import logging import os from contextlib import asynccontextmanager from pathlib import Path from fastapi import FastAPI, HTTPException from huggingface_hub import InferenceClient from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import FileResponse from fastapi.staticfiles import StaticFiles from pydantic import BaseModel import telemetry from config import DOMAIN_CLIENTS, CLIENT_DOMAIN, DISPLAY_NAMES from grader import get_embedder, get_nli_model from pipeline import run, _build_index, clear_index_cache log = logging.getLogger(__name__) logging.basicConfig(level=logging.INFO) UI_DIR = Path(__file__).parent.parent / "ui" @asynccontextmanager async def lifespan(app: FastAPI): hf_token = os.environ.get("HF_TOKEN") if not hf_token: raise RuntimeError("HF_TOKEN not set") app.state.hf_client = InferenceClient(token=hf_token) embedder = get_embedder() get_nli_model() for domain in DOMAIN_CLIENTS: _build_index(domain, embedder) log.info("Models and KB indexes pre-warmed. Ready.") yield app = FastAPI(title="AI Response Validator", lifespan=lifespan) app.add_middleware( CORSMiddleware, allow_origins=["*"], allow_methods=["GET", "POST"], allow_headers=["*"], ) class QueryRequest(BaseModel): query: str client: str class QueryResponse(BaseModel): query: str client: str client_display: str answer: str flagged: bool sources: list[dict] evaluation: dict @app.get("/health") def health(): return {"status": "ok"} @app.get("/config") def get_config(): """Domain/client structure for the UI switcher.""" return { "domains": { domain: [{"id": c, "display": DISPLAY_NAMES[c]} for c in clients] for domain, clients in DOMAIN_CLIENTS.items() } } @app.post("/refresh-cache") def refresh_cache(): """Evict KB index cache and rebuild all domain indexes from disk.""" evicted = clear_index_cache() embedder = get_embedder() for domain in DOMAIN_CLIENTS: _build_index(domain, embedder) log.info("Cache refreshed. Rebuilt indexes for: %s", list(DOMAIN_CLIENTS)) return {"refreshed": evicted, "rebuilt": list(DOMAIN_CLIENTS)} @app.get("/metrics") def get_metrics(): """Live session stats from in-memory counters — resets on restart.""" return telemetry.live_stats() @app.get("/report") def get_report(): """Accumulated stats from HF Dataset shards — persists across restarts.""" return telemetry.persistent_report() @app.post("/query", response_model=QueryResponse) def handle_query(req: QueryRequest): if req.client not in CLIENT_DOMAIN: raise HTTPException(status_code=400, detail=f"Unknown client: {req.client!r}") if not req.query.strip(): raise HTTPException(status_code=400, detail="Query cannot be empty") result = run( query=req.query.strip(), client=req.client, hf_client=app.state.hf_client, ) if not result.grade_report.overall: failed = [r.metric for r in result.grade_report.results if not r.passed] log.warning("EVAL_FAIL client=%s failed_metrics=%s query=%r", req.client, failed, req.query.strip()[:80]) return result.response_payload app.mount("/static", StaticFiles(directory=UI_DIR), name="static") @app.get("/") def root(): return FileResponse(UI_DIR / "index.html")