from fastapi.responses import StreamingResponse from fastapi import FastAPI, HTTPException import os import base64 from pydantic import BaseModel, Field from typing import List, Dict from typing_extensions import Literal import logging import sqlite3 import time import asyncio from components.LLM import rLLM from components.Database import AdvancedClient from components.utils import create_refrences # LLM API key TOGETHER_API = str(os.getenv("TOGETHER_API_KEY")) # Configure logging logging.basicConfig( level=logging.WARNING, format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", handlers=[logging.FileHandler("app.log"), logging.StreamHandler()], ) logger = logging.getLogger(__name__) app = FastAPI() # SQLite setup DB_PATH = "app/data/conversations.db" # In-memory storage for conversations conversations: Dict[str, List[Dict[str, str]]] = {} COLLECTIONS: Dict[str, List[str]] = {} last_activity: Dict[str, float] = {} # initialize SQLite database def init_db(): logger.info("Initializing database") os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute( """CREATE TABLE IF NOT EXISTS conversations (id INTEGER PRIMARY KEY AUTOINCREMENT, conversation_id TEXT, collections TEXT, lastmessage TEXT timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)""" ) conn.commit() conn.close() logger.info("Database initialized successfully") init_db() def update_db(conversation_id, collections, message): logger.info(f"Updating database for conversation: {conversation_id}") conn = sqlite3.connect(DB_PATH) c = conn.cursor() c.execute( """INSERT INTO conversations (conversation_id, collections, lastmessage) VALUES (?, ?, ?)""", (conversation_id, collections, message), ) conn.commit() conn.close() logger.info("Database updated successfully") def get_collection_from_db(conversation_id): conn = sqlite3.connect(DB_PATH) try: c = conn.cursor() c.execute( """SELECT collections FROM conversations WHERE conversation_id = ?""", (conversation_id,), ) collection = c.fetchone() if collection: return collection[0] else: return None finally: conn.close() async def clear_inactive_conversations(): while True: logger.info("Clearing inactive conversations") current_time = time.time() inactive_convos = [ conv_id for conv_id, last_time in last_activity.items() if current_time - last_time > 1800 ] # 30 minutes for conv_id in inactive_convos: if conv_id in conversations: del conversations[conv_id] if conv_id in last_activity: del last_activity[conv_id] if conv_id in COLLECTIONS: del COLLECTIONS[conv_id] logger.info(f"Cleared {len(inactive_convos)} inactive conversations") await asyncio.sleep(60) # Check every minute @app.on_event("startup") async def startup_event(): logger.info("Starting up the application") asyncio.create_task(clear_inactive_conversations()) class UploadedFiles(BaseModel): ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"]) FileNames: List[str] = Field(examples=[["file_1.pdf", "file_2.docx"]]) FileTypes: List[Literal["pdf", "docx"]] = Field(examples=[["pdf", "docx"]]) FileData: List[str] class UserInput(BaseModel): ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"]) Query: str = Field(examples=["What is IT ACT 2000?"]) class ChunkResponse(BaseModel): chunk: str = Field(examples=["This is", "streaming"]) class CompletedResponse(BaseModel): FullResponse: str = Field(examples=["This is a complete response"]) InputToken: int = Field(examples=[1024, 2048]) OutputToken: int = Field(examples=[4096, 7000]) @app.post("/initiate_conversation") async def get_conversation_id(files: UploadedFiles): # Decoding bytes data data = [base64.b64decode(b) for b in files.FileData] vector_db = AdvancedClient() vector_db.create_or_get_collection( file_names=files.FileNames, file_types=files.FileTypes, file_datas=data, ) file_ids = vector_db.selected_collections # update in-memory data COLLECTIONS[files.ConversationID] = file_ids conversations[files.ConversationID] = [] last_activity[files.ConversationID] = time.time() # update SQL data update_db( conversation_id=files.ConversationID, collections="|".join(file_ids), message="NONE", ) return True @app.post("/get_response") async def get_response_streaming(user_query: UserInput): llm = rLLM(llm_name="meta-llama/Llama-3-8b-chat-hf", api_key=TOGETHER_API) conv_id = user_query.ConversationID try: print(COLLECTIONS) if conv_id in COLLECTIONS: collection_to_use = COLLECTIONS[conv_id] last_activity[conv_id] = time.time() else: collections = get_collection_from_db(conv_id) if collections: collection_to_use = collections.split("|") except: return HTTPException( status_code=404, detail="Conversation ID does not exist, please register one with /initiate_conversation endpoint.", ) vector_db = AdvancedClient() # update database to user conversation's documents vector_db.selected_collections = collection_to_use try: conversation_history = conversations[conv_id] except: conversations[conv_id] = [] conversation_history = [] rephrased_query = llm.HyDE( query=user_query.Query, message_history=conversation_history ) retrieved_docs = vector_db.retrieve_chunks(query=rephrased_query) conversations[conv_id].append({"role": "user", "content": user_query.Query}) context = "" for i, doc in enumerate(retrieved_docs, start=1): context += f"Refrence {i}\n\n" + doc["document"] + "\n\n" def streaming(): for data in llm.generate_rag_response( context=context, prompt=user_query.Query, message_history=conversation_history, ): completed, chunk = data if completed: full_response, input_token, output_token = chunk conversations[conv_id].append( {"role": "assistant", "content": full_response} ) logger.info(msg=f"Input:{input_token} \nOuptut:{output_token}") yield "\n\n\n" + create_refrences( retrieved_docs ) + "\n" else: chunk = chunk yield chunk return StreamingResponse(streaming(), media_type="text/event-stream")