| import os |
| os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID' |
| os.environ["CUDA_VISIBLE_DEVICES"] = "3" |
| """ |
| FastAPI service for support claim checking using HHEM model. |
| This service provides an API endpoint to check if subclaims are supported by context. |
| """ |
| import os |
| import sys |
| from typing import List, Dict, Any |
| from fastapi import FastAPI, HTTPException |
| from pydantic import BaseModel |
| import warnings |
| warnings.filterwarnings("ignore") |
|
|
| try: |
| import torch |
| from transformers import AutoModelForSequenceClassification |
| _HHEM_AVAILABLE = True |
| except ImportError: |
| torch = None |
| AutoModelForSequenceClassification = None |
| _HHEM_AVAILABLE = False |
|
|
| |
| HHEM_MODEL_NAME = os.getenv("HHEM_MODEL_NAME", "vectara/hallucination_evaluation_model") |
| _HHEM_MODEL = None |
|
|
|
|
| def load_hhem_model(model_name: str = None): |
| """Load the HHEM model for subclaim verification (premise=generated text, hypothesis=subclaim).""" |
| global _HHEM_MODEL |
| if not _HHEM_AVAILABLE: |
| raise RuntimeError("torch and transformers are required for HHEM support checking") |
| if _HHEM_MODEL is not None: |
| return _HHEM_MODEL |
| name = model_name or HHEM_MODEL_NAME |
| _HHEM_MODEL = AutoModelForSequenceClassification.from_pretrained( |
| name, |
| trust_remote_code=True, |
| torch_dtype=torch.bfloat16, |
| device_map="auto", |
| ) |
| _HHEM_MODEL.eval() |
| return _HHEM_MODEL |
|
|
|
|
| def verify_subclaims_in_text( |
| model, |
| generated_text: str, |
| subclaims: List[str], |
| threshold: float = 0.5, |
| batch_size: int = 32, |
| ) -> List[Dict[str, Any]]: |
| """ |
| Verify how much information from subclaims exists in generated text. |
| HHEM: premise=generated text, hypothesis=subclaim. Returns PASS/FAIL per subclaim. |
| """ |
| pairs = [(generated_text, claim) for claim in subclaims] |
| results = [] |
| for i in range(0, len(pairs), batch_size): |
| batch_pairs = pairs[i : i + batch_size] |
| batch_scores = model.predict(batch_pairs) |
| for j, score in enumerate(batch_scores): |
| claim_index = i + j |
| claim = subclaims[claim_index] |
| s = score.item() if hasattr(score, "item") else float(score) |
| results.append({ |
| "subclaim": claim, |
| "score": round(s, 4), |
| "status": "PASS" if s > threshold else "FAIL", |
| "exists_in_text": s > threshold, |
| }) |
| return results |
|
|
|
|
| |
| app = FastAPI(title="Support Claim Checking API", version="1.0.0") |
|
|
|
|
| class SupportCheckRequest(BaseModel): |
| """Request model for support claim checking.""" |
| context: str |
| subclaims: List[str] |
| threshold: float = 0.5 |
| batch_size: int = 32 |
|
|
|
|
| class SupportCheckResponse(BaseModel): |
| """Response model for support claim checking.""" |
| labels: List[str] |
| details: List[Dict[str, Any]] |
|
|
|
|
| @app.get("/health") |
| async def health_check(): |
| """Health check endpoint.""" |
| return { |
| "status": "healthy", |
| "hhem_available": _HHEM_AVAILABLE, |
| "model_loaded": _HHEM_MODEL is not None |
| } |
|
|
|
|
| @app.post("/check_support", response_model=SupportCheckResponse) |
| async def check_support(request: SupportCheckRequest): |
| """ |
| Check if subclaims are supported by the context. |
| |
| Args: |
| request: SupportCheckRequest containing context, subclaims, threshold, and batch_size |
| |
| Returns: |
| SupportCheckResponse with labels and detailed results |
| """ |
| if not request.context or not request.subclaims: |
| return SupportCheckResponse( |
| labels=[], |
| details=[] |
| ) |
| |
| if not _HHEM_AVAILABLE: |
| return SupportCheckResponse( |
| labels=["invalid"] * len(request.subclaims), |
| details=[] |
| ) |
| |
| try: |
| model = load_hhem_model() |
| results = verify_subclaims_in_text( |
| model, |
| request.context, |
| request.subclaims, |
| threshold=request.threshold, |
| batch_size=request.batch_size, |
| ) |
| |
| labels = ["supported" if r["status"] == "PASS" else "not_supported" for r in results] |
| |
| return SupportCheckResponse( |
| labels=labels, |
| details=results |
| ) |
| except Exception as exc: |
| raise HTTPException( |
| status_code=500, |
| detail=f"HHEM support check failed: {str(exc)}" |
| ) |
|
|
|
|
| if __name__ == "__main__": |
| import uvicorn |
| port = int(os.getenv("SUPPORT_API_PORT", "8091")) |
| host = os.getenv("SUPPORT_API_HOST", "0.0.0.0") |
| uvicorn.run(app, host=host, port=port) |
|
|