heerjtdev commited on
Commit
8189a78
·
verified ·
1 Parent(s): d9658ea

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -29
app.py CHANGED
@@ -60,27 +60,21 @@
60
 
61
 
62
 
63
-
64
-
65
-
66
  import gradio as gr
67
  import torch
68
  import torch.nn.functional as F
69
  from sentence_transformers import CrossEncoder
70
- import time
71
 
72
  # --- CONFIGURATION ---
73
- # GATE 1: Relevance (Is the answer related to the question?)
74
- # We switch from MS-MARCO (Search) to STS (Semantic Similarity).
75
- # This prevents the "Lion Sleeping" failure.
76
  relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
77
 
78
- # GATE 2: Fact Checking (Is the answer supported by the text?)
79
- # DeBERTa-v3 is state-of-the-art for NLI.
80
  nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
81
 
82
  print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
83
- # We load them once.
84
  rel_model = CrossEncoder(relevance_model_name, device="cpu")
85
  nli_model = CrossEncoder(nli_model_name, device="cpu")
86
  print("✅ System Ready.")
@@ -89,22 +83,18 @@ def evaluate_response(kb, question, user_answer):
89
  if not kb or not question or not user_answer:
90
  return "⚠️ Error: Missing Input", {}, "N/A"
91
 
92
- logs = {} # Dictionary to store debug info
93
 
94
  # --- GATE 1: RELEVANCE CHECK (STS) ---
95
- # Does the answer make sense in the context of the question?
96
- # STS models output a score from 0.0 to 1.0 (usually).
97
  rel_score = rel_model.predict([(question, user_answer)])
98
 
99
- # Check if the model output is raw logits or normalized
100
- # STSb models usually output 0-1. If not, we clip/normalize.
101
- rel_score_val = float(rel_score)
102
  logs['Gate 1 Model'] = relevance_model_name
103
  logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
104
 
105
- # Threshold: STS scores are usually tighter.
106
- # > 0.15 is usually enough to say "These sentences are related".
107
- # "Lion sleeping" vs "What lion doing" should score ~0.4 - 0.6
108
  RELEVANCE_THRESHOLD = 0.15
109
 
110
  if rel_score_val < RELEVANCE_THRESHOLD:
@@ -113,13 +103,29 @@ def evaluate_response(kb, question, user_answer):
113
  return status, logs, "Blocked"
114
 
115
  # --- GATE 2: FACT CHECKING (NLI) ---
116
- # Does the Knowledge Base entail the Answer?
117
  nli_logits = nli_model.predict([(kb, user_answer)])
118
- nli_probs = F.softmax(torch.tensor(nli_logits), dim=0).tolist()
119
 
120
- # DeBERTa-v3-xsmall Labels: 0: Contradiction, 1: Entailment, 2: Neutral
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
  labels = ["Contradiction", "Entailment", "Neutral"]
122
- max_idx = torch.tensor(nli_logits).argmax().item()
 
 
 
 
123
  nli_verdict = labels[max_idx]
124
  nli_conf = nli_probs[max_idx] * 100
125
 
@@ -141,17 +147,15 @@ def evaluate_response(kb, question, user_answer):
141
  logs['Final Outcome'] = "Answer contradicts the text."
142
 
143
  else: # Neutral
144
- # The answer is relevant to the question, but the TEXT doesn't mention it.
145
- # e.g., "The lion likes pizza." (Relevant topic, but hallucinated fact)
146
  status = "❌ INCORRECT (Hallucination/Not in Text)"
147
  logs['Final Outcome'] = "Answer not found in text."
148
 
149
  return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
150
 
151
  # --- UI SETUP ---
152
- with gr.Blocks(title="NLI Logic Engine v5 (Debug Enabled)", theme=gr.themes.Soft()) as demo:
153
- gr.Markdown("## 🧠 Neural Logic Engine v5")
154
- gr.Markdown("Corrected Architecture: Uses **STS (Semantic Similarity)** for Relevance and **NLI** for Fact Checking.")
155
 
156
  with gr.Row():
157
  with gr.Column(scale=1):
@@ -163,7 +167,6 @@ with gr.Blocks(title="NLI Logic Engine v5 (Debug Enabled)", theme=gr.themes.Soft
163
  with gr.Column(scale=1):
164
  verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
165
  nli_metric = gr.Label(label="NLI Confidence")
166
- # JSON output for full transparency
167
  debug_log = gr.JSON(label="System Internals (Debug Log)")
168
 
169
  btn.click(
 
60
 
61
 
62
 
 
 
 
63
  import gradio as gr
64
  import torch
65
  import torch.nn.functional as F
66
  from sentence_transformers import CrossEncoder
 
67
 
68
  # --- CONFIGURATION ---
69
+ # GATE 1: Semantic Relevance (STS)
70
+ # Checks if the Answer is conversationally related to the Question.
 
71
  relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
72
 
73
+ # GATE 2: Fact Checking (NLI)
74
+ # Checks if the Answer is supported by the Knowledge Base.
75
  nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
76
 
77
  print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
 
78
  rel_model = CrossEncoder(relevance_model_name, device="cpu")
79
  nli_model = CrossEncoder(nli_model_name, device="cpu")
80
  print("✅ System Ready.")
 
83
  if not kb or not question or not user_answer:
84
  return "⚠️ Error: Missing Input", {}, "N/A"
85
 
86
+ logs = {}
87
 
88
  # --- GATE 1: RELEVANCE CHECK (STS) ---
 
 
89
  rel_score = rel_model.predict([(question, user_answer)])
90
 
91
+ # FIX 1: Use .item() to safely extract float from numpy array
92
+ rel_score_val = rel_score.item()
93
+
94
  logs['Gate 1 Model'] = relevance_model_name
95
  logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
96
 
97
+ # Threshold: STS score > 0.15 usually implies relevance
 
 
98
  RELEVANCE_THRESHOLD = 0.15
99
 
100
  if rel_score_val < RELEVANCE_THRESHOLD:
 
103
  return status, logs, "Blocked"
104
 
105
  # --- GATE 2: FACT CHECKING (NLI) ---
 
106
  nli_logits = nli_model.predict([(kb, user_answer)])
 
107
 
108
+ # FIX 2: Handle Dimensions safely
109
+ # Convert to tensor
110
+ nli_tensor = torch.tensor(nli_logits)
111
+
112
+ # If the model returns a batch dimension (e.g. [1, 3]), squeeze it to flat [3]
113
+ if nli_tensor.dim() > 1:
114
+ nli_tensor = nli_tensor.squeeze()
115
+
116
+ # Apply Softmax across the classes (now dim=0 is safe on a flat tensor)
117
+ nli_probs = F.softmax(nli_tensor, dim=0).tolist()
118
+
119
+ # Get the winner index
120
+ max_idx = nli_tensor.argmax().item()
121
+
122
+ # Standard NLI Labels
123
  labels = ["Contradiction", "Entailment", "Neutral"]
124
+
125
+ # Safety check for model label count mismatch
126
+ if max_idx >= len(labels):
127
+ return "⚠️ Model Error", {"Error": "Label mismatch"}, "N/A"
128
+
129
  nli_verdict = labels[max_idx]
130
  nli_conf = nli_probs[max_idx] * 100
131
 
 
147
  logs['Final Outcome'] = "Answer contradicts the text."
148
 
149
  else: # Neutral
 
 
150
  status = "❌ INCORRECT (Hallucination/Not in Text)"
151
  logs['Final Outcome'] = "Answer not found in text."
152
 
153
  return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
154
 
155
  # --- UI SETUP ---
156
+ with gr.Blocks(title="NLI Logic Engine v5", theme=gr.themes.Soft()) as demo:
157
+ gr.Markdown("## 🧠 Neural Logic Engine v5.1 (Bug Fixes Applied)")
158
+ gr.Markdown("Corrected Architecture: STS for Relevance + NLI for Fact Checking.")
159
 
160
  with gr.Row():
161
  with gr.Column(scale=1):
 
167
  with gr.Column(scale=1):
168
  verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
169
  nli_metric = gr.Label(label="NLI Confidence")
 
170
  debug_log = gr.JSON(label="System Internals (Debug Log)")
171
 
172
  btn.click(