Spaces:
Build error
Build error
| """FastAPI entrypoint for the math agent backend.""" | |
| from __future__ import annotations | |
| import asyncio | |
| import os | |
| from typing import AsyncIterator | |
| from contextlib import asynccontextmanager | |
| from fastapi import Depends, FastAPI, HTTPException, Request | |
| from fastapi.middleware.cors import CORSMiddleware | |
| from fastapi.responses import JSONResponse | |
| from .config import settings | |
| from .guardrails import GuardrailViolation | |
| from .logger import configure_logging, get_logger | |
| from .schemas import AgentResponse, ChatRequest, FeedbackRequest | |
| from .services.retrieval import MathAgent | |
| from .services.vector_store import ( | |
| load_vector_store, | |
| save_feedback_to_queue, | |
| ) | |
| from .services.kb_updater import update_knowledge_base | |
| from .tools.audio import transcribe_audio | |
| from .tools.validator import validate_user_solution | |
| from .tools.vision import extract_text_from_image | |
| logger = get_logger(__name__) | |
| def create_app() -> FastAPI: | |
| configure_logging() | |
| app = FastAPI(title=settings.app_name, version="0.1.0") | |
| # Configure CORS. Accept a single FRONTEND_URL or a comma-separated | |
| # FRONTEND_URLS env var. Default to common dev ports (5173 and 3000). | |
| frontend_env = os.getenv("FRONTEND_URL", os.getenv("FRONTEND_URLS", "http://localhost:5173,http://localhost:3000")) | |
| # split and strip any whitespace | |
| allowed_origins = [o.strip() for o in frontend_env.split(",") if o.strip()] | |
| app.add_middleware( | |
| CORSMiddleware, | |
| allow_origins=allowed_origins, | |
| allow_credentials=True, | |
| allow_methods=["*"], | |
| allow_headers=["*"], | |
| ) | |
| async def lifespan(app: FastAPI) -> AsyncIterator[None]: | |
| # Startup | |
| logger.info("app.startup") | |
| app.state.vector_store = await asyncio.to_thread(load_vector_store) | |
| yield # Server is running | |
| # Shutdown | |
| logger.info("app.shutdown") | |
| # Clean up any resources if needed | |
| app.router.lifespan_context = lifespan | |
| def get_agent() -> MathAgent: | |
| vector_store = getattr(app.state, "vector_store", None) | |
| if vector_store is None: | |
| vector_store = load_vector_store() | |
| app.state.vector_store = vector_store | |
| return MathAgent(vector_store=vector_store) | |
| async def add_process_time_header(request: Request, call_next): # type: ignore[override] | |
| response = await call_next(request) | |
| response.headers["X-App-Env"] = settings.environment | |
| return response | |
| async def health() -> dict[str, str]: | |
| return {"status": "ok", "environment": settings.environment} | |
| async def chat_endpoint(payload: ChatRequest, agent: MathAgent = Depends(get_agent)) -> AgentResponse: | |
| logger.info("chat.request", modality=payload.modality) | |
| query = payload.query | |
| if payload.modality == "audio" and payload.audio_base64: | |
| query = transcribe_audio(payload.audio_base64) | |
| elif payload.modality == "image" and payload.image_base64: | |
| query = extract_text_from_image(payload.image_base64) | |
| try: | |
| response = await agent.handle_query(query) | |
| return response | |
| except GuardrailViolation as exc: | |
| raise HTTPException(status_code=400, detail=exc.message) from exc | |
| except Exception as exc: # pragma: no cover - defensive logging | |
| logger.exception("chat.error", error=str(exc)) | |
| raise HTTPException(status_code=500, detail="Unexpected error handling query") from exc | |
| async def feedback_endpoint(request: FeedbackRequest) -> JSONResponse: | |
| """Handle user feedback with optional solution upload.""" | |
| logger.info( | |
| "feedback.received", | |
| message_id=request.message_id, | |
| helpful=request.feedback.thumbs_up, | |
| issue=request.feedback.primary_issue | |
| ) | |
| # Always save the feedback first | |
| record = request.model_dump() | |
| save_feedback_to_queue(record) | |
| # If it's negative feedback with a solution | |
| if not request.feedback.thumbs_up and request.feedback.has_better_solution: | |
| solution = None | |
| # Get solution based on type | |
| if request.feedback.solution_type == "text": | |
| solution = request.feedback.better_solution_text | |
| elif request.feedback.solution_type == "pdf": | |
| # TODO: Extract text from PDF | |
| solution = request.feedback.better_solution_text | |
| elif request.feedback.solution_type == "image": | |
| # Use vision model to extract solution from image | |
| if request.feedback.better_solution_image_base64: | |
| solution = extract_text_from_image( | |
| request.feedback.better_solution_image_base64, | |
| "Extract the mathematical solution from this image." | |
| ) | |
| if solution: | |
| # Validate and update KB | |
| success = await update_knowledge_base(request.query, solution) | |
| return JSONResponse({ | |
| "status": "ok", | |
| "feedback_saved": True, | |
| "kb_updated": success | |
| }) | |
| return JSONResponse({ | |
| "status": "ok", | |
| "feedback_saved": True | |
| }) | |
| async def reload_vector_store() -> dict[str, str]: | |
| app.state.vector_store = await asyncio.to_thread(load_vector_store, True) | |
| return {"status": "reloaded"} | |
| return app | |
| app = create_app() | |