petter2025's picture
Create nli_detector.py
d22a44b verified
raw
history blame
1.97 kB
"""
Natural Language Inference detector – checks if generated response is consistent with input.
"""
import logging
from typing import Optional
import torch
from transformers import pipeline
logger = logging.getLogger(__name__)
class NLIDetector:
"""Uses an NLI model to detect contradictions/hallucinations."""
def __init__(self, model_name: str = "typeform/distilroberta-base-mnli"):
try:
self.pipeline = pipeline(
"text-classification",
model=model_name,
device=0 if torch.cuda.is_available() else -1
)
logger.info(f"NLI model {model_name} loaded.")
except Exception as e:
logger.error(f"Failed to load NLI model: {e}")
self.pipeline = None
def check(self, premise: str, hypothesis: str) -> Optional[float]:
"""
Returns probability of entailment (higher means more consistent).
"""
if self.pipeline is None:
return None
try:
result = self.pipeline(f"{premise} </s></s> {hypothesis}")[0]
# The model outputs label and score
if result['label'] == 'ENTAILMENT':
return result['score']
else:
# For contradiction/neutral, return 1 - score? Better to return entailment probability directly.
# Some models give 'CONTRADICTION' and 'NEUTRAL' – we can treat as low consistency.
# We'll use the score of the entailment class if present, else 0.
# But the pipeline might return only the top label. Let's get probabilities for all labels.
# This is more complex. For simplicity, we'll assume the model gives entailment score.
# In practice, we'd use a dedicated NLI model that returns probabilities.
return 0.0
except Exception as e:
logger.error(f"NLI error: {e}")
return None