File size: 3,685 Bytes
ebe934f e181667 ebe934f e181667 ebe934f e181667 ebe934f 907c06a e181667 ebe934f e181667 ebb06ed aef9f0f 907c06a aef9f0f ebe934f 907c06a e181667 ebe934f e181667 ebe934f e181667 ebe934f e77a2f2 e181667 e77a2f2 c79d967 e181667 c79d967 e181667 c79d967 ebe934f e181667 ebe934f ebb06ed ebe934f 907c06a ebe934f e181667 ebe934f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 | import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any
import telemetry
from config import CLIENT_DOMAIN, DISPLAY_NAMES, DOMAIN_CLIENTS
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from grader import get_embedder, get_nli_model
from huggingface_hub import InferenceClient
from pipeline import _build_index, clear_index_cache, run
from pydantic import BaseModel
log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)
UI_DIR = Path(__file__).parent.parent / "ui"
@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
hf_token = os.environ.get("HF_TOKEN")
if not hf_token:
raise RuntimeError("HF_TOKEN not set")
app.state.hf_client = InferenceClient(token=hf_token)
embedder = get_embedder()
get_nli_model()
for domain in DOMAIN_CLIENTS:
_build_index(domain, embedder)
log.info("Models and KB indexes pre-warmed. Ready.")
yield
app = FastAPI(title="AI Response Validator", lifespan=lifespan)
app.add_middleware(
CORSMiddleware,
allow_origins=["*"],
allow_methods=["GET", "POST"],
allow_headers=["*"],
)
class QueryRequest(BaseModel):
query: str
client: str
class QueryResponse(BaseModel):
query: str
client: str
client_display: str
answer: str
flagged: bool
sources: list[dict[str, Any]]
evaluation: dict[str, Any]
@app.get("/health")
def health() -> dict[str, str]:
return {"status": "ok"}
@app.get("/config")
def get_config() -> dict[str, Any]:
"""Domain/client structure for the UI switcher."""
return {
"domains": {
domain: [{"id": c, "display": DISPLAY_NAMES[c]} for c in clients]
for domain, clients in DOMAIN_CLIENTS.items()
}
}
@app.post("/refresh-cache")
def refresh_cache() -> dict[str, Any]:
"""Evict KB index cache and rebuild all domain indexes from disk."""
evicted = clear_index_cache()
embedder = get_embedder()
for domain in DOMAIN_CLIENTS:
_build_index(domain, embedder)
log.info("Cache refreshed. Rebuilt indexes for: %s", list(DOMAIN_CLIENTS))
return {"refreshed": evicted, "rebuilt": list(DOMAIN_CLIENTS)}
@app.get("/metrics")
def get_metrics() -> dict[str, Any]:
"""Live session stats from in-memory counters — resets on restart."""
return telemetry.live_stats()
@app.get("/report")
def get_report() -> dict[str, Any]:
"""Accumulated stats from HF Dataset shards — persists across restarts."""
return telemetry.persistent_report()
@app.post("/query", response_model=QueryResponse)
def handle_query(req: QueryRequest) -> dict[str, Any]:
if req.client not in CLIENT_DOMAIN:
raise HTTPException(status_code=400, detail=f"Unknown client: {req.client!r}")
if not req.query.strip():
raise HTTPException(status_code=400, detail="Query cannot be empty")
result = run(
query=req.query.strip(),
client=req.client,
hf_client=app.state.hf_client,
)
if not result.grade_report.overall:
failed = [r.metric for r in result.grade_report.results if not r.passed]
log.warning("EVAL_FAIL client=%s failed_metrics=%s query=%r",
req.client, failed, req.query.strip()[:80])
return result.response_payload
app.mount("/static", StaticFiles(directory=UI_DIR), name="static")
@app.get("/")
def root() -> FileResponse:
return FileResponse(UI_DIR / "index.html")
|