heerjtdev commited on
Commit
82d9acf
Β·
verified Β·
1 Parent(s): 8189a78

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +310 -83
app.py CHANGED
@@ -60,120 +60,347 @@
60
 
61
 
62
 
63
- import gradio as gr
64
- import torch
65
- import torch.nn.functional as F
66
- from sentence_transformers import CrossEncoder
67
 
68
- # --- CONFIGURATION ---
69
- # GATE 1: Semantic Relevance (STS)
70
- # Checks if the Answer is conversationally related to the Question.
71
- relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
72
 
73
- # GATE 2: Fact Checking (NLI)
74
- # Checks if the Answer is supported by the Knowledge Base.
75
- nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
76
 
77
- print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
78
- rel_model = CrossEncoder(relevance_model_name, device="cpu")
79
- nli_model = CrossEncoder(nli_model_name, device="cpu")
80
- print("βœ… System Ready.")
81
 
82
- def evaluate_response(kb, question, user_answer):
83
- if not kb or not question or not user_answer:
84
- return "⚠️ Error: Missing Input", {}, "N/A"
85
 
86
- logs = {}
87
 
88
- # --- GATE 1: RELEVANCE CHECK (STS) ---
89
- rel_score = rel_model.predict([(question, user_answer)])
90
 
91
- # FIX 1: Use .item() to safely extract float from numpy array
92
- rel_score_val = rel_score.item()
93
 
94
- logs['Gate 1 Model'] = relevance_model_name
95
- logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
96
 
97
- # Threshold: STS score > 0.15 usually implies relevance
98
- RELEVANCE_THRESHOLD = 0.15
99
 
100
- if rel_score_val < RELEVANCE_THRESHOLD:
101
- status = "❌ INCORRECT (Irrelevant)"
102
- logs['Verdict'] = "Blocked at Gate 1 (Answer unrelated to Question)"
103
- return status, logs, "Blocked"
104
 
105
- # --- GATE 2: FACT CHECKING (NLI) ---
106
- nli_logits = nli_model.predict([(kb, user_answer)])
107
 
108
- # FIX 2: Handle Dimensions safely
109
- # Convert to tensor
110
- nli_tensor = torch.tensor(nli_logits)
111
 
112
- # If the model returns a batch dimension (e.g. [1, 3]), squeeze it to flat [3]
113
- if nli_tensor.dim() > 1:
114
- nli_tensor = nli_tensor.squeeze()
115
 
116
- # Apply Softmax across the classes (now dim=0 is safe on a flat tensor)
117
- nli_probs = F.softmax(nli_tensor, dim=0).tolist()
118
 
119
- # Get the winner index
120
- max_idx = nli_tensor.argmax().item()
121
 
122
- # Standard NLI Labels
123
- labels = ["Contradiction", "Entailment", "Neutral"]
124
 
125
- # Safety check for model label count mismatch
126
- if max_idx >= len(labels):
127
- return "⚠️ Model Error", {"Error": "Label mismatch"}, "N/A"
128
-
129
- nli_verdict = labels[max_idx]
130
- nli_conf = nli_probs[max_idx] * 100
131
-
132
- logs['Gate 2 Model'] = nli_model_name
133
- logs['Gate 2 Probabilities'] = {
134
- "Contradiction": f"{nli_probs[0]*100:.1f}%",
135
- "Entailment": f"{nli_probs[1]*100:.1f}%",
136
- "Neutral": f"{nli_probs[2]*100:.1f}%"
137
- }
138
- logs['Gate 2 Verdict'] = nli_verdict
139
 
140
- # --- FINAL DECISION LOGIC ---
141
- if nli_verdict == "Entailment":
142
- status = "βœ… CORRECT (Confirmed)"
143
- logs['Final Outcome'] = "Answer is Relevant and Factual."
 
 
 
 
 
 
 
 
144
 
145
- elif nli_verdict == "Contradiction":
146
- status = "❌ INCORRECT (False Information)"
147
- logs['Final Outcome'] = "Answer contradicts the text."
148
 
149
- else: # Neutral
150
- status = "❌ INCORRECT (Hallucination/Not in Text)"
151
- logs['Final Outcome'] = "Answer not found in text."
152
 
153
- return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
154
 
155
- # --- UI SETUP ---
156
- with gr.Blocks(title="NLI Logic Engine v5", theme=gr.themes.Soft()) as demo:
157
- gr.Markdown("## 🧠 Neural Logic Engine v5.1 (Bug Fixes Applied)")
158
- gr.Markdown("Corrected Architecture: STS for Relevance + NLI for Fact Checking.")
159
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
160
  with gr.Row():
161
  with gr.Column(scale=1):
162
- kb_input = gr.Textbox(label="Knowledge Base", lines=5, value="When a lion was resting in the jungle, a mouse began racing up and down his body for fun. The lion's sleep was disturbed, and he woke in anger.")
163
- q_input = gr.Textbox(label="Question", value="What was the lion doing?")
164
- a_input = gr.Textbox(label="User Answer", value="The lion was sleeping in the jungle.")
 
 
 
 
 
 
 
 
 
 
 
165
  btn = gr.Button("Evaluate", variant="primary")
166
-
167
  with gr.Column(scale=1):
168
- verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
169
- nli_metric = gr.Label(label="NLI Confidence")
170
- debug_log = gr.JSON(label="System Internals (Debug Log)")
171
 
172
  btn.click(
173
  fn=evaluate_response,
174
  inputs=[kb_input, q_input, a_input],
175
- outputs=[verdict_out, debug_log, nli_metric]
176
  )
177
 
178
  if __name__ == "__main__":
179
- demo.launch()
 
60
 
61
 
62
 
63
+ # import gradio as gr
64
+ # import torch
65
+ # import torch.nn.functional as F
66
+ # from sentence_transformers import CrossEncoder
67
 
68
+ # # --- CONFIGURATION ---
69
+ # # GATE 1: Semantic Relevance (STS)
70
+ # # Checks if the Answer is conversationally related to the Question.
71
+ # relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
72
 
73
+ # # GATE 2: Fact Checking (NLI)
74
+ # # Checks if the Answer is supported by the Knowledge Base.
75
+ # nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
76
 
77
+ # print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
78
+ # rel_model = CrossEncoder(relevance_model_name, device="cpu")
79
+ # nli_model = CrossEncoder(nli_model_name, device="cpu")
80
+ # print("βœ… System Ready.")
81
 
82
+ # def evaluate_response(kb, question, user_answer):
83
+ # if not kb or not question or not user_answer:
84
+ # return "⚠️ Error: Missing Input", {}, "N/A"
85
 
86
+ # logs = {}
87
 
88
+ # # --- GATE 1: RELEVANCE CHECK (STS) ---
89
+ # rel_score = rel_model.predict([(question, user_answer)])
90
 
91
+ # # FIX 1: Use .item() to safely extract float from numpy array
92
+ # rel_score_val = rel_score.item()
93
 
94
+ # logs['Gate 1 Model'] = relevance_model_name
95
+ # logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
96
 
97
+ # # Threshold: STS score > 0.15 usually implies relevance
98
+ # RELEVANCE_THRESHOLD = 0.15
99
 
100
+ # if rel_score_val < RELEVANCE_THRESHOLD:
101
+ # status = "❌ INCORRECT (Irrelevant)"
102
+ # logs['Verdict'] = "Blocked at Gate 1 (Answer unrelated to Question)"
103
+ # return status, logs, "Blocked"
104
 
105
+ # # --- GATE 2: FACT CHECKING (NLI) ---
106
+ # nli_logits = nli_model.predict([(kb, user_answer)])
107
 
108
+ # # FIX 2: Handle Dimensions safely
109
+ # # Convert to tensor
110
+ # nli_tensor = torch.tensor(nli_logits)
111
 
112
+ # # If the model returns a batch dimension (e.g. [1, 3]), squeeze it to flat [3]
113
+ # if nli_tensor.dim() > 1:
114
+ # nli_tensor = nli_tensor.squeeze()
115
 
116
+ # # Apply Softmax across the classes (now dim=0 is safe on a flat tensor)
117
+ # nli_probs = F.softmax(nli_tensor, dim=0).tolist()
118
 
119
+ # # Get the winner index
120
+ # max_idx = nli_tensor.argmax().item()
121
 
122
+ # # Standard NLI Labels
123
+ # labels = ["Contradiction", "Entailment", "Neutral"]
124
 
125
+ # # Safety check for model label count mismatch
126
+ # if max_idx >= len(labels):
127
+ # return "⚠️ Model Error", {"Error": "Label mismatch"}, "N/A"
128
+
129
+ # nli_verdict = labels[max_idx]
130
+ # nli_conf = nli_probs[max_idx] * 100
 
 
 
 
 
 
 
 
131
 
132
+ # logs['Gate 2 Model'] = nli_model_name
133
+ # logs['Gate 2 Probabilities'] = {
134
+ # "Contradiction": f"{nli_probs[0]*100:.1f}%",
135
+ # "Entailment": f"{nli_probs[1]*100:.1f}%",
136
+ # "Neutral": f"{nli_probs[2]*100:.1f}%"
137
+ # }
138
+ # logs['Gate 2 Verdict'] = nli_verdict
139
+
140
+ # # --- FINAL DECISION LOGIC ---
141
+ # if nli_verdict == "Entailment":
142
+ # status = "βœ… CORRECT (Confirmed)"
143
+ # logs['Final Outcome'] = "Answer is Relevant and Factual."
144
 
145
+ # elif nli_verdict == "Contradiction":
146
+ # status = "❌ INCORRECT (False Information)"
147
+ # logs['Final Outcome'] = "Answer contradicts the text."
148
 
149
+ # else: # Neutral
150
+ # status = "❌ INCORRECT (Hallucination/Not in Text)"
151
+ # logs['Final Outcome'] = "Answer not found in text."
152
 
153
+ # return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
154
 
155
+ # # --- UI SETUP ---
156
+ # with gr.Blocks(title="NLI Logic Engine v5", theme=gr.themes.Soft()) as demo:
157
+ # gr.Markdown("## 🧠 Neural Logic Engine v5.1 (Bug Fixes Applied)")
158
+ # gr.Markdown("Corrected Architecture: STS for Relevance + NLI for Fact Checking.")
159
 
160
+ # with gr.Row():
161
+ # with gr.Column(scale=1):
162
+ # kb_input = gr.Textbox(label="Knowledge Base", lines=5, value="When a lion was resting in the jungle, a mouse began racing up and down his body for fun. The lion's sleep was disturbed, and he woke in anger.")
163
+ # q_input = gr.Textbox(label="Question", value="What was the lion doing?")
164
+ # a_input = gr.Textbox(label="User Answer", value="The lion was sleeping in the jungle.")
165
+ # btn = gr.Button("Evaluate", variant="primary")
166
+
167
+ # with gr.Column(scale=1):
168
+ # verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
169
+ # nli_metric = gr.Label(label="NLI Confidence")
170
+ # debug_log = gr.JSON(label="System Internals (Debug Log)")
171
+
172
+ # btn.click(
173
+ # fn=evaluate_response,
174
+ # inputs=[kb_input, q_input, a_input],
175
+ # outputs=[verdict_out, debug_log, nli_metric]
176
+ # )
177
+
178
+ # if __name__ == "__main__":
179
+ # demo.launch()
180
+
181
+
182
+
183
+
184
+
185
+
186
+
187
+
188
+
189
+ import gradio as gr
190
+ import torch
191
+ import torch.nn.functional as F
192
+ from sentence_transformers import CrossEncoder
193
+ import re
194
+
195
+ # ==============================
196
+ # CONFIGURATION
197
+ # ==============================
198
+
199
+ RELEVANCE_MODEL = "cross-encoder/stsb-distilroberta-base"
200
+ NLI_MODEL = "cross-encoder/nli-deberta-v3-xsmall"
201
+
202
+ RELEVANCE_THRESHOLD_QA = 0.15
203
+ RELEVANCE_THRESHOLD_KB = 0.30
204
+ ENTAILMENT_THRESHOLD = 0.65
205
+
206
+ DEVICE = "cpu"
207
+
208
+ # ==============================
209
+ # LOAD MODELS
210
+ # ==============================
211
+
212
+ print("Loading models...")
213
+ rel_model = CrossEncoder(RELEVANCE_MODEL, device=DEVICE)
214
+ nli_model = CrossEncoder(NLI_MODEL, device=DEVICE)
215
+ print("βœ… Models loaded")
216
+
217
+ # ==============================
218
+ # UTILITIES
219
+ # ==============================
220
+
221
+ def split_sentences(text):
222
+ text = text.strip()
223
+ if not text:
224
+ return []
225
+ return re.split(r'(?<=[.!?])\s+', text)
226
+
227
+ def softmax_logits(logits):
228
+ t = torch.tensor(logits)
229
+ if t.dim() > 1:
230
+ t = t.squeeze(0)
231
+ probs = F.softmax(t, dim=0).tolist()
232
+ return probs
233
+
234
+ # ==============================
235
+ # CORE EVALUATION FUNCTION
236
+ # ==============================
237
+
238
+ def evaluate_response(kb, question, user_answer):
239
+ logs = {}
240
+
241
+ # ------------------------------
242
+ # INPUT VALIDATION
243
+ # ------------------------------
244
+ if not kb or not question or not user_answer:
245
+ return "⚠️ ERROR: Missing input", {}, "N/A"
246
+
247
+ logs["Inputs"] = {
248
+ "Question": question,
249
+ "User Answer": user_answer,
250
+ "KB Length (chars)": len(kb)
251
+ }
252
+
253
+ # ------------------------------
254
+ # GATE 1 β€” QUESTION ↔ ANSWER RELEVANCE
255
+ # ------------------------------
256
+ qa_score = rel_model.predict([(question, user_answer)]).item()
257
+
258
+ logs["Gate 1 β€” QA Relevance"] = {
259
+ "Model": RELEVANCE_MODEL,
260
+ "Score": round(qa_score, 4),
261
+ "Threshold": RELEVANCE_THRESHOLD_QA
262
+ }
263
+
264
+ if qa_score < RELEVANCE_THRESHOLD_QA:
265
+ logs["Final Decision"] = "Blocked at Gate 1 (Irrelevant Answer)"
266
+ return (
267
+ "❌ INCORRECT (Irrelevant)",
268
+ logs,
269
+ f"Relevance {qa_score:.2f}"
270
+ )
271
+
272
+ # ------------------------------
273
+ # GATE 2 β€” KB SENTENCE SELECTION (STS)
274
+ # ------------------------------
275
+ kb_sentences = split_sentences(kb)
276
+ logs["KB Processing"] = {
277
+ "Total Sentences": len(kb_sentences),
278
+ "Sentences": kb_sentences
279
+ }
280
+
281
+ if not kb_sentences:
282
+ logs["Final Decision"] = "Empty KB after sentence split"
283
+ return "❌ INCORRECT (Empty KB)", logs, "N/A"
284
+
285
+ sentence_pairs = [(s, user_answer) for s in kb_sentences]
286
+ sim_scores = rel_model.predict(sentence_pairs)
287
+
288
+ best_idx = int(sim_scores.argmax())
289
+ best_sentence = kb_sentences[best_idx]
290
+ best_score = float(sim_scores[best_idx])
291
+
292
+ logs["Gate 2 β€” KB Sentence Selection"] = {
293
+ "Model": RELEVANCE_MODEL,
294
+ "Best Sentence": best_sentence,
295
+ "Best Similarity Score": round(best_score, 4),
296
+ "Threshold": RELEVANCE_THRESHOLD_KB,
297
+ "All Scores": [
298
+ {"sentence": s, "score": round(float(sc), 4)}
299
+ for s, sc in zip(kb_sentences, sim_scores)
300
+ ]
301
+ }
302
+
303
+ if best_score < RELEVANCE_THRESHOLD_KB:
304
+ logs["Final Decision"] = "Answer not grounded in KB"
305
+ return (
306
+ "❌ INCORRECT (Not Found in Text)",
307
+ logs,
308
+ f"KB Similarity {best_score:.2f}"
309
+ )
310
+
311
+ # ------------------------------
312
+ # GATE 3 β€” NLI (Sentence ↔ Answer)
313
+ # ------------------------------
314
+ nli_logits = nli_model.predict([(best_sentence, user_answer)])
315
+ probs = softmax_logits(nli_logits)
316
+
317
+ labels = ["Contradiction", "Entailment", "Neutral"]
318
+ verdict_idx = int(torch.tensor(probs).argmax())
319
+ verdict = labels[verdict_idx]
320
+ confidence = probs[verdict_idx] * 100
321
+
322
+ logs["Gate 3 β€” NLI Verification"] = {
323
+ "Model": NLI_MODEL,
324
+ "Premise": best_sentence,
325
+ "Hypothesis": user_answer,
326
+ "Probabilities": {
327
+ "Contradiction": f"{probs[0]*100:.2f}%",
328
+ "Entailment": f"{probs[1]*100:.2f}%",
329
+ "Neutral": f"{probs[2]*100:.2f}%"
330
+ },
331
+ "Verdict": verdict,
332
+ "Confidence": f"{confidence:.2f}%",
333
+ "Entailment Threshold": f"{ENTAILMENT_THRESHOLD*100:.0f}%"
334
+ }
335
+
336
+ # ------------------------------
337
+ # FINAL DECISION
338
+ # ------------------------------
339
+ if verdict == "Entailment" and probs[1] >= ENTAILMENT_THRESHOLD:
340
+ logs["Final Decision"] = "Answer is Supported by Text"
341
+ return (
342
+ "βœ… CORRECT (Confirmed)",
343
+ logs,
344
+ f"Entailment {confidence:.1f}%"
345
+ )
346
+
347
+ if verdict == "Contradiction":
348
+ logs["Final Decision"] = "Answer Contradicts Text"
349
+ return (
350
+ "❌ INCORRECT (Contradiction)",
351
+ logs,
352
+ f"Contradiction {confidence:.1f}%"
353
+ )
354
+
355
+ logs["Final Decision"] = "Answer Not Explicitly Stated"
356
+ return (
357
+ "❌ INCORRECT (Neutral / Not in Text)",
358
+ logs,
359
+ f"Neutral {confidence:.1f}%"
360
+ )
361
+
362
+ # ==============================
363
+ # GRADIO UI
364
+ # ==============================
365
+
366
+ with gr.Blocks(title="Neural Logic Engine v6", theme=gr.themes.Soft()) as demo:
367
+ gr.Markdown("## 🧠 Neural Logic Engine v6")
368
+ gr.Markdown(
369
+ "**Architecture:**\n"
370
+ "- Gate 1: Question ↔ Answer relevance (STS)\n"
371
+ "- Gate 2: KB sentence grounding (STS)\n"
372
+ "- Gate 3: Sentence-level NLI verification\n"
373
+ "- Fully logged, deterministic decisions"
374
+ )
375
+
376
  with gr.Row():
377
  with gr.Column(scale=1):
378
+ kb_input = gr.Textbox(
379
+ label="Knowledge Base",
380
+ lines=6,
381
+ value="When a lion was resting in the jungle, a mouse began racing up and down his body for fun. "
382
+ "The lion's sleep was disturbed, and he woke in anger."
383
+ )
384
+ q_input = gr.Textbox(
385
+ label="Question",
386
+ value="What was the lion doing?"
387
+ )
388
+ a_input = gr.Textbox(
389
+ label="User Answer",
390
+ value="The lion was sleeping in the jungle."
391
+ )
392
  btn = gr.Button("Evaluate", variant="primary")
393
+
394
  with gr.Column(scale=1):
395
+ verdict_out = gr.Textbox(label="Final Verdict")
396
+ confidence_out = gr.Label(label="Model Confidence")
397
+ debug_log = gr.JSON(label="System Internals (FULL DEBUG LOG)")
398
 
399
  btn.click(
400
  fn=evaluate_response,
401
  inputs=[kb_input, q_input, a_input],
402
+ outputs=[verdict_out, debug_log, confidence_out]
403
  )
404
 
405
  if __name__ == "__main__":
406
+ demo.launch()