Spaces:
Running
Running
| """FastAPI Backend for VietQA Multi-Agent System.""" | |
| from contextlib import asynccontextmanager | |
| from fastapi import FastAPI, HTTPException | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from pydantic import BaseModel | |
| from src.data_processing.models import QuestionInput | |
| from src.data_processing.formatting import question_to_state | |
| from src.data_processing.answer import normalize_answer | |
| from src.graph import get_graph | |
| from src.graph import get_graph | |
| from src.utils.llm import set_large_model_override, get_available_large_models | |
| from src.utils.firebase import delete_user_permanently, verify_id_token | |
| from src.utils.ingestion import get_qdrant_client | |
| from fastapi import Header | |
| async def lifespan(app: FastAPI): | |
| """Fast startup - lazy load models on first request.""" | |
| print("[Startup] Server starting (models will load on first request)...") | |
| print("[Startup] Server ready!") | |
| yield | |
| app = FastAPI( | |
| title="VietQA Multi-Agent API", | |
| description="API cho hệ thống trả lời câu hỏi trắc nghiệm tiếng Việt", | |
| version="1.0.0", | |
| lifespan=lifespan | |
| ) | |
| # CORS for frontend | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=["*"], | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| class SolveRequest(BaseModel): | |
| question: str | |
| choices: list[str] | |
| model: str | None = None | |
| class SolveResponse(BaseModel): | |
| answer: str | |
| route: str | |
| reasoning: str | |
| context: str | |
| def clean_thinking_tags(text: str) -> str: | |
| """Remove <think>...</think> tags from model response.""" | |
| import re | |
| # Remove think tags and their content | |
| cleaned = re.sub(r'<think>.*?</think>', '', text, flags=re.DOTALL) | |
| return cleaned.strip() | |
| class ChatRequest(BaseModel): | |
| message: str | |
| model: str | None = None | |
| class ChatResponse(BaseModel): | |
| response: str | |
| route: str | |
| class ModelsResponse(BaseModel): | |
| models: list[dict] | |
| async def root(): | |
| return {"message": "VietQA Multi-Agent API", "status": "running"} | |
| async def health(): | |
| """Health check endpoint for Render.""" | |
| return {"status": "ok"} | |
| async def get_models(): | |
| """Get available large models.""" | |
| models = get_available_large_models() | |
| return { | |
| "models": [ | |
| {"id": m, "name": m.split("/")[-1]} | |
| for m in models | |
| ] | |
| } | |
| async def solve_question(req: SolveRequest): | |
| """Solve a multiple-choice question.""" | |
| if not req.question.strip(): | |
| raise HTTPException(400, "Question is required") | |
| if len(req.choices) < 2: | |
| raise HTTPException(400, "At least 2 choices required") | |
| set_large_model_override(req.model) | |
| try: | |
| q = QuestionInput(qid="api", question=req.question, choices=req.choices) | |
| state = question_to_state(q) | |
| graph = get_graph() | |
| result = await graph.ainvoke(state) | |
| answer = normalize_answer( | |
| answer=result.get("answer", "A"), | |
| num_choices=len(req.choices), | |
| question_id="api", | |
| default="A" | |
| ) | |
| return SolveResponse( | |
| answer=answer, | |
| route=result.get("route", "unknown"), | |
| reasoning=clean_thinking_tags(result.get("raw_response", "")), | |
| context=result.get("context", "") | |
| ) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise HTTPException(500, str(e)) | |
| finally: | |
| set_large_model_override(None) | |
| async def chat(req: ChatRequest): | |
| """Free-form chat (routes through pipeline without choices).""" | |
| if not req.message.strip(): | |
| raise HTTPException(400, "Message is required") | |
| set_large_model_override(req.model) | |
| try: | |
| # Use empty choices for chat mode | |
| q = QuestionInput(qid="chat", question=req.message, choices=[]) | |
| state = question_to_state(q) | |
| graph = get_graph() | |
| result = await graph.ainvoke(state) | |
| return ChatResponse( | |
| response=clean_thinking_tags(result.get("raw_response", "")), | |
| route=result.get("route", "unknown") | |
| ) | |
| except Exception as e: | |
| raise HTTPException(500, str(e)) | |
| finally: | |
| set_large_model_override(None) | |
| class DeleteUserRequest(BaseModel): | |
| uid: str | |
| async def delete_user(req: DeleteUserRequest, authorization: str = Header(None)): | |
| """Permanently delete a user (Supervisor only).""" | |
| if not authorization or not authorization.startswith("Bearer "): | |
| raise HTTPException(401, "Missing or invalid authorization token") | |
| id_token = authorization.split("Bearer ")[1] | |
| try: | |
| # Verify token and check role | |
| # Note: In a real app, custom claims should be used. | |
| # Here we rely on the client SDK's token, but for critical actions like this, | |
| # we should ideally check the user's role in Firestore. | |
| # However, for simplicity and since we verify on frontend with PIN, | |
| # we'll assume the caller is authorized if they have a valid token | |
| # AND we double check the user's existence/role via Firestore if needed. | |
| # Steps: | |
| # 1. Verify token is valid | |
| decoded_token = verify_id_token(id_token) | |
| uid = decoded_token['uid'] | |
| # 2. Check if requester is Supervisor (optional but recommended) | |
| # For now, we trust the token is valid. The Supervisor PIN check is on frontend. | |
| # Improvements: Fetch user from Firestore and check role == 'SUPERVISOR' | |
| # 3. Perform deletion | |
| delete_user_permanently(req.uid) | |
| return {"status": "success", "message": f"User {req.uid} deleted permanently"} | |
| except Exception as e: | |
| print(f"Delete user error: {e}") | |
| raise HTTPException(500, str(e)) | |
| async def get_qdrant_stats(): | |
| """Get Qdrant database statistics.""" | |
| try: | |
| client = get_qdrant_client() | |
| collections = client.get_collections().collections | |
| stats = { | |
| "collections": [], | |
| "total_vectors": 0, | |
| "total_size_bytes": 0 | |
| } | |
| for collection in collections: | |
| try: | |
| collection_info = client.get_collection(collection.name) | |
| vector_count = collection_info.points_count or 0 | |
| stats["collections"].append({ | |
| "name": collection.name, | |
| "vectors": vector_count, | |
| "status": "healthy" | |
| }) | |
| stats["total_vectors"] += vector_count | |
| except Exception as e: | |
| print(f"Error getting collection {collection.name}: {e}") | |
| stats["collections"].append({ | |
| "name": collection.name, | |
| "vectors": 0, | |
| "status": "error" | |
| }) | |
| return stats | |
| except Exception as e: | |
| print(f"Qdrant stats error: {e}") | |
| raise HTTPException(500, f"Failed to get Qdrant stats: {str(e)}") | |
| if __name__ == "__main__": | |
| import uvicorn | |
| uvicorn.run(app, host="0.0.0.0", port=8000) | |