| import logging |
| import os |
| from contextlib import asynccontextmanager |
| from pathlib import Path |
|
|
| from fastapi import FastAPI, HTTPException |
| from huggingface_hub import InferenceClient |
| from fastapi.middleware.cors import CORSMiddleware |
| from fastapi.responses import FileResponse |
| from fastapi.staticfiles import StaticFiles |
| from pydantic import BaseModel |
|
|
| import telemetry |
| from config import DOMAIN_CLIENTS, CLIENT_DOMAIN, DISPLAY_NAMES |
| from grader import get_embedder, get_nli_model |
| from pipeline import run, _build_index, clear_index_cache |
|
|
| log = logging.getLogger(__name__) |
| logging.basicConfig(level=logging.INFO) |
|
|
| UI_DIR = Path(__file__).parent.parent / "ui" |
|
|
|
|
| @asynccontextmanager |
| async def lifespan(app: FastAPI): |
| hf_token = os.environ.get("HF_TOKEN") |
| if not hf_token: |
| raise RuntimeError("HF_TOKEN not set") |
| app.state.hf_client = InferenceClient(token=hf_token) |
| embedder = get_embedder() |
| get_nli_model() |
| for domain in DOMAIN_CLIENTS: |
| _build_index(domain, embedder) |
| log.info("Models and KB indexes pre-warmed. Ready.") |
| yield |
|
|
|
|
| app = FastAPI(title="AI Response Validator", lifespan=lifespan) |
|
|
| app.add_middleware( |
| CORSMiddleware, |
| allow_origins=["*"], |
| allow_methods=["GET", "POST"], |
| allow_headers=["*"], |
| ) |
|
|
|
|
| class QueryRequest(BaseModel): |
| query: str |
| client: str |
|
|
|
|
| class QueryResponse(BaseModel): |
| query: str |
| client: str |
| client_display: str |
| answer: str |
| flagged: bool |
| sources: list[dict] |
| evaluation: dict |
|
|
|
|
| @app.get("/health") |
| def health(): |
| return {"status": "ok"} |
|
|
|
|
| @app.get("/config") |
| def get_config(): |
| """Domain/client structure for the UI switcher.""" |
| return { |
| "domains": { |
| domain: [{"id": c, "display": DISPLAY_NAMES[c]} for c in clients] |
| for domain, clients in DOMAIN_CLIENTS.items() |
| } |
| } |
|
|
|
|
| @app.post("/refresh-cache") |
| def refresh_cache(): |
| """Evict KB index cache and rebuild all domain indexes from disk.""" |
| evicted = clear_index_cache() |
| embedder = get_embedder() |
| for domain in DOMAIN_CLIENTS: |
| _build_index(domain, embedder) |
| log.info("Cache refreshed. Rebuilt indexes for: %s", list(DOMAIN_CLIENTS)) |
| return {"refreshed": evicted, "rebuilt": list(DOMAIN_CLIENTS)} |
|
|
|
|
| @app.get("/metrics") |
| def get_metrics(): |
| """Live session stats from in-memory counters — resets on restart.""" |
| return telemetry.live_stats() |
|
|
|
|
| @app.get("/report") |
| def get_report(): |
| """Accumulated stats from HF Dataset shards — persists across restarts.""" |
| return telemetry.persistent_report() |
|
|
|
|
| @app.post("/query", response_model=QueryResponse) |
| def handle_query(req: QueryRequest): |
| if req.client not in CLIENT_DOMAIN: |
| raise HTTPException(status_code=400, detail=f"Unknown client: {req.client!r}") |
| if not req.query.strip(): |
| raise HTTPException(status_code=400, detail="Query cannot be empty") |
|
|
| result = run( |
| query=req.query.strip(), |
| client=req.client, |
| hf_client=app.state.hf_client, |
| ) |
| if not result.grade_report.overall: |
| failed = [r.metric for r in result.grade_report.results if not r.passed] |
| log.warning("EVAL_FAIL client=%s failed_metrics=%s query=%r", |
| req.client, failed, req.query.strip()[:80]) |
| return result.response_payload |
|
|
|
|
| app.mount("/static", StaticFiles(directory=UI_DIR), name="static") |
|
|
|
|
| @app.get("/") |
| def root(): |
| return FileResponse(UI_DIR / "index.html") |
|
|