KShoichi's picture
Upload app/api/predict.py with huggingface_hub
9594d90 verified
from fastapi import APIRouter, HTTPException, Request, Depends
from pydantic import BaseModel, Field
from uuid import uuid4
from model.inference import model_instance
from db.database import SessionLocal
from db.models import Prediction
from core.advanced_cache import advanced_cache
from slowapi import Limiter
from slowapi.util import get_remote_address
import logging
import time
import torch
from typing import Optional
logger = logging.getLogger(__name__)
limiter = Limiter(key_func=get_remote_address)
router = APIRouter()
class PredictRequest(BaseModel):
prompt: str = Field(..., min_length=1, max_length=2000, description="Context prompt")
response: str = Field(..., min_length=1, max_length=2000, description="AI response to evaluate")
question: str = Field(..., min_length=1, max_length=500, description="Question being answered")
use_cache: Optional[bool] = Field(True, description="Whether to use caching")
class PredictResponse(BaseModel):
is_hallucination: bool
confidence_score: float
raw_prediction: str
processing_time: float
request_id: str
cached: bool
method: Optional[str] = None # Add method field
@router.post("/debug-predict")
async def debug_predict(request: Request, predict_request: PredictRequest):
"""
Debug version of predict endpoint that shows raw model output
"""
req_id = str(uuid4())
try:
# Make prediction with debug info
input_text = model_instance.format_prompt(
predict_request.prompt,
predict_request.response,
predict_request.question
)
# Get raw model prediction
inputs = model_instance.tokenizer(
input_text,
return_tensors="pt",
max_length=512,
truncation=True,
padding=True
)
inputs = {k: v.to(model_instance.device) for k, v in inputs.items()}
with torch.no_grad():
outputs = model_instance.model.generate(
**inputs,
max_new_tokens=20,
num_return_sequences=1,
temperature=0.1, # Lower temperature for more deterministic output
do_sample=False, # Disable sampling for debugging
pad_token_id=model_instance.tokenizer.pad_token_id,
eos_token_id=model_instance.tokenizer.eos_token_id
)
# Decode prediction
full_output = model_instance.tokenizer.decode(outputs[0], skip_special_tokens=True)
pred_text = full_output.replace(input_text, "").strip().lower()
# Manual confidence and hallucination detection
confidence_score = model_instance._calculate_confidence(pred_text)
is_hallucination = model_instance._is_hallucination(pred_text)
return {
"request_id": req_id,
"input_prompt": input_text,
"raw_model_output": full_output,
"extracted_prediction": pred_text,
"is_hallucination": is_hallucination,
"confidence_score": confidence_score,
"debug_info": {
"model_name": model_instance.model.config.name_or_path if hasattr(model_instance.model.config, 'name_or_path') else "unknown",
"device": str(model_instance.device),
"input_length": len(input_text),
"output_length": len(full_output)
}
}
except Exception as e:
logger.error(f"Debug prediction error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Debug prediction failed: {str(e)}")
@router.post("/predict", response_model=PredictResponse)
@limiter.limit("60/minute; 10/10seconds; 3/5seconds")
async def predict(request: Request, predict_request: PredictRequest):
"""
Predict whether an AI response contains hallucination
"""
req_id = str(uuid4())
start_time = time.time()
try:
# Input validation
if not all([predict_request.prompt.strip(), predict_request.response.strip(), predict_request.question.strip()]):
raise HTTPException(status_code=400, detail="All fields must be non-empty")
# Check cache first if enabled
cached_result = None
if predict_request.use_cache:
cached_result = advanced_cache.get(
predict_request.prompt,
predict_request.response,
predict_request.question
)
if cached_result:
logger.info(f"Cache hit for request {req_id}")
return PredictResponse(
**cached_result,
request_id=req_id,
cached=True
)
# Make prediction
result = model_instance.predict(
predict_request.prompt,
predict_request.response,
predict_request.question
)
# Cache the result if caching is enabled
if predict_request.use_cache:
advanced_cache.set(
predict_request.prompt,
predict_request.response,
predict_request.question,
result
)
# Store in database
try:
db = SessionLocal()
pred = Prediction(
id=req_id,
prompt=predict_request.prompt,
response=predict_request.response,
question=predict_request.question,
is_hallucination=result["is_hallucination"],
confidence_score=result["confidence_score"],
raw_prediction=result["raw_prediction"],
processing_time=result["processing_time"]
)
db.add(pred)
db.commit()
logger.info(f"Prediction saved: {req_id}")
except Exception as db_error:
logger.error(f"Database error: {str(db_error)}")
finally:
db.close()
return PredictResponse(
**result,
request_id=req_id,
cached=False
)
except HTTPException:
raise
except Exception as e:
logger.error(f"Prediction error: {str(e)}", exc_info=True)
raise HTTPException(status_code=500, detail=f"Prediction failed: {str(e)}")