from fastapi import APIRouter, HTTPException, Request, Depends from pydantic import BaseModel, Field from uuid import uuid4 from model.inference import model_instance from db.database import SessionLocal from db.models import Prediction from core.advanced_cache import advanced_cache from slowapi import Limiter from slowapi.util import get_remote_address import logging import time import torch from typing import Optional logger = logging.getLogger(__name__) limiter = Limiter(key_func=get_remote_address) router = APIRouter() class PredictRequest(BaseModel): prompt: str = Field(..., min_length=1, max_length=2000, description="Context prompt") response: str = Field(..., min_length=1, max_length=2000, description="AI response to evaluate") question: str = Field(..., min_length=1, max_length=500, description="Question being answered") use_cache: Optional[bool] = Field(True, description="Whether to use caching") class PredictResponse(BaseModel): is_hallucination: bool confidence_score: float raw_prediction: str processing_time: float request_id: str cached: bool method: Optional[str] = None # Add method field @router.post("/debug-predict") async def debug_predict(request: Request, predict_request: PredictRequest): """ Debug version of predict endpoint that shows raw model output """ req_id = str(uuid4()) try: # Make prediction with debug info input_text = model_instance.format_prompt( predict_request.prompt, predict_request.response, predict_request.question ) # Get raw model prediction inputs = model_instance.tokenizer( input_text, return_tensors="pt", max_length=512, truncation=True, padding=True ) inputs = {k: v.to(model_instance.device) for k, v in inputs.items()} with torch.no_grad(): outputs = model_instance.model.generate( **inputs, max_new_tokens=20, num_return_sequences=1, temperature=0.1, # Lower temperature for more deterministic output do_sample=False, # Disable sampling for debugging pad_token_id=model_instance.tokenizer.pad_token_id, eos_token_id=model_instance.tokenizer.eos_token_id ) # Decode prediction full_output = model_instance.tokenizer.decode(outputs[0], skip_special_tokens=True) pred_text = full_output.replace(input_text, "").strip().lower() # Manual confidence and hallucination detection confidence_score = model_instance._calculate_confidence(pred_text) is_hallucination = model_instance._is_hallucination(pred_text) return { "request_id": req_id, "input_prompt": input_text, "raw_model_output": full_output, "extracted_prediction": pred_text, "is_hallucination": is_hallucination, "confidence_score": confidence_score, "debug_info": { "model_name": model_instance.model.config.name_or_path if hasattr(model_instance.model.config, 'name_or_path') else "unknown", "device": str(model_instance.device), "input_length": len(input_text), "output_length": len(full_output) } } except Exception as e: logger.error(f"Debug prediction error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Debug prediction failed: {str(e)}") @router.post("/predict", response_model=PredictResponse) @limiter.limit("60/minute; 10/10seconds; 3/5seconds") async def predict(request: Request, predict_request: PredictRequest): """ Predict whether an AI response contains hallucination """ req_id = str(uuid4()) start_time = time.time() try: # Input validation if not all([predict_request.prompt.strip(), predict_request.response.strip(), predict_request.question.strip()]): raise HTTPException(status_code=400, detail="All fields must be non-empty") # Check cache first if enabled cached_result = None if predict_request.use_cache: cached_result = advanced_cache.get( predict_request.prompt, predict_request.response, predict_request.question ) if cached_result: logger.info(f"Cache hit for request {req_id}") return PredictResponse( **cached_result, request_id=req_id, cached=True ) # Make prediction result = model_instance.predict( predict_request.prompt, predict_request.response, predict_request.question ) # Cache the result if caching is enabled if predict_request.use_cache: advanced_cache.set( predict_request.prompt, predict_request.response, predict_request.question, result ) # Store in database try: db = SessionLocal() pred = Prediction( id=req_id, prompt=predict_request.prompt, response=predict_request.response, question=predict_request.question, is_hallucination=result["is_hallucination"], confidence_score=result["confidence_score"], raw_prediction=result["raw_prediction"], processing_time=result["processing_time"] ) db.add(pred) db.commit() logger.info(f"Prediction saved: {req_id}") except Exception as db_error: logger.error(f"Database error: {str(db_error)}") finally: db.close() return PredictResponse( **result, request_id=req_id, cached=False ) except HTTPException: raise except Exception as e: logger.error(f"Prediction error: {str(e)}", exc_info=True) raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")