Spaces:
Sleeping
Sleeping
| # main.py | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from uuid import uuid4 | |
| from models import InitChatRequest, ChatRequest, EndSessionRequest | |
| from rag_chain import build_chain | |
| app = FastAPI() | |
| # CORS setup | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| # In-memory session storage | |
| chat_sessions = {} | |
| def root(): | |
| return {"message": "API is running"} | |
| def initialize_chat(req: InitChatRequest): | |
| try: | |
| session_id = str(uuid4()) | |
| qa_chain = build_chain(req.video_id) | |
| chat_sessions[session_id] = qa_chain | |
| return {"message": "Chat session started", "session_id": session_id} | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def chat(req: ChatRequest): | |
| session_id = req.session_id | |
| if session_id not in chat_sessions: | |
| raise HTTPException(status_code=404, detail="Invalid session ID. Initialize session first.") | |
| try: | |
| qa_chain = chat_sessions[session_id] | |
| result = qa_chain.invoke({"query": req.query}) | |
| return { | |
| "answer": result["result"], | |
| "sources": [doc.page_content for doc in result["source_documents"]] | |
| } | |
| except Exception as e: | |
| raise HTTPException(status_code=500, detail=str(e)) | |
| def end_chat_session(req: EndSessionRequest): | |
| session_id = req.session_id | |
| if session_id in chat_sessions: | |
| del chat_sessions[session_id] | |
| return {"message": f"Session {session_id} ended successfully."} | |
| else: | |
| raise HTTPException(status_code=404, detail="Session ID not found.") | |