heerjtdev commited on
Commit
b79fdd7
·
verified ·
1 Parent(s): c53835a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +60 -96
app.py CHANGED
@@ -5,39 +5,40 @@ from sentence_transformers import CrossEncoder
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import re
7
  import hashlib
 
8
 
9
  # ============================================================
10
  # DEVICE
11
  # ============================================================
12
-
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ============================================================
16
  # MODELS
17
  # ============================================================
18
-
19
  SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
20
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
21
  LLM_NAME = "google/flan-t5-base"
22
 
 
23
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
24
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
25
 
 
26
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
27
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
28
 
 
 
29
  # ============================================================
30
- # CONFIG
31
  # ============================================================
32
-
33
- SIM_THRESHOLD = 0.60
34
  CONTRADICTION_THRESHOLD = 0.70
35
  SCHEMA_CACHE = {}
36
 
37
  # ============================================================
38
  # UTILITIES
39
  # ============================================================
40
-
41
  def split_sentences(text):
42
  return re.split(r'(?<=[.!?])\s+', text.strip())
43
 
@@ -50,35 +51,21 @@ def softmax_logits(logits):
50
  def hash_key(kb, question):
51
  return hashlib.sha256((kb + question).encode()).hexdigest()
52
 
53
- def infer_question_type(question):
54
- q = question.lower()
55
- if q.startswith("how"):
56
- return "METHOD"
57
- if q.startswith("why"):
58
- return "REASON"
59
- return "FACT"
60
-
61
  def decompose_answer(answer):
62
- return [s.strip() for s in split_sentences(answer) if len(s.strip()) > 3]
 
 
63
 
64
  # ============================================================
65
- # 🔥 ACTION-FOCUSED SCHEMA GENERATION (FIXED)
66
  # ============================================================
67
-
68
- def generate_schema_with_llm(kb, question):
69
- q_type = infer_question_type(question)
70
-
 
71
  prompt = f"""
72
- You are extracting the exact answer to a competitive exam question.
73
-
74
- STRICT RULES:
75
- - Extract ONLY the direct action that answers the question.
76
- - DO NOT include background events.
77
- - DO NOT include earlier or later story details.
78
- - Use ACTIVE VERBS.
79
- - Keep answers short (one clause).
80
-
81
- Question type: {q_type}
82
 
83
  Knowledge Base:
84
  {kb}
@@ -86,27 +73,29 @@ Knowledge Base:
86
  Question:
87
  {question}
88
 
89
- Return the answer as bullet points.
 
 
 
 
90
  """
91
-
92
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
93
  outputs = llm_model.generate(
94
  **inputs,
95
- max_new_tokens=80,
96
- temperature=0.0,
97
- do_sample=False
98
  )
99
-
100
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
101
-
102
- facts = [
103
- line.strip("-• ").strip()
104
- for line in raw.split("\n")
105
- if len(line.strip()) > 4
106
- ]
107
-
 
108
  return {
109
- "question_type": q_type,
110
  "required_concepts": facts,
111
  "raw_llm_output": raw
112
  }
@@ -114,106 +103,81 @@ Return the answer as bullet points.
114
  # ============================================================
115
  # CORE EVALUATION
116
  # ============================================================
117
-
118
  def evaluate_answer(answer, question, kb):
119
- logs = {
120
- "inputs": {
121
- "question": question,
122
- "answer": answer,
123
- "kb_length": len(kb)
124
- }
125
- }
126
-
127
  key = hash_key(kb, question)
128
-
129
  if key not in SCHEMA_CACHE:
130
- schema = generate_schema_with_llm(kb, question)
131
-
132
- # HARD FILTER: must contain an ACTION VERB
133
- action_schema = []
134
- for s in schema["required_concepts"]:
135
- if re.search(r'\b(bit|cut|free|help|rescue|save)\b', s.lower()):
136
- action_schema.append(s)
137
-
138
- # Fallback: extract action sentences directly from KB
139
- if not action_schema:
140
- action_schema = [
141
- s for s in split_sentences(kb)
142
- if re.search(r'\b(bit|cut|free|help|rescue|save)\b', s.lower())
143
- ]
144
-
145
- schema["required_concepts"] = action_schema[:2]
146
  SCHEMA_CACHE[key] = schema
147
-
148
  schema = SCHEMA_CACHE[key]
149
  logs["schema"] = schema
150
-
151
  claims = decompose_answer(answer)
152
  logs["claims"] = claims
153
-
154
  # ---------------- COVERAGE ----------------
155
  coverage = []
156
  covered_all = True
157
-
158
  for concept in schema["required_concepts"]:
159
- scores = sim_model.predict([(concept, c) for c in claims])
160
- best = float(scores.max())
161
- ok = best >= SIM_THRESHOLD
162
-
 
 
 
163
  coverage.append({
164
  "concept": concept,
165
  "similarity": round(best, 3),
166
  "covered": ok
167
  })
168
-
169
  if not ok:
170
  covered_all = False
171
-
172
  logs["coverage"] = coverage
173
-
174
  # ---------------- CONTRADICTIONS ----------------
175
  contradictions = []
176
-
177
  for claim in claims:
178
- for ref in schema["required_concepts"]:
179
- probs = softmax_logits(nli_model.predict([(ref, claim)]))
180
  if probs[0] > CONTRADICTION_THRESHOLD:
181
  contradictions.append({
182
  "claim": claim,
183
- "against": ref,
184
  "confidence": round(probs[0] * 100, 1)
185
  })
186
-
187
  logs["contradictions"] = contradictions
188
-
189
- # ---------------- VERDICT ----------------
190
  if contradictions:
191
- verdict = "❌ INCORRECT"
192
  elif covered_all:
193
  verdict = "✅ CORRECT"
194
  else:
195
  verdict = "⚠️ PARTIALLY CORRECT"
196
-
197
  logs["final_verdict"] = verdict
198
  return verdict, logs
199
 
200
  # ============================================================
201
  # GRADIO UI
202
  # ============================================================
203
-
204
  def run(answer, question, kb):
205
  return evaluate_answer(answer, question, kb)
206
 
207
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
208
- gr.Markdown("## 🧠 Competitive Exam Answer Checker")
209
-
210
- kb = gr.Textbox(label="Knowledge Base", lines=8)
211
  question = gr.Textbox(label="Question")
212
  answer = gr.Textbox(label="Student Answer")
213
-
214
  verdict = gr.Textbox(label="Verdict")
215
  debug = gr.JSON(label="Debug Logs")
216
-
217
  btn = gr.Button("Evaluate")
218
  btn.click(run, [answer, question, kb], [verdict, debug])
219
 
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import re
7
  import hashlib
8
+ import json
9
 
10
  # ============================================================
11
  # DEVICE
12
  # ============================================================
 
13
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
  # ============================================================
16
  # MODELS
17
  # ============================================================
 
18
  SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
19
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
20
  LLM_NAME = "google/flan-t5-base"
21
 
22
+ print("Loading similarity + NLI models...")
23
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
24
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
25
 
26
+ print("Loading LLM for atomic fact extraction...")
27
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
28
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
29
 
30
+ print("✅ All models loaded")
31
+
32
  # ============================================================
33
+ # CONFIGURATION
34
  # ============================================================
35
+ SIM_THRESHOLD_REQUIRED = 0.55
 
36
  CONTRADICTION_THRESHOLD = 0.70
37
  SCHEMA_CACHE = {}
38
 
39
  # ============================================================
40
  # UTILITIES
41
  # ============================================================
 
42
  def split_sentences(text):
43
  return re.split(r'(?<=[.!?])\s+', text.strip())
44
 
 
51
  def hash_key(kb, question):
52
  return hashlib.sha256((kb + question).encode()).hexdigest()
53
 
 
 
 
 
 
 
 
 
54
  def decompose_answer(answer):
55
+ """Split answer into atomic claims."""
56
+ parts = re.split(r'\b(?:and|because|before|after|while|then|so)\b', answer)
57
+ return [p.strip() for p in parts if p.strip()]
58
 
59
  # ============================================================
60
+ # LLM FACT EXTRACTION
61
  # ============================================================
62
+ def generate_atomic_facts(kb, question):
63
+ """
64
+ Ask LLM to extract 1-5 atomic facts from KB that directly answer the question.
65
+ Returns JSON: {"facts": [ ... ]}
66
+ """
67
  prompt = f"""
68
+ Extract atomic facts that directly answer the question.
 
 
 
 
 
 
 
 
 
69
 
70
  Knowledge Base:
71
  {kb}
 
73
  Question:
74
  {question}
75
 
76
+ RULES:
77
+ - Return 1-5 short factual statements that directly answer the question.
78
+ - Output strictly in JSON format: {{"facts": ["fact1", "fact2", ...]}}
79
+ - Do not include unrelated events or explanations.
80
+ - Each fact should be self-contained.
81
  """
 
82
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
83
  outputs = llm_model.generate(
84
  **inputs,
85
+ max_new_tokens=128,
86
+ do_sample=False,
87
+ temperature=0.0
88
  )
 
89
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
90
+
91
+ try:
92
+ data = json.loads(raw)
93
+ facts = data.get("facts", [])
94
+ except:
95
+ # fallback: parse line by line if JSON fails
96
+ facts = [line.strip("-• ").strip() for line in raw.split("\n") if len(line.strip()) > 3]
97
+
98
  return {
 
99
  "required_concepts": facts,
100
  "raw_llm_output": raw
101
  }
 
103
  # ============================================================
104
  # CORE EVALUATION
105
  # ============================================================
 
106
  def evaluate_answer(answer, question, kb):
107
+ logs = {"inputs": {"question": question, "answer": answer, "kb_length": len(kb)}}
108
+
 
 
 
 
 
 
109
  key = hash_key(kb, question)
 
110
  if key not in SCHEMA_CACHE:
111
+ schema = generate_atomic_facts(kb, question)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  SCHEMA_CACHE[key] = schema
113
+
114
  schema = SCHEMA_CACHE[key]
115
  logs["schema"] = schema
116
+
117
  claims = decompose_answer(answer)
118
  logs["claims"] = claims
119
+
120
  # ---------------- COVERAGE ----------------
121
  coverage = []
122
  covered_all = True
 
123
  for concept in schema["required_concepts"]:
124
+ if claims:
125
+ scores = sim_model.predict([(concept, c) for c in claims])
126
+ best = float(scores.max())
127
+ ok = best >= SIM_THRESHOLD_REQUIRED
128
+ else:
129
+ best = 0.0
130
+ ok = False
131
  coverage.append({
132
  "concept": concept,
133
  "similarity": round(best, 3),
134
  "covered": ok
135
  })
 
136
  if not ok:
137
  covered_all = False
 
138
  logs["coverage"] = coverage
139
+
140
  # ---------------- CONTRADICTIONS ----------------
141
  contradictions = []
142
+ kb_sents = split_sentences(kb)
143
  for claim in claims:
144
+ for sent in kb_sents:
145
+ probs = softmax_logits(nli_model.predict([(sent, claim)]))
146
  if probs[0] > CONTRADICTION_THRESHOLD:
147
  contradictions.append({
148
  "claim": claim,
149
+ "sentence": sent,
150
  "confidence": round(probs[0] * 100, 1)
151
  })
 
152
  logs["contradictions"] = contradictions
153
+
154
+ # ---------------- FINAL VERDICT ----------------
155
  if contradictions:
156
+ verdict = "❌ INCORRECT (Contradiction)"
157
  elif covered_all:
158
  verdict = "✅ CORRECT"
159
  else:
160
  verdict = "⚠️ PARTIALLY CORRECT"
161
+
162
  logs["final_verdict"] = verdict
163
  return verdict, logs
164
 
165
  # ============================================================
166
  # GRADIO UI
167
  # ============================================================
 
168
  def run(answer, question, kb):
169
  return evaluate_answer(answer, question, kb)
170
 
171
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
172
+ gr.Markdown("## 🧠 Competitive Exam Answer Checker (Robust General Version)")
173
+
174
+ kb = gr.Textbox(label="Knowledge Base", lines=10)
175
  question = gr.Textbox(label="Question")
176
  answer = gr.Textbox(label="Student Answer")
177
+
178
  verdict = gr.Textbox(label="Verdict")
179
  debug = gr.JSON(label="Debug Logs")
180
+
181
  btn = gr.Button("Evaluate")
182
  btn.click(run, [answer, question, kb], [verdict, debug])
183