Update app.py
Browse files
app.py
CHANGED
|
@@ -60,27 +60,21 @@
|
|
| 60 |
|
| 61 |
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
|
| 66 |
import gradio as gr
|
| 67 |
import torch
|
| 68 |
import torch.nn.functional as F
|
| 69 |
from sentence_transformers import CrossEncoder
|
| 70 |
-
import time
|
| 71 |
|
| 72 |
# --- CONFIGURATION ---
|
| 73 |
-
# GATE 1: Relevance (
|
| 74 |
-
#
|
| 75 |
-
# This prevents the "Lion Sleeping" failure.
|
| 76 |
relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
|
| 77 |
|
| 78 |
-
# GATE 2: Fact Checking (
|
| 79 |
-
#
|
| 80 |
nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
|
| 81 |
|
| 82 |
print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
|
| 83 |
-
# We load them once.
|
| 84 |
rel_model = CrossEncoder(relevance_model_name, device="cpu")
|
| 85 |
nli_model = CrossEncoder(nli_model_name, device="cpu")
|
| 86 |
print("✅ System Ready.")
|
|
@@ -89,22 +83,18 @@ def evaluate_response(kb, question, user_answer):
|
|
| 89 |
if not kb or not question or not user_answer:
|
| 90 |
return "⚠️ Error: Missing Input", {}, "N/A"
|
| 91 |
|
| 92 |
-
logs = {}
|
| 93 |
|
| 94 |
# --- GATE 1: RELEVANCE CHECK (STS) ---
|
| 95 |
-
# Does the answer make sense in the context of the question?
|
| 96 |
-
# STS models output a score from 0.0 to 1.0 (usually).
|
| 97 |
rel_score = rel_model.predict([(question, user_answer)])
|
| 98 |
|
| 99 |
-
#
|
| 100 |
-
|
| 101 |
-
|
| 102 |
logs['Gate 1 Model'] = relevance_model_name
|
| 103 |
logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
|
| 104 |
|
| 105 |
-
# Threshold: STS
|
| 106 |
-
# > 0.15 is usually enough to say "These sentences are related".
|
| 107 |
-
# "Lion sleeping" vs "What lion doing" should score ~0.4 - 0.6
|
| 108 |
RELEVANCE_THRESHOLD = 0.15
|
| 109 |
|
| 110 |
if rel_score_val < RELEVANCE_THRESHOLD:
|
|
@@ -113,13 +103,29 @@ def evaluate_response(kb, question, user_answer):
|
|
| 113 |
return status, logs, "Blocked"
|
| 114 |
|
| 115 |
# --- GATE 2: FACT CHECKING (NLI) ---
|
| 116 |
-
# Does the Knowledge Base entail the Answer?
|
| 117 |
nli_logits = nli_model.predict([(kb, user_answer)])
|
| 118 |
-
nli_probs = F.softmax(torch.tensor(nli_logits), dim=0).tolist()
|
| 119 |
|
| 120 |
-
#
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 121 |
labels = ["Contradiction", "Entailment", "Neutral"]
|
| 122 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
nli_verdict = labels[max_idx]
|
| 124 |
nli_conf = nli_probs[max_idx] * 100
|
| 125 |
|
|
@@ -141,17 +147,15 @@ def evaluate_response(kb, question, user_answer):
|
|
| 141 |
logs['Final Outcome'] = "Answer contradicts the text."
|
| 142 |
|
| 143 |
else: # Neutral
|
| 144 |
-
# The answer is relevant to the question, but the TEXT doesn't mention it.
|
| 145 |
-
# e.g., "The lion likes pizza." (Relevant topic, but hallucinated fact)
|
| 146 |
status = "❌ INCORRECT (Hallucination/Not in Text)"
|
| 147 |
logs['Final Outcome'] = "Answer not found in text."
|
| 148 |
|
| 149 |
return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
|
| 150 |
|
| 151 |
# --- UI SETUP ---
|
| 152 |
-
with gr.Blocks(title="NLI Logic Engine v5
|
| 153 |
-
gr.Markdown("## 🧠 Neural Logic Engine v5")
|
| 154 |
-
gr.Markdown("Corrected Architecture:
|
| 155 |
|
| 156 |
with gr.Row():
|
| 157 |
with gr.Column(scale=1):
|
|
@@ -163,7 +167,6 @@ with gr.Blocks(title="NLI Logic Engine v5 (Debug Enabled)", theme=gr.themes.Soft
|
|
| 163 |
with gr.Column(scale=1):
|
| 164 |
verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
|
| 165 |
nli_metric = gr.Label(label="NLI Confidence")
|
| 166 |
-
# JSON output for full transparency
|
| 167 |
debug_log = gr.JSON(label="System Internals (Debug Log)")
|
| 168 |
|
| 169 |
btn.click(
|
|
|
|
| 60 |
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
| 63 |
import gradio as gr
|
| 64 |
import torch
|
| 65 |
import torch.nn.functional as F
|
| 66 |
from sentence_transformers import CrossEncoder
|
|
|
|
| 67 |
|
| 68 |
# --- CONFIGURATION ---
|
| 69 |
+
# GATE 1: Semantic Relevance (STS)
|
| 70 |
+
# Checks if the Answer is conversationally related to the Question.
|
|
|
|
| 71 |
relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
|
| 72 |
|
| 73 |
+
# GATE 2: Fact Checking (NLI)
|
| 74 |
+
# Checks if the Answer is supported by the Knowledge Base.
|
| 75 |
nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
|
| 76 |
|
| 77 |
print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
|
|
|
|
| 78 |
rel_model = CrossEncoder(relevance_model_name, device="cpu")
|
| 79 |
nli_model = CrossEncoder(nli_model_name, device="cpu")
|
| 80 |
print("✅ System Ready.")
|
|
|
|
| 83 |
if not kb or not question or not user_answer:
|
| 84 |
return "⚠️ Error: Missing Input", {}, "N/A"
|
| 85 |
|
| 86 |
+
logs = {}
|
| 87 |
|
| 88 |
# --- GATE 1: RELEVANCE CHECK (STS) ---
|
|
|
|
|
|
|
| 89 |
rel_score = rel_model.predict([(question, user_answer)])
|
| 90 |
|
| 91 |
+
# FIX 1: Use .item() to safely extract float from numpy array
|
| 92 |
+
rel_score_val = rel_score.item()
|
| 93 |
+
|
| 94 |
logs['Gate 1 Model'] = relevance_model_name
|
| 95 |
logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
|
| 96 |
|
| 97 |
+
# Threshold: STS score > 0.15 usually implies relevance
|
|
|
|
|
|
|
| 98 |
RELEVANCE_THRESHOLD = 0.15
|
| 99 |
|
| 100 |
if rel_score_val < RELEVANCE_THRESHOLD:
|
|
|
|
| 103 |
return status, logs, "Blocked"
|
| 104 |
|
| 105 |
# --- GATE 2: FACT CHECKING (NLI) ---
|
|
|
|
| 106 |
nli_logits = nli_model.predict([(kb, user_answer)])
|
|
|
|
| 107 |
|
| 108 |
+
# FIX 2: Handle Dimensions safely
|
| 109 |
+
# Convert to tensor
|
| 110 |
+
nli_tensor = torch.tensor(nli_logits)
|
| 111 |
+
|
| 112 |
+
# If the model returns a batch dimension (e.g. [1, 3]), squeeze it to flat [3]
|
| 113 |
+
if nli_tensor.dim() > 1:
|
| 114 |
+
nli_tensor = nli_tensor.squeeze()
|
| 115 |
+
|
| 116 |
+
# Apply Softmax across the classes (now dim=0 is safe on a flat tensor)
|
| 117 |
+
nli_probs = F.softmax(nli_tensor, dim=0).tolist()
|
| 118 |
+
|
| 119 |
+
# Get the winner index
|
| 120 |
+
max_idx = nli_tensor.argmax().item()
|
| 121 |
+
|
| 122 |
+
# Standard NLI Labels
|
| 123 |
labels = ["Contradiction", "Entailment", "Neutral"]
|
| 124 |
+
|
| 125 |
+
# Safety check for model label count mismatch
|
| 126 |
+
if max_idx >= len(labels):
|
| 127 |
+
return "⚠️ Model Error", {"Error": "Label mismatch"}, "N/A"
|
| 128 |
+
|
| 129 |
nli_verdict = labels[max_idx]
|
| 130 |
nli_conf = nli_probs[max_idx] * 100
|
| 131 |
|
|
|
|
| 147 |
logs['Final Outcome'] = "Answer contradicts the text."
|
| 148 |
|
| 149 |
else: # Neutral
|
|
|
|
|
|
|
| 150 |
status = "❌ INCORRECT (Hallucination/Not in Text)"
|
| 151 |
logs['Final Outcome'] = "Answer not found in text."
|
| 152 |
|
| 153 |
return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
|
| 154 |
|
| 155 |
# --- UI SETUP ---
|
| 156 |
+
with gr.Blocks(title="NLI Logic Engine v5", theme=gr.themes.Soft()) as demo:
|
| 157 |
+
gr.Markdown("## 🧠 Neural Logic Engine v5.1 (Bug Fixes Applied)")
|
| 158 |
+
gr.Markdown("Corrected Architecture: STS for Relevance + NLI for Fact Checking.")
|
| 159 |
|
| 160 |
with gr.Row():
|
| 161 |
with gr.Column(scale=1):
|
|
|
|
| 167 |
with gr.Column(scale=1):
|
| 168 |
verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
|
| 169 |
nli_metric = gr.Label(label="NLI Confidence")
|
|
|
|
| 170 |
debug_log = gr.JSON(label="System Internals (Debug Log)")
|
| 171 |
|
| 172 |
btn.click(
|