hallucination-detector-project / app /api /batch_predict.py
KShoichi's picture
Upload app/api/batch_predict.py with huggingface_hub
a0acbd1 verified
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from uuid import uuid4
from model.inference import model_instance
from db.database import SessionLocal
from db.models import Prediction
from typing import List
router = APIRouter()
class BatchPredictRequest(BaseModel):
items: List[dict]
@router.post("/batch_predict")
def batch_predict(request: BatchPredictRequest):
results = []
db = SessionLocal()
try:
for item in request.items:
req_id = str(uuid4())
result = model_instance.predict(item['prompt'], item['response'], item['question'])
pred = Prediction(
id=req_id,
prompt=item['prompt'],
response=item['response'],
question=item['question'],
is_hallucination=result["is_hallucination"],
confidence_score=result["confidence_score"],
raw_prediction=result["raw_prediction"],
processing_time=result["processing_time"]
)
db.add(pred)
results.append({**result, "request_id": req_id})
db.commit()
db.close()
return {"results": results}
except Exception as e:
db.close()
raise HTTPException(status_code=500, detail=str(e))