File size: 4,807 Bytes
030876e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 | import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
"""
FastAPI service for support claim checking using HHEM model.
This service provides an API endpoint to check if subclaims are supported by context.
"""
import os
import sys
from typing import List, Dict, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import warnings
warnings.filterwarnings("ignore")
try:
import torch
from transformers import AutoModelForSequenceClassification
_HHEM_AVAILABLE = True
except ImportError:
torch = None
AutoModelForSequenceClassification = None
_HHEM_AVAILABLE = False
# --- HHEM (vectara/hallucination_evaluation_model) for support checking ---
HHEM_MODEL_NAME = os.getenv("HHEM_MODEL_NAME", "vectara/hallucination_evaluation_model")
_HHEM_MODEL = None
def load_hhem_model(model_name: str = None):
"""Load the HHEM model for subclaim verification (premise=generated text, hypothesis=subclaim)."""
global _HHEM_MODEL
if not _HHEM_AVAILABLE:
raise RuntimeError("torch and transformers are required for HHEM support checking")
if _HHEM_MODEL is not None:
return _HHEM_MODEL
name = model_name or HHEM_MODEL_NAME
_HHEM_MODEL = AutoModelForSequenceClassification.from_pretrained(
name,
trust_remote_code=True,
torch_dtype=torch.bfloat16,
device_map="auto",
)
_HHEM_MODEL.eval()
return _HHEM_MODEL
def verify_subclaims_in_text(
model,
generated_text: str,
subclaims: List[str],
threshold: float = 0.5,
batch_size: int = 32,
) -> List[Dict[str, Any]]:
"""
Verify how much information from subclaims exists in generated text.
HHEM: premise=generated text, hypothesis=subclaim. Returns PASS/FAIL per subclaim.
"""
pairs = [(generated_text, claim) for claim in subclaims]
results = []
for i in range(0, len(pairs), batch_size):
batch_pairs = pairs[i : i + batch_size]
batch_scores = model.predict(batch_pairs)
for j, score in enumerate(batch_scores):
claim_index = i + j
claim = subclaims[claim_index]
s = score.item() if hasattr(score, "item") else float(score)
results.append({
"subclaim": claim,
"score": round(s, 4),
"status": "PASS" if s > threshold else "FAIL",
"exists_in_text": s > threshold,
})
return results
# FastAPI app
app = FastAPI(title="Support Claim Checking API", version="1.0.0")
class SupportCheckRequest(BaseModel):
"""Request model for support claim checking."""
context: str
subclaims: List[str]
threshold: float = 0.5
batch_size: int = 32
class SupportCheckResponse(BaseModel):
"""Response model for support claim checking."""
labels: List[str] # "supported" | "not_supported" | "invalid"
details: List[Dict[str, Any]] # Detailed results with scores
@app.get("/health")
async def health_check():
"""Health check endpoint."""
return {
"status": "healthy",
"hhem_available": _HHEM_AVAILABLE,
"model_loaded": _HHEM_MODEL is not None
}
@app.post("/check_support", response_model=SupportCheckResponse)
async def check_support(request: SupportCheckRequest):
"""
Check if subclaims are supported by the context.
Args:
request: SupportCheckRequest containing context, subclaims, threshold, and batch_size
Returns:
SupportCheckResponse with labels and detailed results
"""
if not request.context or not request.subclaims:
return SupportCheckResponse(
labels=[],
details=[]
)
if not _HHEM_AVAILABLE:
return SupportCheckResponse(
labels=["invalid"] * len(request.subclaims),
details=[]
)
try:
model = load_hhem_model()
results = verify_subclaims_in_text(
model,
request.context,
request.subclaims,
threshold=request.threshold,
batch_size=request.batch_size,
)
# Map PASS -> "supported", FAIL -> "not_supported" to match existing reward logic
labels = ["supported" if r["status"] == "PASS" else "not_supported" for r in results]
return SupportCheckResponse(
labels=labels,
details=results
)
except Exception as exc:
raise HTTPException(
status_code=500,
detail=f"HHEM support check failed: {str(exc)}"
)
if __name__ == "__main__":
import uvicorn
port = int(os.getenv("SUPPORT_API_PORT", "8091"))
host = os.getenv("SUPPORT_API_HOST", "0.0.0.0")
uvicorn.run(app, host=host, port=port)
|