heerjtdev commited on
Commit
e2daaeb
·
verified ·
1 Parent(s): 4134c06

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +176 -7
app.py CHANGED
@@ -1,20 +1,189 @@
1
  import gradio as gr
2
- from pipeline.evaluator import evaluate_answer
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3
 
4
- def run(answer, question, kb):
5
- schema = load_schema(kb, question)
6
- verdict, logs = evaluate_answer(answer, question, kb, schema, MODELS)
7
  return verdict, logs
8
 
9
- with gr.Blocks() as demo:
 
 
 
 
 
 
 
 
 
10
  kb = gr.Textbox(label="Knowledge Base", lines=6)
11
  question = gr.Textbox(label="Question")
12
  answer = gr.Textbox(label="Student Answer")
13
 
14
  verdict = gr.Textbox(label="Verdict")
15
- logs = gr.JSON(label="Debug Logs")
16
 
17
  btn = gr.Button("Evaluate")
18
- btn.click(run, [answer, question, kb], [verdict, logs])
19
 
20
  demo.launch()
 
1
  import gradio as gr
2
+ import torch
3
+ import torch.nn.functional as F
4
+ from sentence_transformers import CrossEncoder
5
+ import re
6
+ import hashlib
7
+ import json
8
+
9
+ # ============================================================
10
+ # MODEL LOADING (ONCE)
11
+ # ============================================================
12
+
13
+ DEVICE = "cpu"
14
+
15
+ SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
16
+ NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
17
+
18
+ print("Loading models...")
19
+ sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
20
+ nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
21
+ print("✅ Models ready")
22
+
23
+ # ============================================================
24
+ # CONFIGURATION
25
+ # ============================================================
26
+
27
+ SIM_THRESHOLD_REQUIRED = 0.55
28
+ SIM_THRESHOLD_FORBIDDEN = 0.60
29
+ ENTAILMENT_THRESHOLD = 0.65
30
+
31
+ SCHEMA_CACHE = {} # in-memory cache (HF-safe)
32
+
33
+ # ============================================================
34
+ # UTILITIES
35
+ # ============================================================
36
+
37
+ def split_sentences(text):
38
+ return re.split(r'(?<=[.!?])\s+', text.strip())
39
+
40
+ def softmax_logits(logits):
41
+ t = torch.tensor(logits)
42
+ if t.dim() > 1:
43
+ t = t.squeeze(0)
44
+ return F.softmax(t, dim=0).tolist()
45
+
46
+ def hash_key(kb, question):
47
+ return hashlib.sha256((kb + question).encode()).hexdigest()
48
+
49
+ # ============================================================
50
+ # QUESTION CLASSIFIER
51
+ # ============================================================
52
+
53
+ def classify_question(question):
54
+ q = question.lower()
55
+ if q.startswith("what was") or q.startswith("who"):
56
+ return "FACT"
57
+ if q.startswith("define"):
58
+ return "DEFINITION"
59
+ if "explain" in q or "why" in q:
60
+ return "EXPLANATION"
61
+ return "FACT"
62
+
63
+ # ============================================================
64
+ # SCHEMA GENERATION (AUTO, NO LLM)
65
+ # ============================================================
66
+
67
+ def generate_schema(kb, question):
68
+ """
69
+ Auto-generates a grading schema directly from KB.
70
+ Deterministic and HF-safe.
71
+ """
72
+ sentences = split_sentences(kb)
73
+ q_type = classify_question(question)
74
+
75
+ # Find most relevant sentence
76
+ scores = sim_model.predict([(s, question) for s in sentences])
77
+ best_idx = int(scores.argmax())
78
+ best_sentence = sentences[best_idx]
79
+
80
+ schema = {
81
+ "question_type": q_type,
82
+ "required_concepts": [best_sentence],
83
+ "forbidden_concepts": [],
84
+ "allow_extra_info": True
85
+ }
86
+ return schema
87
+
88
+ # ============================================================
89
+ # ANSWER DECOMPOSITION
90
+ # ============================================================
91
+
92
+ def decompose_answer(answer):
93
+ clauses = re.split(r'\b(?:and|before|after|because|while)\b', answer)
94
+ return [c.strip() for c in clauses if c.strip()]
95
+
96
+ # ============================================================
97
+ # CORE EVALUATION
98
+ # ============================================================
99
+
100
+ def evaluate_answer(answer, question, kb):
101
+ logs = {}
102
+
103
+ # --------------------
104
+ # SCHEMA LOAD / CREATE
105
+ # --------------------
106
+ key = hash_key(kb, question)
107
+ if key not in SCHEMA_CACHE:
108
+ SCHEMA_CACHE[key] = generate_schema(kb, question)
109
+
110
+ schema = SCHEMA_CACHE[key]
111
+ logs["schema"] = schema
112
+
113
+ # --------------------
114
+ # ANSWER PARSING
115
+ # --------------------
116
+ claims = decompose_answer(answer)
117
+ logs["answer_claims"] = claims
118
+
119
+ # --------------------
120
+ # REQUIRED CONCEPT CHECK
121
+ # --------------------
122
+ required = schema["required_concepts"]
123
+ coverage = []
124
+
125
+ for req in required:
126
+ scores = sim_model.predict([(req, c) for c in claims])
127
+ best = float(scores.max())
128
+ coverage.append({
129
+ "concept": req,
130
+ "max_similarity": round(best, 3),
131
+ "covered": best >= SIM_THRESHOLD_REQUIRED
132
+ })
133
+
134
+ logs["required_coverage"] = coverage
135
+ covered_all = all(c["covered"] for c in coverage)
136
+
137
+ # --------------------
138
+ # CONTRADICTION CHECK (NLI)
139
+ # --------------------
140
+ kb_sentences = split_sentences(kb)
141
+ contradictions = []
142
+
143
+ for claim in claims:
144
+ for sent in kb_sentences:
145
+ probs = softmax_logits(nli_model.predict([(sent, claim)]))
146
+ if probs[0] > 0.70: # Contradiction
147
+ contradictions.append({
148
+ "claim": claim,
149
+ "sentence": sent,
150
+ "confidence": round(probs[0] * 100, 1)
151
+ })
152
+
153
+ logs["contradictions"] = contradictions
154
+
155
+ # --------------------
156
+ # FINAL DECISION
157
+ # --------------------
158
+ if covered_all and not contradictions:
159
+ verdict = "✅ CORRECT"
160
+ elif contradictions:
161
+ verdict = "❌ INCORRECT (Contradiction)"
162
+ else:
163
+ verdict = "⚠️ PARTIALLY CORRECT"
164
+
165
+ logs["final_decision"] = verdict
166
 
 
 
 
167
  return verdict, logs
168
 
169
+ # ============================================================
170
+ # GRADIO UI
171
+ # ============================================================
172
+
173
+ def run(answer, question, kb):
174
+ return evaluate_answer(answer, question, kb)
175
+
176
+ with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
177
+ gr.Markdown("## 🧠 Competitive Exam Answer Checker (Single-File Engine)")
178
+
179
  kb = gr.Textbox(label="Knowledge Base", lines=6)
180
  question = gr.Textbox(label="Question")
181
  answer = gr.Textbox(label="Student Answer")
182
 
183
  verdict = gr.Textbox(label="Verdict")
184
+ debug = gr.JSON(label="Debug Logs")
185
 
186
  btn = gr.Button("Evaluate")
187
+ btn.click(run, [answer, question, kb], [verdict, debug])
188
 
189
  demo.launch()