petter2025's picture
Update nli_detector.py
8cbb09b verified
raw
history blame
2.54 kB
"""
Natural Language Inference detector – checks if generated response is consistent with input.
"""
import logging
from typing import Optional
import torch
from transformers import pipeline
logger = logging.getLogger(__name__)
class NLIDetector:
"""
Uses an NLI model to detect contradictions/hallucinations.
Returns entailment probability (0 to 1) for a given premise‑hypothesis pair.
"""
def __init__(self, model_name: str = "microsoft/deberta-base-mnli"):
"""
Args:
model_name: Hugging Face model identifier for NLI.
Default is a public model that does not require authentication.
"""
try:
# Request all scores to obtain probabilities for each class.
# The pipeline returns a list of lists of dicts: each dict has 'label' and 'score'.
self.pipeline = pipeline(
"text-classification",
model=model_name,
device=0 if torch.cuda.is_available() else -1,
return_all_scores=True
)
logger.info(f"NLI model {model_name} loaded with return_all_scores=True.")
except Exception as e:
logger.error(f"Failed to load NLI model: {e}")
self.pipeline = None
def check(self, premise: str, hypothesis: str) -> Optional[float]:
"""
Returns probability of entailment (higher means more consistent).
Args:
premise: The original input/context.
hypothesis: The generated response.
Returns:
Float between 0 and 1, or None if model unavailable.
"""
if self.pipeline is None:
return None
try:
# For a single input, the pipeline returns a list containing one element,
# which is itself a list of class-score dicts.
result = self.pipeline(f"{premise} </s></s> {hypothesis}")
# result[0] is the list of scores for all classes.
scores = result[0]
# Find the score corresponding to 'ENTAILMENT' (typical label for this model).
for item in scores:
if item['label'] == 'ENTAILMENT':
return item['score']
# If the label is not found (should not happen), fall back to 0.0.
logger.warning("ENTAILMENT label not found in NLI output; returning 0.0.")
return 0.0
except Exception as e:
logger.error(f"NLI error: {e}")
return None