File size: 2,123 Bytes
b707cd3
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
from contextlib import asynccontextmanager
from fastapi import FastAPI, Depends
from pydantic import BaseModel
from sqlalchemy.ext.asyncio import AsyncSession

from api.detector import HallucinationDetector
from api.database import engine, Base, get_db
from api.models import HallucinationLog

detector = HallucinationDetector()

# Lifespan context to create tables automatically on startup
@asynccontextmanager
async def lifespan(app: FastAPI):
    async with engine.begin() as conn:
        # In a real production app you'd use Alembic migrations, 
        # but this is perfect for our current phase.
        await conn.run_sync(Base.metadata.create_all)
    yield

app = FastAPI(
    title="HalluciGuard API",
    description="Async API for detecting LLM hallucinations using NLI.",
    version="1.0.0",
    lifespan=lifespan
)

class HallucinationRequest(BaseModel):
    context: str
    llm_output: str

@app.get("/")
async def root():
    return {"status": "online", "message": "HalluciGuard API is running."}

# Notice we added `db: AsyncSession = Depends(get_db)` here
@app.post("/api/v1/score")
async def score_hallucination(request: HallucinationRequest, db: AsyncSession = Depends(get_db)):
    
    # 1. Run the ML Model
    results = detector.analyze(request.context, request.llm_output)
    
    # 2. Package the data for PostgreSQL
    new_log = HallucinationLog(
        context=request.context,
        llm_output=request.llm_output,
        contradiction_score=results["contradiction_score"],
        entailment_score=results["entailment_score"],
        neutral_score=results["neutral_score"],
        is_hallucination=results["is_hallucination"]
    )
    
    # 3. Async commit to the database
    db.add(new_log)
    await db.commit()
    await db.refresh(new_log) # Grabs the auto-generated ID and Timestamp
    
    # 4. Return to the user
    return {
        "log_id": new_log.id,
        "context": request.context,
        "llm_output": request.llm_output,
        "results": results,
        "timestamp": new_log.created_at
    }