heerjtdev commited on
Commit
b267053
Β·
verified Β·
1 Parent(s): e78173e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +27 -33
app.py CHANGED
@@ -3,68 +3,62 @@ import torch
3
  import torch.nn.functional as F
4
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
5
 
6
- # Force CPU usage for the Free Tier
7
  device = "cpu"
8
 
9
- # Load models
10
- print("Loading models on CPU...")
11
  sim_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
12
- nli_model = CrossEncoder('cross-encoder/nli-distilroberta-base', device=device)
 
13
 
14
  def evaluate_response(kb, question, user_answer):
15
- # --- GATE 1: RELEVANCE ---
16
  q_emb = sim_model.encode(question, convert_to_tensor=True, device=device)
17
  a_emb = sim_model.encode(user_answer, convert_to_tensor=True, device=device)
18
  relevance_score = util.cos_sim(q_emb, a_emb).item()
19
 
20
- # --- GATE 2: FACTUALITY ---
21
- hypothesis = f"The answer to the question '{question}' is '{user_answer}'"
22
  logits = nli_model.predict([(kb, hypothesis)])
23
  probabilities = F.softmax(torch.tensor(logits), dim=1).tolist()[0]
24
 
 
25
  labels = ["CONTRADICTION", "ENTAILMENT", "NEUTRAL"]
26
  max_idx = torch.tensor(logits).argmax().item()
27
  verdict = labels[max_idx]
28
  confidence = probabilities[max_idx] * 100
29
 
30
- # --- DECISION LOGIC ---
31
- if verdict == "CONTRADICTION" and confidence > 60:
 
32
  status = "❌ INCORRECT (Fact Mismatch)"
33
- color = "#ff4b4b"
34
- elif verdict == "ENTAILMENT" and confidence > 45:
35
  status = "βœ… CORRECT (Directly Supported)"
36
- color = "#2ecc71"
37
- elif relevance_score > 0.30 and verdict != "CONTRADICTION":
38
  status = "βœ… CORRECT (Inferred)"
39
- color = "#f1c40f"
40
  else:
41
- status = "❌ IRRELEVANT / WRONG"
42
- color = "#95a5a6"
43
 
44
  return status, f"{relevance_score:.2f}", f"{verdict} ({confidence:.1f}%)"
45
 
46
- # Interactive UI
47
- with gr.Blocks(title="AI Answer Checker") as demo:
48
- gr.Markdown("# 🧠 Smart Answer Verifier")
49
- gr.Markdown("Test how well an answer matches the context provided.")
50
 
51
  with gr.Row():
52
  with gr.Column():
53
- kb_input = gr.Textbox(label="Knowledge Base (Context)", placeholder="Paste your text here...", lines=6)
54
- q_input = gr.Textbox(label="The Question", placeholder="What do you want to ask?")
55
- ans_input = gr.Textbox(label="User's Answer", placeholder="What did the user say?")
56
- btn = gr.Button("Analyze Answer", variant="primary")
57
-
58
  with gr.Column():
59
- verdict_out = gr.Textbox(label="Final Verdict")
60
- rel_out = gr.Label(label="Relevance Score (0 to 1)")
61
- nli_out = gr.Label(label="NLI Confidence")
62
 
63
- btn.click(
64
- fn=evaluate_response,
65
- inputs=[kb_input, q_input, ans_input],
66
- outputs=[verdict_out, rel_out, nli_out]
67
- )
68
 
69
  if __name__ == "__main__":
70
  demo.launch()
 
3
  import torch.nn.functional as F
4
  from sentence_transformers import SentenceTransformer, CrossEncoder, util
5
 
6
+ # Optimized for Free Tier CPU
7
  device = "cpu"
8
 
9
+ # UPGRADED MODELS
10
+ # 1. Similarity: Lightweight and fast
11
  sim_model = SentenceTransformer('all-MiniLM-L6-v2', device=device)
12
+ # 2. Reasoning: DeBERTa-v3-base is significantly better at logic than DistilRoBERTa
13
+ nli_model = CrossEncoder('cross-encoder/nli-deberta-v3-base', device=device)
14
 
15
  def evaluate_response(kb, question, user_answer):
16
+ # GATE 1: RELEVANCE
17
  q_emb = sim_model.encode(question, convert_to_tensor=True, device=device)
18
  a_emb = sim_model.encode(user_answer, convert_to_tensor=True, device=device)
19
  relevance_score = util.cos_sim(q_emb, a_emb).item()
20
 
21
+ # GATE 2: FACTUALITY (The Reasoning Step)
22
+ hypothesis = f"Question: {question} Answer: {user_answer}"
23
  logits = nli_model.predict([(kb, hypothesis)])
24
  probabilities = F.softmax(torch.tensor(logits), dim=1).tolist()[0]
25
 
26
+ # DeBERTa-v3 Label Mapping: 0: contradiction, 1: entailment, 2: neutral
27
  labels = ["CONTRADICTION", "ENTAILMENT", "NEUTRAL"]
28
  max_idx = torch.tensor(logits).argmax().item()
29
  verdict = labels[max_idx]
30
  confidence = probabilities[max_idx] * 100
31
 
32
+ # UPGRADED DECISION LOGIC
33
+ # We trust DeBERTa more, so we can be slightly more rigid with its logic
34
+ if verdict == "CONTRADICTION" and confidence > 55:
35
  status = "❌ INCORRECT (Fact Mismatch)"
36
+ elif verdict == "ENTAILMENT" and confidence > 40:
 
37
  status = "βœ… CORRECT (Directly Supported)"
38
+ elif relevance_score > 0.35 and verdict == "NEUTRAL":
 
39
  status = "βœ… CORRECT (Inferred)"
 
40
  else:
41
+ status = "❌ IRRELEVANT / LOGICALLY WEAK"
 
42
 
43
  return status, f"{relevance_score:.2f}", f"{verdict} ({confidence:.1f}%)"
44
 
45
+ # Interface setup (same as before)
46
+ with gr.Blocks(title="Advanced Reasoning Verifier") as demo:
47
+ gr.Markdown("# 🧠 Advanced Answer Verifier (DeBERTa-v3)")
48
+ gr.Markdown("Using high-performance Cross-Encoders for superior logical reasoning.")
49
 
50
  with gr.Row():
51
  with gr.Column():
52
+ kb_input = gr.Textbox(label="Knowledge Base", lines=6)
53
+ q_input = gr.Textbox(label="Question")
54
+ ans_input = gr.Textbox(label="User Answer")
55
+ btn = gr.Button("Analyze", variant="primary")
 
56
  with gr.Column():
57
+ verdict_out = gr.Textbox(label="Verdict")
58
+ rel_out = gr.Label(label="Similarity")
59
+ nli_out = gr.Label(label="NLI Reasoning")
60
 
61
+ btn.click(fn=evaluate_response, inputs=[kb_input, q_input, ans_input], outputs=[verdict_out, rel_out, nli_out])
 
 
 
 
62
 
63
  if __name__ == "__main__":
64
  demo.launch()