Spaces:
Sleeping
Sleeping
| from fastapi.responses import StreamingResponse | |
| from fastapi import FastAPI, HTTPException | |
| import os | |
| import base64 | |
| from pydantic import BaseModel, Field | |
| from typing import List, Dict | |
| from typing_extensions import Literal | |
| import logging | |
| import sqlite3 | |
| import time | |
| import asyncio | |
| from components.LLM import rLLM | |
| from components.Database import AdvancedClient | |
| from components.utils import create_refrences | |
| # LLM API key | |
| TOGETHER_API = str(os.getenv("TOGETHER_API_KEY")) | |
| # Configure logging | |
| logging.basicConfig( | |
| level=logging.WARNING, | |
| format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", | |
| handlers=[logging.FileHandler("app.log"), logging.StreamHandler()], | |
| ) | |
| logger = logging.getLogger(__name__) | |
| app = FastAPI() | |
| # SQLite setup | |
| DB_PATH = "app/data/conversations.db" | |
| # In-memory storage for conversations | |
| conversations: Dict[str, List[Dict[str, str]]] = {} | |
| COLLECTIONS: Dict[str, List[str]] = {} | |
| last_activity: Dict[str, float] = {} | |
| # initialize SQLite database | |
| def init_db(): | |
| logger.info("Initializing database") | |
| os.makedirs(os.path.dirname(DB_PATH), exist_ok=True) | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute( | |
| """CREATE TABLE IF NOT EXISTS conversations | |
| (id INTEGER PRIMARY KEY AUTOINCREMENT, | |
| conversation_id TEXT, | |
| collections TEXT, | |
| lastmessage TEXT | |
| timestamp DATETIME DEFAULT CURRENT_TIMESTAMP)""" | |
| ) | |
| conn.commit() | |
| conn.close() | |
| logger.info("Database initialized successfully") | |
| init_db() | |
| def update_db(conversation_id, collections, message): | |
| logger.info(f"Updating database for conversation: {conversation_id}") | |
| conn = sqlite3.connect(DB_PATH) | |
| c = conn.cursor() | |
| c.execute( | |
| """INSERT INTO conversations (conversation_id, collections, lastmessage) | |
| VALUES (?, ?, ?)""", | |
| (conversation_id, collections, message), | |
| ) | |
| conn.commit() | |
| conn.close() | |
| logger.info("Database updated successfully") | |
| def get_collection_from_db(conversation_id): | |
| conn = sqlite3.connect(DB_PATH) | |
| try: | |
| c = conn.cursor() | |
| c.execute( | |
| """SELECT collections FROM conversations WHERE conversation_id = ?""", | |
| (conversation_id,), | |
| ) | |
| collection = c.fetchone() | |
| if collection: | |
| return collection[0] | |
| else: | |
| return None | |
| finally: | |
| conn.close() | |
| async def clear_inactive_conversations(): | |
| while True: | |
| logger.info("Clearing inactive conversations") | |
| current_time = time.time() | |
| inactive_convos = [ | |
| conv_id | |
| for conv_id, last_time in last_activity.items() | |
| if current_time - last_time > 1800 | |
| ] # 30 minutes | |
| for conv_id in inactive_convos: | |
| if conv_id in conversations: | |
| del conversations[conv_id] | |
| if conv_id in last_activity: | |
| del last_activity[conv_id] | |
| if conv_id in COLLECTIONS: | |
| del COLLECTIONS[conv_id] | |
| logger.info(f"Cleared {len(inactive_convos)} inactive conversations") | |
| await asyncio.sleep(60) # Check every minute | |
| async def startup_event(): | |
| logger.info("Starting up the application") | |
| asyncio.create_task(clear_inactive_conversations()) | |
| class UploadedFiles(BaseModel): | |
| ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"]) | |
| FileNames: List[str] = Field(examples=[["file_1.pdf", "file_2.docx"]]) | |
| FileTypes: List[Literal["pdf", "docx"]] = Field(examples=[["pdf", "docx"]]) | |
| FileData: List[str] | |
| class UserInput(BaseModel): | |
| ConversationID: str = Field(examples=["123e4567-e89b-12d3-a456-426614174000"]) | |
| Query: str = Field(examples=["What is IT ACT 2000?"]) | |
| class ChunkResponse(BaseModel): | |
| chunk: str = Field(examples=["This is", "streaming"]) | |
| class CompletedResponse(BaseModel): | |
| FullResponse: str = Field(examples=["This is a complete response"]) | |
| InputToken: int = Field(examples=[1024, 2048]) | |
| OutputToken: int = Field(examples=[4096, 7000]) | |
| async def get_conversation_id(files: UploadedFiles): | |
| # Decoding bytes data | |
| data = [base64.b64decode(b) for b in files.FileData] | |
| vector_db = AdvancedClient() | |
| vector_db.create_or_get_collection( | |
| file_names=files.FileNames, | |
| file_types=files.FileTypes, | |
| file_datas=data, | |
| ) | |
| file_ids = vector_db.selected_collections | |
| # update in-memory data | |
| COLLECTIONS[files.ConversationID] = file_ids | |
| conversations[files.ConversationID] = [] | |
| last_activity[files.ConversationID] = time.time() | |
| # update SQL data | |
| update_db( | |
| conversation_id=files.ConversationID, | |
| collections="|".join(file_ids), | |
| message="NONE", | |
| ) | |
| return True | |
| async def get_response_streaming(user_query: UserInput): | |
| llm = rLLM(llm_name="meta-llama/Llama-3-8b-chat-hf", api_key=TOGETHER_API) | |
| conv_id = user_query.ConversationID | |
| try: | |
| print(COLLECTIONS) | |
| if conv_id in COLLECTIONS: | |
| collection_to_use = COLLECTIONS[conv_id] | |
| last_activity[conv_id] = time.time() | |
| else: | |
| collections = get_collection_from_db(conv_id) | |
| if collections: | |
| collection_to_use = collections.split("|") | |
| except: | |
| return HTTPException( | |
| status_code=404, | |
| detail="Conversation ID does not exist, please register one with /initiate_conversation endpoint.", | |
| ) | |
| vector_db = AdvancedClient() | |
| # update database to user conversation's documents | |
| vector_db.selected_collections = collection_to_use | |
| try: | |
| conversation_history = conversations[conv_id] | |
| except: | |
| conversations[conv_id] = [] | |
| conversation_history = [] | |
| rephrased_query = llm.HyDE( | |
| query=user_query.Query, message_history=conversation_history | |
| ) | |
| retrieved_docs = vector_db.retrieve_chunks(query=rephrased_query) | |
| conversations[conv_id].append({"role": "user", "content": user_query.Query}) | |
| context = "" | |
| for i, doc in enumerate(retrieved_docs, start=1): | |
| context += f"Refrence {i}\n\n" + doc["document"] + "\n\n" | |
| def streaming(): | |
| for data in llm.generate_rag_response( | |
| context=context, | |
| prompt=user_query.Query, | |
| message_history=conversation_history, | |
| ): | |
| completed, chunk = data | |
| if completed: | |
| full_response, input_token, output_token = chunk | |
| conversations[conv_id].append( | |
| {"role": "assistant", "content": full_response} | |
| ) | |
| logger.info(msg=f"Input:{input_token} \nOuptut:{output_token}") | |
| yield "\n\n<REFRENCES>\n" + create_refrences( | |
| retrieved_docs | |
| ) + "\n</REFRENCES>" | |
| else: | |
| chunk = chunk | |
| yield chunk | |
| return StreamingResponse(streaming(), media_type="text/event-stream") | |