Spaces:
No application file
No application file
| from typing import Annotated | |
| from uuid import uuid4 | |
| from fastapi import FastAPI, UploadFile, HTTPException, status, Request, Depends | |
| from fastapi.responses import JSONResponse, StreamingResponse | |
| from pydantic import BaseModel | |
| from starlette.middleware.cors import CORSMiddleware | |
| from starlette.middleware.sessions import SessionMiddleware | |
| from utils.file_parsers import ParserRouter | |
| from utils.session_db import SessionDB | |
| from utils.chat import ChatOpenAI | |
| from utils.prompts import Prompt | |
| from utils.pipeline import RAGPipeline | |
| from utils.splitter import CharacterTextSplitter | |
| from utils.embedding import EmbeddingModel | |
| from utils.vector_db import VectorDatabase | |
| from settings import Settings | |
| app = FastAPI(debug=True) | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins="*", | |
| allow_credentials="*", | |
| allow_methods="*", | |
| allow_headers="*", | |
| ) | |
| app.add_middleware(SessionMiddleware, secret_key="very-secret-key", max_age=None) | |
| SESSION_DB = None | |
| class ChatRequest(BaseModel): | |
| message: str | |
| def get_settings() -> Settings: | |
| return Settings() | |
| def get_embedding_model( | |
| settings: Annotated[Settings, Depends(get_settings)], | |
| ) -> EmbeddingModel: | |
| return EmbeddingModel(settings) | |
| def get_vector_db( | |
| embedding_model: Annotated[EmbeddingModel, Depends(get_embedding_model)], | |
| ) -> VectorDatabase: | |
| return VectorDatabase(embedding_model) | |
| def get_session_db() -> SessionDB: | |
| global SESSION_DB | |
| if SESSION_DB is None: | |
| SESSION_DB = SessionDB() | |
| return SESSION_DB | |
| def get_splitter() -> CharacterTextSplitter: | |
| return CharacterTextSplitter() | |
| def get_chat(settings: Annotated[Settings, Depends(get_settings)]) -> ChatOpenAI: | |
| return ChatOpenAI(settings) | |
| def get_prompt() -> Prompt: | |
| return Prompt() | |
| def get_pipeline( | |
| llm: Annotated[ChatOpenAI, Depends(get_chat)], | |
| db: Annotated[VectorDatabase, Depends(get_vector_db)], | |
| prompt: Annotated[Prompt, Depends(get_prompt)], | |
| ) -> RAGPipeline: | |
| return RAGPipeline(llm, db, prompt) | |
| def get_parser() -> ParserRouter: | |
| return ParserRouter() | |
| async def upload_file( | |
| request: Request, | |
| file: UploadFile, | |
| parser: Annotated[ParserRouter, Depends(get_parser)], | |
| splitter: Annotated[CharacterTextSplitter, Depends(get_splitter)], | |
| pipeline: Annotated[RAGPipeline, Depends(get_pipeline)], | |
| session_db: Annotated[SessionDB, Depends(get_session_db)], | |
| ): | |
| file_content = await file.read() | |
| try: | |
| parsed_text = parser.parse(file_content, file.filename) | |
| except KeyError: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, | |
| detail="Unavailable file extension", | |
| ) | |
| documents = splitter.split(parsed_text) | |
| await pipeline.vector_db.abuild_from_list(documents, {}) | |
| key = str(uuid4()) | |
| session_db.add(key, pipeline) | |
| request.session["session_key"] = key | |
| return JSONResponse( | |
| content={"message": "File uploaded successfully, please ask your questions"}, | |
| ) | |
| async def chat( | |
| request: Request, | |
| chat_request: ChatRequest, | |
| session_db: Annotated[SessionDB, Depends(get_session_db)], | |
| ): | |
| try: | |
| user_message = chat_request.message | |
| # Retrieve data from session | |
| session_key = request.session.get("session_key") | |
| if not session_key: | |
| return JSONResponse( | |
| content={ | |
| "response": "Waiting for file. Please upload one so we can start" | |
| }, | |
| ) | |
| pipeline: RAGPipeline = session_db.get(session_key) | |
| response = await pipeline.arun_pipeline(user_message) | |
| return StreamingResponse(content=response["response"]) | |
| except Exception as e: | |
| raise HTTPException( | |
| status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e) | |
| ) | |
| def generate_response(user_message: str, file_data: str) -> str: | |
| # Placeholder function to generate a response using the RAG and LLM | |
| return f"This is a placeholder response. File data: {file_data}" | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run("asgi:app", host="localhost", port=8000, reload=True) | |