Update app.py
Browse files
app.py
CHANGED
|
@@ -63,111 +63,113 @@
|
|
| 63 |
|
| 64 |
|
| 65 |
|
| 66 |
-
|
| 67 |
-
|
| 68 |
import gradio as gr
|
| 69 |
import torch
|
| 70 |
import torch.nn.functional as F
|
| 71 |
from sentence_transformers import CrossEncoder
|
|
|
|
| 72 |
|
| 73 |
# --- CONFIGURATION ---
|
| 74 |
-
#
|
| 75 |
-
#
|
| 76 |
-
#
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
#
|
| 81 |
-
# We use a DeBERTa-v3-xsmall or similar high-performance NLI model.
|
| 82 |
-
# It is very robust at detecting Hallucinations vs Entailment.
|
| 83 |
nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
|
| 84 |
|
| 85 |
-
print("
|
| 86 |
-
|
|
|
|
| 87 |
nli_model = CrossEncoder(nli_model_name, device="cpu")
|
| 88 |
-
print("System Ready.")
|
| 89 |
|
| 90 |
def evaluate_response(kb, question, user_answer):
|
| 91 |
if not kb or not question or not user_answer:
|
| 92 |
-
return "⚠️ Missing Input",
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
#
|
| 97 |
-
|
| 98 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
if not is_relevant:
|
| 108 |
-
return (
|
| 109 |
-
"❌ INCORRECT (Irrelevant Answer)",
|
| 110 |
-
f"Low Relevance ({qa_confidence:.1f}%)",
|
| 111 |
-
"Skipped (Not an answer)"
|
| 112 |
-
)
|
| 113 |
-
|
| 114 |
-
# --- GATE 2: Knowledge Base Verification (NLI) ---
|
| 115 |
-
# Now that we know it IS an answer, we check if it is TRUE based on the KB.
|
| 116 |
-
# Premise = KB
|
| 117 |
-
# Hypothesis = user_answer (Clean check, no complex prompt engineering needed)
|
| 118 |
nli_logits = nli_model.predict([(kb, user_answer)])
|
| 119 |
nli_probs = F.softmax(torch.tensor(nli_logits), dim=0).tolist()
|
| 120 |
|
| 121 |
-
#
|
| 122 |
-
|
| 123 |
-
# Label 0 = Contradiction, Label 1 = Entailment, Label 2 = Neutral
|
| 124 |
-
labels = ["CONTRADICTION", "ENTAILMENT", "NEUTRAL"]
|
| 125 |
max_idx = torch.tensor(nli_logits).argmax().item()
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
|
| 130 |
-
|
| 131 |
-
|
| 132 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 133 |
status = "✅ CORRECT (Confirmed)"
|
| 134 |
-
|
| 135 |
-
|
| 136 |
-
|
| 137 |
-
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
|
|
|
|
|
|
| 145 |
|
| 146 |
# --- UI SETUP ---
|
| 147 |
-
with gr.Blocks(title="
|
| 148 |
-
gr.Markdown("## 🧠 Neural
|
| 149 |
-
gr.Markdown("
|
| 150 |
|
| 151 |
with gr.Row():
|
| 152 |
-
|
| 153 |
-
|
| 154 |
-
|
| 155 |
-
|
| 156 |
-
|
| 157 |
-
|
| 158 |
-
|
| 159 |
-
|
| 160 |
-
|
| 161 |
-
|
| 162 |
-
|
| 163 |
-
|
| 164 |
-
|
| 165 |
-
nli_metric = gr.Label(label="Gate 2: Fact Check")
|
| 166 |
-
|
| 167 |
-
check_btn.click(
|
| 168 |
fn=evaluate_response,
|
| 169 |
inputs=[kb_input, q_input, a_input],
|
| 170 |
-
outputs=[
|
| 171 |
)
|
| 172 |
|
| 173 |
if __name__ == "__main__":
|
|
|
|
| 63 |
|
| 64 |
|
| 65 |
|
|
|
|
|
|
|
| 66 |
import gradio as gr
|
| 67 |
import torch
|
| 68 |
import torch.nn.functional as F
|
| 69 |
from sentence_transformers import CrossEncoder
|
| 70 |
+
import time
|
| 71 |
|
| 72 |
# --- CONFIGURATION ---
|
| 73 |
+
# GATE 1: Relevance (Is the answer related to the question?)
|
| 74 |
+
# We switch from MS-MARCO (Search) to STS (Semantic Similarity).
|
| 75 |
+
# This prevents the "Lion Sleeping" failure.
|
| 76 |
+
relevance_model_name = 'cross-encoder/stsb-distilroberta-base'
|
| 77 |
+
|
| 78 |
+
# GATE 2: Fact Checking (Is the answer supported by the text?)
|
| 79 |
+
# DeBERTa-v3 is state-of-the-art for NLI.
|
|
|
|
|
|
|
| 80 |
nli_model_name = 'cross-encoder/nli-deberta-v3-xsmall'
|
| 81 |
|
| 82 |
+
print(f"Loading Models...\n1. {relevance_model_name}\n2. {nli_model_name}")
|
| 83 |
+
# We load them once.
|
| 84 |
+
rel_model = CrossEncoder(relevance_model_name, device="cpu")
|
| 85 |
nli_model = CrossEncoder(nli_model_name, device="cpu")
|
| 86 |
+
print("✅ System Ready.")
|
| 87 |
|
| 88 |
def evaluate_response(kb, question, user_answer):
|
| 89 |
if not kb or not question or not user_answer:
|
| 90 |
+
return "⚠️ Error: Missing Input", {}, "N/A"
|
| 91 |
|
| 92 |
+
logs = {} # Dictionary to store debug info
|
| 93 |
+
|
| 94 |
+
# --- GATE 1: RELEVANCE CHECK (STS) ---
|
| 95 |
+
# Does the answer make sense in the context of the question?
|
| 96 |
+
# STS models output a score from 0.0 to 1.0 (usually).
|
| 97 |
+
rel_score = rel_model.predict([(question, user_answer)])
|
| 98 |
+
|
| 99 |
+
# Check if the model output is raw logits or normalized
|
| 100 |
+
# STSb models usually output 0-1. If not, we clip/normalize.
|
| 101 |
+
rel_score_val = float(rel_score)
|
| 102 |
+
logs['Gate 1 Model'] = relevance_model_name
|
| 103 |
+
logs['Gate 1 Raw Score'] = f"{rel_score_val:.4f}"
|
| 104 |
+
|
| 105 |
+
# Threshold: STS scores are usually tighter.
|
| 106 |
+
# > 0.15 is usually enough to say "These sentences are related".
|
| 107 |
+
# "Lion sleeping" vs "What lion doing" should score ~0.4 - 0.6
|
| 108 |
+
RELEVANCE_THRESHOLD = 0.15
|
| 109 |
|
| 110 |
+
if rel_score_val < RELEVANCE_THRESHOLD:
|
| 111 |
+
status = "❌ INCORRECT (Irrelevant)"
|
| 112 |
+
logs['Verdict'] = "Blocked at Gate 1 (Answer unrelated to Question)"
|
| 113 |
+
return status, logs, "Blocked"
|
| 114 |
+
|
| 115 |
+
# --- GATE 2: FACT CHECKING (NLI) ---
|
| 116 |
+
# Does the Knowledge Base entail the Answer?
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 117 |
nli_logits = nli_model.predict([(kb, user_answer)])
|
| 118 |
nli_probs = F.softmax(torch.tensor(nli_logits), dim=0).tolist()
|
| 119 |
|
| 120 |
+
# DeBERTa-v3-xsmall Labels: 0: Contradiction, 1: Entailment, 2: Neutral
|
| 121 |
+
labels = ["Contradiction", "Entailment", "Neutral"]
|
|
|
|
|
|
|
| 122 |
max_idx = torch.tensor(nli_logits).argmax().item()
|
| 123 |
+
nli_verdict = labels[max_idx]
|
| 124 |
+
nli_conf = nli_probs[max_idx] * 100
|
| 125 |
+
|
| 126 |
+
logs['Gate 2 Model'] = nli_model_name
|
| 127 |
+
logs['Gate 2 Probabilities'] = {
|
| 128 |
+
"Contradiction": f"{nli_probs[0]*100:.1f}%",
|
| 129 |
+
"Entailment": f"{nli_probs[1]*100:.1f}%",
|
| 130 |
+
"Neutral": f"{nli_probs[2]*100:.1f}%"
|
| 131 |
+
}
|
| 132 |
+
logs['Gate 2 Verdict'] = nli_verdict
|
| 133 |
+
|
| 134 |
+
# --- FINAL DECISION LOGIC ---
|
| 135 |
+
if nli_verdict == "Entailment":
|
| 136 |
status = "✅ CORRECT (Confirmed)"
|
| 137 |
+
logs['Final Outcome'] = "Answer is Relevant and Factual."
|
| 138 |
+
|
| 139 |
+
elif nli_verdict == "Contradiction":
|
| 140 |
+
status = "❌ INCORRECT (False Information)"
|
| 141 |
+
logs['Final Outcome'] = "Answer contradicts the text."
|
| 142 |
+
|
| 143 |
+
else: # Neutral
|
| 144 |
+
# The answer is relevant to the question, but the TEXT doesn't mention it.
|
| 145 |
+
# e.g., "The lion likes pizza." (Relevant topic, but hallucinated fact)
|
| 146 |
+
status = "❌ INCORRECT (Hallucination/Not in Text)"
|
| 147 |
+
logs['Final Outcome'] = "Answer not found in text."
|
| 148 |
+
|
| 149 |
+
return status, logs, f"{nli_verdict} ({nli_conf:.1f}%)"
|
| 150 |
|
| 151 |
# --- UI SETUP ---
|
| 152 |
+
with gr.Blocks(title="NLI Logic Engine v5 (Debug Enabled)", theme=gr.themes.Soft()) as demo:
|
| 153 |
+
gr.Markdown("## 🧠 Neural Logic Engine v5")
|
| 154 |
+
gr.Markdown("Corrected Architecture: Uses **STS (Semantic Similarity)** for Relevance and **NLI** for Fact Checking.")
|
| 155 |
|
| 156 |
with gr.Row():
|
| 157 |
+
with gr.Column(scale=1):
|
| 158 |
+
kb_input = gr.Textbox(label="Knowledge Base", lines=5, value="When a lion was resting in the jungle, a mouse began racing up and down his body for fun. The lion's sleep was disturbed, and he woke in anger.")
|
| 159 |
+
q_input = gr.Textbox(label="Question", value="What was the lion doing?")
|
| 160 |
+
a_input = gr.Textbox(label="User Answer", value="The lion was sleeping in the jungle.")
|
| 161 |
+
btn = gr.Button("Evaluate", variant="primary")
|
| 162 |
+
|
| 163 |
+
with gr.Column(scale=1):
|
| 164 |
+
verdict_out = gr.Textbox(label="Final Verdict", elem_classes="verdict")
|
| 165 |
+
nli_metric = gr.Label(label="NLI Confidence")
|
| 166 |
+
# JSON output for full transparency
|
| 167 |
+
debug_log = gr.JSON(label="System Internals (Debug Log)")
|
| 168 |
+
|
| 169 |
+
btn.click(
|
|
|
|
|
|
|
|
|
|
| 170 |
fn=evaluate_response,
|
| 171 |
inputs=[kb_input, q_input, a_input],
|
| 172 |
+
outputs=[verdict_out, debug_log, nli_metric]
|
| 173 |
)
|
| 174 |
|
| 175 |
if __name__ == "__main__":
|