File size: 4,807 Bytes
030876e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
import os
os.environ['CUDA_DEVICE_ORDER'] = 'PCI_BUS_ID'
os.environ["CUDA_VISIBLE_DEVICES"] = "3"
"""
FastAPI service for support claim checking using HHEM model.
This service provides an API endpoint to check if subclaims are supported by context.
"""
import os
import sys
from typing import List, Dict, Any
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
import warnings
warnings.filterwarnings("ignore")

try:
    import torch
    from transformers import AutoModelForSequenceClassification
    _HHEM_AVAILABLE = True
except ImportError:
    torch = None
    AutoModelForSequenceClassification = None
    _HHEM_AVAILABLE = False

# --- HHEM (vectara/hallucination_evaluation_model) for support checking ---
HHEM_MODEL_NAME = os.getenv("HHEM_MODEL_NAME", "vectara/hallucination_evaluation_model")
_HHEM_MODEL = None


def load_hhem_model(model_name: str = None):
    """Load the HHEM model for subclaim verification (premise=generated text, hypothesis=subclaim)."""
    global _HHEM_MODEL
    if not _HHEM_AVAILABLE:
        raise RuntimeError("torch and transformers are required for HHEM support checking")
    if _HHEM_MODEL is not None:
        return _HHEM_MODEL
    name = model_name or HHEM_MODEL_NAME
    _HHEM_MODEL = AutoModelForSequenceClassification.from_pretrained(
        name,
        trust_remote_code=True,
        torch_dtype=torch.bfloat16,
        device_map="auto",
    )
    _HHEM_MODEL.eval()
    return _HHEM_MODEL


def verify_subclaims_in_text(
    model,
    generated_text: str,
    subclaims: List[str],
    threshold: float = 0.5,
    batch_size: int = 32,
) -> List[Dict[str, Any]]:
    """
    Verify how much information from subclaims exists in generated text.
    HHEM: premise=generated text, hypothesis=subclaim. Returns PASS/FAIL per subclaim.
    """
    pairs = [(generated_text, claim) for claim in subclaims]
    results = []
    for i in range(0, len(pairs), batch_size):
        batch_pairs = pairs[i : i + batch_size]
        batch_scores = model.predict(batch_pairs)
        for j, score in enumerate(batch_scores):
            claim_index = i + j
            claim = subclaims[claim_index]
            s = score.item() if hasattr(score, "item") else float(score)
            results.append({
                "subclaim": claim,
                "score": round(s, 4),
                "status": "PASS" if s > threshold else "FAIL",
                "exists_in_text": s > threshold,
            })
    return results


# FastAPI app
app = FastAPI(title="Support Claim Checking API", version="1.0.0")


class SupportCheckRequest(BaseModel):
    """Request model for support claim checking."""
    context: str
    subclaims: List[str]
    threshold: float = 0.5
    batch_size: int = 32


class SupportCheckResponse(BaseModel):
    """Response model for support claim checking."""
    labels: List[str]  # "supported" | "not_supported" | "invalid"
    details: List[Dict[str, Any]]  # Detailed results with scores


@app.get("/health")
async def health_check():
    """Health check endpoint."""
    return {
        "status": "healthy",
        "hhem_available": _HHEM_AVAILABLE,
        "model_loaded": _HHEM_MODEL is not None
    }


@app.post("/check_support", response_model=SupportCheckResponse)
async def check_support(request: SupportCheckRequest):
    """
    Check if subclaims are supported by the context.
    
    Args:
        request: SupportCheckRequest containing context, subclaims, threshold, and batch_size
        
    Returns:
        SupportCheckResponse with labels and detailed results
    """
    if not request.context or not request.subclaims:
        return SupportCheckResponse(
            labels=[],
            details=[]
        )
    
    if not _HHEM_AVAILABLE:
        return SupportCheckResponse(
            labels=["invalid"] * len(request.subclaims),
            details=[]
        )
    
    try:
        model = load_hhem_model()
        results = verify_subclaims_in_text(
            model,
            request.context,
            request.subclaims,
            threshold=request.threshold,
            batch_size=request.batch_size,
        )
        # Map PASS -> "supported", FAIL -> "not_supported" to match existing reward logic
        labels = ["supported" if r["status"] == "PASS" else "not_supported" for r in results]
        
        return SupportCheckResponse(
            labels=labels,
            details=results
        )
    except Exception as exc:
        raise HTTPException(
            status_code=500,
            detail=f"HHEM support check failed: {str(exc)}"
        )


if __name__ == "__main__":
    import uvicorn
    port = int(os.getenv("SUPPORT_API_PORT", "8091"))
    host = os.getenv("SUPPORT_API_HOST", "0.0.0.0")
    uvicorn.run(app, host=host, port=port)