File size: 1,355 Bytes
a0acbd1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
from fastapi import APIRouter, HTTPException
from pydantic import BaseModel
from uuid import uuid4
from model.inference import model_instance
from db.database import SessionLocal
from db.models import Prediction
from typing import List

router = APIRouter()

class BatchPredictRequest(BaseModel):
    items: List[dict]

@router.post("/batch_predict")
def batch_predict(request: BatchPredictRequest):
    results = []
    db = SessionLocal()
    try:
        for item in request.items:
            req_id = str(uuid4())
            result = model_instance.predict(item['prompt'], item['response'], item['question'])
            pred = Prediction(
                id=req_id,
                prompt=item['prompt'],
                response=item['response'],
                question=item['question'],
                is_hallucination=result["is_hallucination"],
                confidence_score=result["confidence_score"],
                raw_prediction=result["raw_prediction"],
                processing_time=result["processing_time"]
            )
            db.add(pred)
            results.append({**result, "request_id": req_id})
        db.commit()
        db.close()
        return {"results": results}
    except Exception as e:
        db.close()
        raise HTTPException(status_code=500, detail=str(e))