File size: 3,685 Bytes
ebe934f
 
e181667
ebe934f
 
e181667
ebe934f
e181667
 
ebe934f
 
 
 
907c06a
e181667
 
 
ebe934f
 
 
 
 
 
 
 
e181667
ebb06ed
 
 
 
aef9f0f
907c06a
aef9f0f
 
 
ebe934f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
907c06a
e181667
 
ebe934f
 
 
e181667
ebe934f
 
 
 
e181667
ebe934f
 
 
 
 
 
 
 
 
e77a2f2
e181667
e77a2f2
 
 
 
 
 
 
 
 
c79d967
e181667
c79d967
 
 
 
 
e181667
c79d967
 
 
 
ebe934f
e181667
ebe934f
 
 
 
 
 
 
 
ebb06ed
ebe934f
907c06a
 
 
 
ebe934f
 
 
 
 
 
 
e181667
ebe934f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import logging
import os
from collections.abc import AsyncGenerator
from contextlib import asynccontextmanager
from pathlib import Path
from typing import Any

import telemetry
from config import CLIENT_DOMAIN, DISPLAY_NAMES, DOMAIN_CLIENTS
from fastapi import FastAPI, HTTPException
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import FileResponse
from fastapi.staticfiles import StaticFiles
from grader import get_embedder, get_nli_model
from huggingface_hub import InferenceClient
from pipeline import _build_index, clear_index_cache, run
from pydantic import BaseModel

log = logging.getLogger(__name__)
logging.basicConfig(level=logging.INFO)

UI_DIR = Path(__file__).parent.parent / "ui"


@asynccontextmanager
async def lifespan(app: FastAPI) -> AsyncGenerator[None, None]:
    hf_token = os.environ.get("HF_TOKEN")
    if not hf_token:
        raise RuntimeError("HF_TOKEN not set")
    app.state.hf_client = InferenceClient(token=hf_token)
    embedder = get_embedder()
    get_nli_model()
    for domain in DOMAIN_CLIENTS:
        _build_index(domain, embedder)
    log.info("Models and KB indexes pre-warmed. Ready.")
    yield


app = FastAPI(title="AI Response Validator", lifespan=lifespan)

app.add_middleware(
    CORSMiddleware,
    allow_origins=["*"],
    allow_methods=["GET", "POST"],
    allow_headers=["*"],
)


class QueryRequest(BaseModel):
    query: str
    client: str


class QueryResponse(BaseModel):
    query: str
    client: str
    client_display: str
    answer: str
    flagged: bool
    sources: list[dict[str, Any]]
    evaluation: dict[str, Any]


@app.get("/health")
def health() -> dict[str, str]:
    return {"status": "ok"}


@app.get("/config")
def get_config() -> dict[str, Any]:
    """Domain/client structure for the UI switcher."""
    return {
        "domains": {
            domain: [{"id": c, "display": DISPLAY_NAMES[c]} for c in clients]
            for domain, clients in DOMAIN_CLIENTS.items()
        }
    }


@app.post("/refresh-cache")
def refresh_cache() -> dict[str, Any]:
    """Evict KB index cache and rebuild all domain indexes from disk."""
    evicted = clear_index_cache()
    embedder = get_embedder()
    for domain in DOMAIN_CLIENTS:
        _build_index(domain, embedder)
    log.info("Cache refreshed. Rebuilt indexes for: %s", list(DOMAIN_CLIENTS))
    return {"refreshed": evicted, "rebuilt": list(DOMAIN_CLIENTS)}


@app.get("/metrics")
def get_metrics() -> dict[str, Any]:
    """Live session stats from in-memory counters — resets on restart."""
    return telemetry.live_stats()


@app.get("/report")
def get_report() -> dict[str, Any]:
    """Accumulated stats from HF Dataset shards — persists across restarts."""
    return telemetry.persistent_report()


@app.post("/query", response_model=QueryResponse)
def handle_query(req: QueryRequest) -> dict[str, Any]:
    if req.client not in CLIENT_DOMAIN:
        raise HTTPException(status_code=400, detail=f"Unknown client: {req.client!r}")
    if not req.query.strip():
        raise HTTPException(status_code=400, detail="Query cannot be empty")

    result = run(
        query=req.query.strip(),
        client=req.client,
        hf_client=app.state.hf_client,
    )
    if not result.grade_report.overall:
        failed = [r.metric for r in result.grade_report.results if not r.passed]
        log.warning("EVAL_FAIL client=%s failed_metrics=%s query=%r",
                    req.client, failed, req.query.strip()[:80])
    return result.response_payload


app.mount("/static", StaticFiles(directory=UI_DIR), name="static")


@app.get("/")
def root() -> FileResponse:
    return FileResponse(UI_DIR / "index.html")