petter2025 commited on
Commit
d22a44b
·
verified ·
1 Parent(s): 550ce84

Create nli_detector.py

Browse files
Files changed (1) hide show
  1. nli_detector.py +46 -0
nli_detector.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Natural Language Inference detector – checks if generated response is consistent with input.
3
+ """
4
+ import logging
5
+ from typing import Optional
6
+ import torch
7
+ from transformers import pipeline
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+ class NLIDetector:
12
+ """Uses an NLI model to detect contradictions/hallucinations."""
13
+ def __init__(self, model_name: str = "typeform/distilroberta-base-mnli"):
14
+ try:
15
+ self.pipeline = pipeline(
16
+ "text-classification",
17
+ model=model_name,
18
+ device=0 if torch.cuda.is_available() else -1
19
+ )
20
+ logger.info(f"NLI model {model_name} loaded.")
21
+ except Exception as e:
22
+ logger.error(f"Failed to load NLI model: {e}")
23
+ self.pipeline = None
24
+
25
+ def check(self, premise: str, hypothesis: str) -> Optional[float]:
26
+ """
27
+ Returns probability of entailment (higher means more consistent).
28
+ """
29
+ if self.pipeline is None:
30
+ return None
31
+ try:
32
+ result = self.pipeline(f"{premise} </s></s> {hypothesis}")[0]
33
+ # The model outputs label and score
34
+ if result['label'] == 'ENTAILMENT':
35
+ return result['score']
36
+ else:
37
+ # For contradiction/neutral, return 1 - score? Better to return entailment probability directly.
38
+ # Some models give 'CONTRADICTION' and 'NEUTRAL' – we can treat as low consistency.
39
+ # We'll use the score of the entailment class if present, else 0.
40
+ # But the pipeline might return only the top label. Let's get probabilities for all labels.
41
+ # This is more complex. For simplicity, we'll assume the model gives entailment score.
42
+ # In practice, we'd use a dedicated NLI model that returns probabilities.
43
+ return 0.0
44
+ except Exception as e:
45
+ logger.error(f"NLI error: {e}")
46
+ return None