| from fastapi import APIRouter, HTTPException, Request, Depends
|
| from pydantic import BaseModel, Field
|
| from uuid import uuid4
|
| from model.inference import model_instance
|
| from db.database import SessionLocal
|
| from db.models import Prediction
|
| from core.advanced_cache import advanced_cache
|
| from slowapi import Limiter
|
| from slowapi.util import get_remote_address
|
| import logging
|
| import time
|
| import torch
|
| from typing import Optional
|
|
|
| logger = logging.getLogger(__name__)
|
| limiter = Limiter(key_func=get_remote_address)
|
|
|
| router = APIRouter()
|
|
|
| class PredictRequest(BaseModel):
|
| prompt: str = Field(..., min_length=1, max_length=2000, description="Context prompt")
|
| response: str = Field(..., min_length=1, max_length=2000, description="AI response to evaluate")
|
| question: str = Field(..., min_length=1, max_length=500, description="Question being answered")
|
| use_cache: Optional[bool] = Field(True, description="Whether to use caching")
|
|
|
| class PredictResponse(BaseModel):
|
| is_hallucination: bool
|
| confidence_score: float
|
| raw_prediction: str
|
| processing_time: float
|
| request_id: str
|
| cached: bool
|
| method: Optional[str] = None
|
|
|
| @router.post("/debug-predict")
|
| async def debug_predict(request: Request, predict_request: PredictRequest):
|
| """
|
| Debug version of predict endpoint that shows raw model output
|
| """
|
| req_id = str(uuid4())
|
|
|
| try:
|
|
|
| input_text = model_instance.format_prompt(
|
| predict_request.prompt,
|
| predict_request.response,
|
| predict_request.question
|
| )
|
|
|
|
|
| inputs = model_instance.tokenizer(
|
| input_text,
|
| return_tensors="pt",
|
| max_length=512,
|
| truncation=True,
|
| padding=True
|
| )
|
|
|
| inputs = {k: v.to(model_instance.device) for k, v in inputs.items()}
|
|
|
| with torch.no_grad():
|
| outputs = model_instance.model.generate(
|
| **inputs,
|
| max_new_tokens=20,
|
| num_return_sequences=1,
|
| temperature=0.1,
|
| do_sample=False,
|
| pad_token_id=model_instance.tokenizer.pad_token_id,
|
| eos_token_id=model_instance.tokenizer.eos_token_id
|
| )
|
|
|
|
|
| full_output = model_instance.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| pred_text = full_output.replace(input_text, "").strip().lower()
|
|
|
|
|
| confidence_score = model_instance._calculate_confidence(pred_text)
|
| is_hallucination = model_instance._is_hallucination(pred_text)
|
|
|
| return {
|
| "request_id": req_id,
|
| "input_prompt": input_text,
|
| "raw_model_output": full_output,
|
| "extracted_prediction": pred_text,
|
| "is_hallucination": is_hallucination,
|
| "confidence_score": confidence_score,
|
| "debug_info": {
|
| "model_name": model_instance.model.config.name_or_path if hasattr(model_instance.model.config, 'name_or_path') else "unknown",
|
| "device": str(model_instance.device),
|
| "input_length": len(input_text),
|
| "output_length": len(full_output)
|
| }
|
| }
|
|
|
| except Exception as e:
|
| logger.error(f"Debug prediction error: {str(e)}", exc_info=True)
|
| raise HTTPException(status_code=500, detail=f"Debug prediction failed: {str(e)}")
|
|
|
| @router.post("/predict", response_model=PredictResponse)
|
| @limiter.limit("60/minute; 10/10seconds; 3/5seconds")
|
| async def predict(request: Request, predict_request: PredictRequest):
|
| """
|
| Predict whether an AI response contains hallucination
|
| """
|
| req_id = str(uuid4())
|
| start_time = time.time()
|
|
|
| try:
|
|
|
| if not all([predict_request.prompt.strip(), predict_request.response.strip(), predict_request.question.strip()]):
|
| raise HTTPException(status_code=400, detail="All fields must be non-empty")
|
|
|
|
|
| cached_result = None
|
| if predict_request.use_cache:
|
| cached_result = advanced_cache.get(
|
| predict_request.prompt,
|
| predict_request.response,
|
| predict_request.question
|
| )
|
|
|
| if cached_result:
|
| logger.info(f"Cache hit for request {req_id}")
|
| return PredictResponse(
|
| **cached_result,
|
| request_id=req_id,
|
| cached=True
|
| )
|
|
|
|
|
| result = model_instance.predict(
|
| predict_request.prompt,
|
| predict_request.response,
|
| predict_request.question
|
| )
|
|
|
|
|
| if predict_request.use_cache:
|
| advanced_cache.set(
|
| predict_request.prompt,
|
| predict_request.response,
|
| predict_request.question,
|
| result
|
| )
|
|
|
|
|
| try:
|
| db = SessionLocal()
|
| pred = Prediction(
|
| id=req_id,
|
| prompt=predict_request.prompt,
|
| response=predict_request.response,
|
| question=predict_request.question,
|
| is_hallucination=result["is_hallucination"],
|
| confidence_score=result["confidence_score"],
|
| raw_prediction=result["raw_prediction"],
|
| processing_time=result["processing_time"]
|
| )
|
| db.add(pred)
|
| db.commit()
|
| logger.info(f"Prediction saved: {req_id}")
|
| except Exception as db_error:
|
| logger.error(f"Database error: {str(db_error)}")
|
| finally:
|
| db.close()
|
|
|
| return PredictResponse(
|
| **result,
|
| request_id=req_id,
|
| cached=False
|
| )
|
|
|
| except HTTPException:
|
| raise
|
| except Exception as e:
|
| logger.error(f"Prediction error: {str(e)}", exc_info=True)
|
| raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")
|
|
|