"""FastAPI entrypoint for the math agent backend.""" from __future__ import annotations import asyncio import os from typing import AsyncIterator from contextlib import asynccontextmanager from fastapi import Depends, FastAPI, HTTPException, Request from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse from .config import settings from .guardrails import GuardrailViolation from .logger import configure_logging, get_logger from .schemas import AgentResponse, ChatRequest, FeedbackRequest from .services.retrieval import MathAgent from .services.vector_store import ( load_vector_store, save_feedback_to_queue, ) from .services.kb_updater import update_knowledge_base from .tools.audio import transcribe_audio from .tools.validator import validate_user_solution from .tools.vision import extract_text_from_image logger = get_logger(__name__) def create_app() -> FastAPI: configure_logging() app = FastAPI(title=settings.app_name, version="0.1.0") # Configure CORS. Accept a single FRONTEND_URL or a comma-separated # FRONTEND_URLS env var. Default to common dev ports (5173 and 3000). frontend_env = os.getenv("FRONTEND_URL", os.getenv("FRONTEND_URLS", "http://localhost:5173,http://localhost:3000")) # split and strip any whitespace allowed_origins = [o.strip() for o in frontend_env.split(",") if o.strip()] app.add_middleware( CORSMiddleware, allow_origins=allowed_origins, allow_credentials=True, allow_methods=["*"], allow_headers=["*"], ) @asynccontextmanager async def lifespan(app: FastAPI) -> AsyncIterator[None]: # Startup logger.info("app.startup") app.state.vector_store = await asyncio.to_thread(load_vector_store) yield # Server is running # Shutdown logger.info("app.shutdown") # Clean up any resources if needed app.router.lifespan_context = lifespan def get_agent() -> MathAgent: vector_store = getattr(app.state, "vector_store", None) if vector_store is None: vector_store = load_vector_store() app.state.vector_store = vector_store return MathAgent(vector_store=vector_store) @app.middleware("http") async def add_process_time_header(request: Request, call_next): # type: ignore[override] response = await call_next(request) response.headers["X-App-Env"] = settings.environment return response @app.get("/health") async def health() -> dict[str, str]: return {"status": "ok", "environment": settings.environment} @app.post("/api/chat", response_model=AgentResponse, responses={400: {"description": "Guardrail failure"}}) async def chat_endpoint(payload: ChatRequest, agent: MathAgent = Depends(get_agent)) -> AgentResponse: logger.info("chat.request", modality=payload.modality) query = payload.query if payload.modality == "audio" and payload.audio_base64: query = transcribe_audio(payload.audio_base64) elif payload.modality == "image" and payload.image_base64: query = extract_text_from_image(payload.image_base64) try: response = await agent.handle_query(query) return response except GuardrailViolation as exc: raise HTTPException(status_code=400, detail=exc.message) from exc except Exception as exc: # pragma: no cover - defensive logging logger.exception("chat.error", error=str(exc)) raise HTTPException(status_code=500, detail="Unexpected error handling query") from exc @app.post("/api/feedback") async def feedback_endpoint(request: FeedbackRequest) -> JSONResponse: """Handle user feedback with optional solution upload.""" logger.info( "feedback.received", message_id=request.message_id, helpful=request.feedback.thumbs_up, issue=request.feedback.primary_issue ) # Always save the feedback first record = request.model_dump() save_feedback_to_queue(record) # If it's negative feedback with a solution if not request.feedback.thumbs_up and request.feedback.has_better_solution: solution = None # Get solution based on type if request.feedback.solution_type == "text": solution = request.feedback.better_solution_text elif request.feedback.solution_type == "pdf": # TODO: Extract text from PDF solution = request.feedback.better_solution_text elif request.feedback.solution_type == "image": # Use vision model to extract solution from image if request.feedback.better_solution_image_base64: solution = extract_text_from_image( request.feedback.better_solution_image_base64, "Extract the mathematical solution from this image." ) if solution: # Validate and update KB success = await update_knowledge_base(request.query, solution) return JSONResponse({ "status": "ok", "feedback_saved": True, "kb_updated": success }) return JSONResponse({ "status": "ok", "feedback_saved": True }) @app.get("/api/vector-store/reload") async def reload_vector_store() -> dict[str, str]: app.state.vector_store = await asyncio.to_thread(load_vector_store, True) return {"status": "reloaded"} return app app = create_app()