from typing import Annotated from uuid import uuid4 from fastapi import FastAPI, UploadFile, HTTPException, status, Request, Depends from fastapi.responses import JSONResponse, StreamingResponse from pydantic import BaseModel from starlette.middleware.cors import CORSMiddleware from starlette.middleware.sessions import SessionMiddleware from utils.file_parsers import ParserRouter from utils.session_db import SessionDB from utils.chat import ChatOpenAI from utils.prompts import Prompt from utils.pipeline import RAGPipeline from utils.splitter import CharacterTextSplitter from utils.embedding import EmbeddingModel from utils.vector_db import VectorDatabase from settings import Settings app = FastAPI(debug=True) app.add_middleware( CORSMiddleware, allow_origins="*", allow_credentials="*", allow_methods="*", allow_headers="*", ) app.add_middleware(SessionMiddleware, secret_key="very-secret-key", max_age=None) SESSION_DB = None class ChatRequest(BaseModel): message: str def get_settings() -> Settings: return Settings() def get_embedding_model( settings: Annotated[Settings, Depends(get_settings)], ) -> EmbeddingModel: return EmbeddingModel(settings) def get_vector_db( embedding_model: Annotated[EmbeddingModel, Depends(get_embedding_model)], ) -> VectorDatabase: return VectorDatabase(embedding_model) def get_session_db() -> SessionDB: global SESSION_DB if SESSION_DB is None: SESSION_DB = SessionDB() return SESSION_DB def get_splitter() -> CharacterTextSplitter: return CharacterTextSplitter() def get_chat(settings: Annotated[Settings, Depends(get_settings)]) -> ChatOpenAI: return ChatOpenAI(settings) def get_prompt() -> Prompt: return Prompt() def get_pipeline( llm: Annotated[ChatOpenAI, Depends(get_chat)], db: Annotated[VectorDatabase, Depends(get_vector_db)], prompt: Annotated[Prompt, Depends(get_prompt)], ) -> RAGPipeline: return RAGPipeline(llm, db, prompt) def get_parser() -> ParserRouter: return ParserRouter() @app.post("/upload-file") async def upload_file( request: Request, file: UploadFile, parser: Annotated[ParserRouter, Depends(get_parser)], splitter: Annotated[CharacterTextSplitter, Depends(get_splitter)], pipeline: Annotated[RAGPipeline, Depends(get_pipeline)], session_db: Annotated[SessionDB, Depends(get_session_db)], ): file_content = await file.read() try: parsed_text = parser.parse(file_content, file.filename) except KeyError: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail="Unavailable file extension", ) documents = splitter.split(parsed_text) await pipeline.vector_db.abuild_from_list(documents, {}) key = str(uuid4()) session_db.add(key, pipeline) request.session["session_key"] = key return JSONResponse( content={"message": "File uploaded successfully, please ask your questions"}, ) @app.post("/chat") async def chat( request: Request, chat_request: ChatRequest, session_db: Annotated[SessionDB, Depends(get_session_db)], ): try: user_message = chat_request.message # Retrieve data from session session_key = request.session.get("session_key") if not session_key: return JSONResponse( content={ "response": "Waiting for file. Please upload one so we can start" }, ) pipeline: RAGPipeline = session_db.get(session_key) response = await pipeline.arun_pipeline(user_message) return StreamingResponse(content=response["response"]) except Exception as e: raise HTTPException( status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) ) def generate_response(user_message: str, file_data: str) -> str: # Placeholder function to generate a response using the RAG and LLM return f"This is a placeholder response. File data: {file_data}" if __name__ == "__main__": import uvicorn uvicorn.run("asgi:app", host="localhost", port=8000, reload=True)