petter2025 commited on
Commit
141f2db
·
verified ·
1 Parent(s): 162cb4f

Update nli_detector.py

Browse files
Files changed (1) hide show
  1. nli_detector.py +33 -15
nli_detector.py CHANGED
@@ -9,15 +9,26 @@ from transformers import pipeline
9
  logger = logging.getLogger(__name__)
10
 
11
  class NLIDetector:
12
- """Uses an NLI model to detect contradictions/hallucinations."""
 
 
 
 
13
  def __init__(self, model_name: str = "typeform/distilroberta-base-mnli"):
 
 
 
 
14
  try:
 
 
15
  self.pipeline = pipeline(
16
  "text-classification",
17
  model=model_name,
18
- device=0 if torch.cuda.is_available() else -1
 
19
  )
20
- logger.info(f"NLI model {model_name} loaded.")
21
  except Exception as e:
22
  logger.error(f"Failed to load NLI model: {e}")
23
  self.pipeline = None
@@ -25,22 +36,29 @@ class NLIDetector:
25
  def check(self, premise: str, hypothesis: str) -> Optional[float]:
26
  """
27
  Returns probability of entailment (higher means more consistent).
 
 
 
 
 
 
 
28
  """
29
  if self.pipeline is None:
30
  return None
31
  try:
32
- result = self.pipeline(f"{premise} </s></s> {hypothesis}")[0]
33
- # The model outputs label and score
34
- if result['label'] == 'ENTAILMENT':
35
- return result['score']
36
- else:
37
- # For contradiction/neutral, return 1 - score? Better to return entailment probability directly.
38
- # Some models give 'CONTRADICTION' and 'NEUTRAL' – we can treat as low consistency.
39
- # We'll use the score of the entailment class if present, else 0.
40
- # But the pipeline might return only the top label. Let's get probabilities for all labels.
41
- # This is more complex. For simplicity, we'll assume the model gives entailment score.
42
- # In practice, we'd use a dedicated NLI model that returns probabilities.
43
- return 0.0
44
  except Exception as e:
45
  logger.error(f"NLI error: {e}")
46
  return None
 
9
  logger = logging.getLogger(__name__)
10
 
11
  class NLIDetector:
12
+ """
13
+ Uses an NLI model to detect contradictions/hallucinations.
14
+ Returns entailment probability (0 to 1) for a given premise‑hypothesis pair.
15
+ """
16
+
17
  def __init__(self, model_name: str = "typeform/distilroberta-base-mnli"):
18
+ """
19
+ Args:
20
+ model_name: Hugging Face model identifier for NLI.
21
+ """
22
  try:
23
+ # Request all scores to obtain probabilities for each class.
24
+ # The pipeline returns a list of lists of dicts: each dict has 'label' and 'score'.
25
  self.pipeline = pipeline(
26
  "text-classification",
27
  model=model_name,
28
+ device=0 if torch.cuda.is_available() else -1,
29
+ return_all_scores=True
30
  )
31
+ logger.info(f"NLI model {model_name} loaded with return_all_scores=True.")
32
  except Exception as e:
33
  logger.error(f"Failed to load NLI model: {e}")
34
  self.pipeline = None
 
36
  def check(self, premise: str, hypothesis: str) -> Optional[float]:
37
  """
38
  Returns probability of entailment (higher means more consistent).
39
+
40
+ Args:
41
+ premise: The original input/context.
42
+ hypothesis: The generated response.
43
+
44
+ Returns:
45
+ Float between 0 and 1, or None if model unavailable.
46
  """
47
  if self.pipeline is None:
48
  return None
49
  try:
50
+ # For a single input, the pipeline returns a list containing one element,
51
+ # which is itself a list of class-score dicts.
52
+ result = self.pipeline(f"{premise} </s></s> {hypothesis}")
53
+ # result[0] is the list of scores for all classes.
54
+ scores = result[0]
55
+ # Find the score corresponding to 'ENTAILMENT' (typical label for this model).
56
+ for item in scores:
57
+ if item['label'] == 'ENTAILMENT':
58
+ return item['score']
59
+ # If the label is not found (should not happen), fall back to 0.0.
60
+ logger.warning("ENTAILMENT label not found in NLI output; returning 0.0.")
61
+ return 0.0
62
  except Exception as e:
63
  logger.error(f"NLI error: {e}")
64
  return None