import os os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' os.environ["CUDA_VISIBLE_DEVICES"] = "3" """ FastAPI service for support claim checking using HHEM model. This service provides an API endpoint to check if subclaims are supported by context. """ import os import sys from typing import List, Dict, Any from fastapi import FastAPI, HTTPException from pydantic import BaseModel import warnings warnings.filterwarnings("ignore") try: import torch from transformers import AutoModelForSequenceClassification _HHEM_AVAILABLE = True except ImportError: torch = None AutoModelForSequenceClassification = None _HHEM_AVAILABLE = False # --- HHEM (vectara/hallucination_evaluation_model) for support checking --- HHEM_MODEL_NAME = os.getenv("HHEM_MODEL_NAME", "vectara/hallucination_evaluation_model") _HHEM_MODEL = None def load_hhem_model(model_name: str = None): """Load the HHEM model for subclaim verification (premise=generated text, hypothesis=subclaim).""" global _HHEM_MODEL if not _HHEM_AVAILABLE: raise RuntimeError("torch and transformers are required for HHEM support checking") if _HHEM_MODEL is not None: return _HHEM_MODEL name = model_name or HHEM_MODEL_NAME _HHEM_MODEL = AutoModelForSequenceClassification.from_pretrained( name, trust_remote_code=True, torch_dtype=torch.bfloat16, device_map="auto", ) _HHEM_MODEL.eval() return _HHEM_MODEL def verify_subclaims_in_text( model, generated_text: str, subclaims: List[str], threshold: float = 0.5, batch_size: int = 32, ) -> List[Dict[str, Any]]: """ Verify how much information from subclaims exists in generated text. HHEM: premise=generated text, hypothesis=subclaim. Returns PASS/FAIL per subclaim. """ pairs = [(generated_text, claim) for claim in subclaims] results = [] for i in range(0, len(pairs), batch_size): batch_pairs = pairs[i : i + batch_size] batch_scores = model.predict(batch_pairs) for j, score in enumerate(batch_scores): claim_index = i + j claim = subclaims[claim_index] s = score.item() if hasattr(score, "item") else float(score) results.append({ "subclaim": claim, "score": round(s, 4), "status": "PASS" if s > threshold else "FAIL", "exists_in_text": s > threshold, }) return results # FastAPI app app = FastAPI(title="Support Claim Checking API", version="1.0.0") class SupportCheckRequest(BaseModel): """Request model for support claim checking.""" context: str subclaims: List[str] threshold: float = 0.5 batch_size: int = 32 class SupportCheckResponse(BaseModel): """Response model for support claim checking.""" labels: List[str] # "supported" | "not_supported" | "invalid" details: List[Dict[str, Any]] # Detailed results with scores @app.get("/health") async def health_check(): """Health check endpoint.""" return { "status": "healthy", "hhem_available": _HHEM_AVAILABLE, "model_loaded": _HHEM_MODEL is not None } @app.post("/check_support", response_model=SupportCheckResponse) async def check_support(request: SupportCheckRequest): """ Check if subclaims are supported by the context. Args: request: SupportCheckRequest containing context, subclaims, threshold, and batch_size Returns: SupportCheckResponse with labels and detailed results """ if not request.context or not request.subclaims: return SupportCheckResponse( labels=[], details=[] ) if not _HHEM_AVAILABLE: return SupportCheckResponse( labels=["invalid"] * len(request.subclaims), details=[] ) try: model = load_hhem_model() results = verify_subclaims_in_text( model, request.context, request.subclaims, threshold=request.threshold, batch_size=request.batch_size, ) # Map PASS -> "supported", FAIL -> "not_supported" to match existing reward logic labels = ["supported" if r["status"] == "PASS" else "not_supported" for r in results] return SupportCheckResponse( labels=labels, details=results ) except Exception as exc: raise HTTPException( status_code=500, detail=f"HHEM support check failed: {str(exc)}" ) if __name__ == "__main__": import uvicorn port = int(os.getenv("SUPPORT_API_PORT", "8091")) host = os.getenv("SUPPORT_API_HOST", "0.0.0.0") uvicorn.run(app, host=host, port=port)