AlexFoxalt's picture
Added file loaders for different extensions | Improved prompt | Added Q&A
9221515
from typing import Annotated
from uuid import uuid4
from fastapi import FastAPI, UploadFile, HTTPException, status, Request, Depends
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel
from starlette.middleware.cors import CORSMiddleware
from starlette.middleware.sessions import SessionMiddleware
from utils.file_parsers import ParserRouter
from utils.session_db import SessionDB
from utils.chat import ChatOpenAI
from utils.prompts import Prompt
from utils.pipeline import RAGPipeline
from utils.splitter import CharacterTextSplitter
from utils.embedding import EmbeddingModel
from utils.vector_db import VectorDatabase
from settings import Settings
app = FastAPI(debug=True)
app.add_middleware(
CORSMiddleware,
allow_origins="*",
allow_credentials="*",
allow_methods="*",
allow_headers="*",
)
app.add_middleware(SessionMiddleware, secret_key="very-secret-key", max_age=None)
SESSION_DB = None
class ChatRequest(BaseModel):
message: str
def get_settings() -> Settings:
return Settings()
def get_embedding_model(
settings: Annotated[Settings, Depends(get_settings)],
) -> EmbeddingModel:
return EmbeddingModel(settings)
def get_vector_db(
embedding_model: Annotated[EmbeddingModel, Depends(get_embedding_model)],
) -> VectorDatabase:
return VectorDatabase(embedding_model)
def get_session_db() -> SessionDB:
global SESSION_DB
if SESSION_DB is None:
SESSION_DB = SessionDB()
return SESSION_DB
def get_splitter() -> CharacterTextSplitter:
return CharacterTextSplitter()
def get_chat(settings: Annotated[Settings, Depends(get_settings)]) -> ChatOpenAI:
return ChatOpenAI(settings)
def get_prompt() -> Prompt:
return Prompt()
def get_pipeline(
llm: Annotated[ChatOpenAI, Depends(get_chat)],
db: Annotated[VectorDatabase, Depends(get_vector_db)],
prompt: Annotated[Prompt, Depends(get_prompt)],
) -> RAGPipeline:
return RAGPipeline(llm, db, prompt)
def get_parser() -> ParserRouter:
return ParserRouter()
@app.post("/upload-file")
async def upload_file(
request: Request,
file: UploadFile,
parser: Annotated[ParserRouter, Depends(get_parser)],
splitter: Annotated[CharacterTextSplitter, Depends(get_splitter)],
pipeline: Annotated[RAGPipeline, Depends(get_pipeline)],
session_db: Annotated[SessionDB, Depends(get_session_db)],
):
file_content = await file.read()
try:
parsed_text = parser.parse(file_content, file.filename)
except KeyError:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
detail="Unavailable file extension",
)
documents = splitter.split(parsed_text)
await pipeline.vector_db.abuild_from_list(documents, {})
key = str(uuid4())
session_db.add(key, pipeline)
request.session["session_key"] = key
return JSONResponse(
content={"message": "File uploaded successfully, please ask your questions"},
)
@app.post("/chat")
async def chat(
request: Request,
chat_request: ChatRequest,
session_db: Annotated[SessionDB, Depends(get_session_db)],
):
try:
user_message = chat_request.message
# Retrieve data from session
session_key = request.session.get("session_key")
if not session_key:
return JSONResponse(
content={
"response": "Waiting for file. Please upload one so we can start"
},
)
pipeline: RAGPipeline = session_db.get(session_key)
response = await pipeline.arun_pipeline(user_message)
return StreamingResponse(content=response["response"])
except Exception as e:
raise HTTPException(
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=str(e)
)
def generate_response(user_message: str, file_data: str) -> str:
# Placeholder function to generate a response using the RAG and LLM
return f"This is a placeholder response. File data: {file_data}"
if __name__ == "__main__":
import uvicorn
uvicorn.run("asgi:app", host="localhost", port=8000, reload=True)