heerjtdev commited on
Commit
7753020
·
verified ·
1 Parent(s): d427ea6

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +83 -69
app.py CHANGED
@@ -15,31 +15,29 @@ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
15
  # ============================================================
16
  # MODELS
17
  # ============================================================
18
- SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
19
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
20
  LLM_NAME = "google/flan-t5-base"
21
 
22
- print("Loading similarity + NLI models...")
23
- sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
24
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
25
 
26
- print("Loading LLM for atomic fact extraction...")
27
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
28
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
29
 
30
  print("✅ All models loaded")
31
 
32
  # ============================================================
33
- # CONFIGURATION
34
  # ============================================================
35
- SIM_THRESHOLD_REQUIRED = 0.55
36
- CONTRADICTION_THRESHOLD = 0.70
37
  SCHEMA_CACHE = {}
38
 
39
  # ============================================================
40
  # UTILITIES
41
  # ============================================================
42
- def split_sentences(text):
43
  return re.split(r'(?<=[.!?])\s+', text.strip())
44
 
45
  def softmax_logits(logits):
@@ -51,55 +49,71 @@ def softmax_logits(logits):
51
  def hash_key(kb, question):
52
  return hashlib.sha256((kb + question).encode()).hexdigest()
53
 
54
- def decompose_answer(answer):
55
- """Split answer into atomic claims."""
56
- parts = re.split(r'\b(?:and|because|before|after|while|then|so)\b', answer)
57
- return [p.strip() for p in parts if p.strip()]
 
 
 
58
 
59
  # ============================================================
60
- # LLM FACT EXTRACTION
61
  # ============================================================
62
- def generate_atomic_facts(kb, question):
63
  """
64
- Ask LLM to extract 1-5 atomic facts from KB that directly answer the question.
65
- Returns JSON: {"facts": [ ... ]}
66
  """
67
- prompt = f"""
68
- From the Knowledge Base, extract the character transformation of Matilda.
69
 
70
- Rules:
71
- - Identify INITIAL traits, CAUSAL EVENTS, and FINAL traits.
72
- - Use short factual statements grounded ONLY in the knowledge base.
73
- - Do NOT paraphrase the question.
74
- - Return facts that can be checked independently.
 
75
 
76
- Output strictly as JSON:
 
 
 
 
 
 
 
 
77
  {
78
  "facts": [
79
- "Initially Matilda desired a luxurious life despite her humble background",
80
- "She pretended to be wealthy and borrowed a necklace to attend the ball",
81
- "She lost the borrowed necklace, causing long-term suffering",
82
- "As a result of hardship, she became mature, humble, and grateful"
83
  ]
84
  }
85
 
 
 
 
 
 
86
  """
 
 
 
87
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
88
  outputs = llm_model.generate(
89
  **inputs,
90
- max_new_tokens=128,
91
  do_sample=False,
92
  temperature=0.0
93
  )
 
94
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
95
-
96
  try:
97
  data = json.loads(raw)
98
  facts = data.get("facts", [])
99
- except:
100
- # fallback: parse line by line if JSON fails
101
- facts = [line.strip("-• ").strip() for line in raw.split("\n") if len(line.strip()) > 3]
102
-
103
  return {
104
  "required_concepts": facts,
105
  "raw_llm_output": raw
@@ -108,68 +122,68 @@ Output strictly as JSON:
108
  # ============================================================
109
  # CORE EVALUATION
110
  # ============================================================
111
- def evaluate_answer(answer, question, kb):
112
- logs = {"inputs": {"question": question, "answer": answer, "kb_length": len(kb)}}
113
-
 
 
 
 
 
 
114
  key = hash_key(kb, question)
115
  if key not in SCHEMA_CACHE:
116
- schema = generate_atomic_facts(kb, question)
117
- SCHEMA_CACHE[key] = schema
118
-
119
  schema = SCHEMA_CACHE[key]
120
  logs["schema"] = schema
121
-
122
  claims = decompose_answer(answer)
123
  logs["claims"] = claims
124
-
125
  # ---------------- COVERAGE ----------------
126
  coverage = []
127
  covered_all = True
128
- for concept in schema["required_concepts"]:
129
- if claims:
130
 
131
- probs = softmax_logits(nli_model.predict([(c, concept)]))
132
- # index 2 = entailment for NLI DeBERTa
133
- entailment = probs[2]
134
- ok = entailment > 0.6
135
- best = entailment
136
 
 
 
 
137
 
138
- # scores = sim_model.predict([(concept, c) for c in claims])
139
- # best = float(scores.max())
140
- # ok = best >= SIM_THRESHOLD_REQUIRED
141
- else:
142
- best = 0.0
143
- ok = False
144
  coverage.append({
145
  "concept": concept,
146
- "similarity": round(best, 3),
147
  "covered": ok
148
  })
 
149
  if not ok:
150
  covered_all = False
 
151
  logs["coverage"] = coverage
152
-
153
  # ---------------- CONTRADICTIONS ----------------
154
  contradictions = []
155
  kb_sents = split_sentences(kb)
 
156
  for claim in claims:
157
  for sent in kb_sents:
158
  probs = softmax_logits(nli_model.predict([(sent, claim)]))
159
-
160
  contradiction = probs[0]
161
  entailment = probs[2]
162
 
163
- if contradiction > 0.8 and entailment < 0.2:
164
- # probs = softmax_logits(nli_model.predict([(sent, claim)]))
165
- # if probs[0] > CONTRADICTION_THRESHOLD:
166
  contradictions.append({
167
  "claim": claim,
168
  "sentence": sent,
169
- "confidence": round(probs[0] * 100, 1)
170
  })
 
171
  logs["contradictions"] = contradictions
172
-
173
  # ---------------- FINAL VERDICT ----------------
174
  if contradictions:
175
  verdict = "❌ INCORRECT (Contradiction)"
@@ -177,7 +191,7 @@ def evaluate_answer(answer, question, kb):
177
  verdict = "✅ CORRECT"
178
  else:
179
  verdict = "⚠️ PARTIALLY CORRECT"
180
-
181
  logs["final_verdict"] = verdict
182
  return verdict, logs
183
 
@@ -188,16 +202,16 @@ def run(answer, question, kb):
188
  return evaluate_answer(answer, question, kb)
189
 
190
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
191
- gr.Markdown("## 🧠 Competitive Exam Answer Checker (Robust General Version)")
192
-
193
  kb = gr.Textbox(label="Knowledge Base", lines=10)
194
  question = gr.Textbox(label="Question")
195
  answer = gr.Textbox(label="Student Answer")
196
-
197
  verdict = gr.Textbox(label="Verdict")
198
  debug = gr.JSON(label="Debug Logs")
199
-
200
  btn = gr.Button("Evaluate")
201
  btn.click(run, [answer, question, kb], [verdict, debug])
202
 
203
- demo.launch()
 
15
  # ============================================================
16
  # MODELS
17
  # ============================================================
 
18
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
19
  LLM_NAME = "google/flan-t5-base"
20
 
21
+ print("Loading NLI model...")
 
22
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
23
 
24
+ print("Loading LLM for schema extraction...")
25
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
26
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
27
 
28
  print("✅ All models loaded")
29
 
30
  # ============================================================
31
+ # CONFIG
32
  # ============================================================
33
+ ENTAILMENT_THRESHOLD = 0.6
34
+ CONTRADICTION_THRESHOLD = 0.8
35
  SCHEMA_CACHE = {}
36
 
37
  # ============================================================
38
  # UTILITIES
39
  # ============================================================
40
+ def split_sentences(text: str):
41
  return re.split(r'(?<=[.!?])\s+', text.strip())
42
 
43
  def softmax_logits(logits):
 
49
  def hash_key(kb, question):
50
  return hashlib.sha256((kb + question).encode()).hexdigest()
51
 
52
+ def decompose_answer(answer: str):
53
+ """
54
+ Conservative sentence-based decomposition.
55
+ Avoids fragments that break NLI.
56
+ """
57
+ sentences = split_sentences(answer)
58
+ return [s.strip() for s in sentences if len(s.split()) >= 5]
59
 
60
  # ============================================================
61
+ # LLM SCHEMA EXTRACTION (GENERALISABLE)
62
  # ============================================================
63
+ def generate_atomic_facts(kb: str, question: str):
64
  """
65
+ Extract minimal checkable propositions from the KB.
 
66
  """
 
 
67
 
68
+ prompt = """
69
+ You are constructing a grading schema.
70
+
71
+ Task:
72
+ From the Knowledge Base, extract the MINIMAL set of factual propositions
73
+ that a correct answer to the Question must entail.
74
 
75
+ Rules:
76
+ - Use ONLY information present in the knowledge base.
77
+ - Do NOT restate or paraphrase the question.
78
+ - Do NOT add explanations.
79
+ - Each fact must be independently checkable.
80
+ - Prefer concrete states, events, causes, or outcomes.
81
+ - Return between 2 and 6 facts.
82
+
83
+ Output STRICTLY in valid JSON:
84
  {
85
  "facts": [
86
+ "fact 1",
87
+ "fact 2",
88
+ "fact 3"
 
89
  ]
90
  }
91
 
92
+ Knowledge Base:
93
+ <<<KB>>>
94
+
95
+ Question:
96
+ <<<QUESTION>>>
97
  """
98
+
99
+ prompt = prompt.replace("<<<KB>>>", kb).replace("<<<QUESTION>>>", question)
100
+
101
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
102
  outputs = llm_model.generate(
103
  **inputs,
104
+ max_new_tokens=192,
105
  do_sample=False,
106
  temperature=0.0
107
  )
108
+
109
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
110
+
111
  try:
112
  data = json.loads(raw)
113
  facts = data.get("facts", [])
114
+ except Exception:
115
+ facts = []
116
+
 
117
  return {
118
  "required_concepts": facts,
119
  "raw_llm_output": raw
 
122
  # ============================================================
123
  # CORE EVALUATION
124
  # ============================================================
125
+ def evaluate_answer(answer: str, question: str, kb: str):
126
+ logs = {
127
+ "inputs": {
128
+ "question": question,
129
+ "answer": answer,
130
+ "kb_length": len(kb)
131
+ }
132
+ }
133
+
134
  key = hash_key(kb, question)
135
  if key not in SCHEMA_CACHE:
136
+ SCHEMA_CACHE[key] = generate_atomic_facts(kb, question)
137
+
 
138
  schema = SCHEMA_CACHE[key]
139
  logs["schema"] = schema
140
+
141
  claims = decompose_answer(answer)
142
  logs["claims"] = claims
143
+
144
  # ---------------- COVERAGE ----------------
145
  coverage = []
146
  covered_all = True
 
 
147
 
148
+ for concept in schema["required_concepts"]:
149
+ best_entailment = 0.0
 
 
 
150
 
151
+ for claim in claims:
152
+ probs = softmax_logits(nli_model.predict([(claim, concept)]))
153
+ best_entailment = max(best_entailment, probs[2]) # entailment
154
 
155
+ ok = best_entailment >= ENTAILMENT_THRESHOLD
 
 
 
 
 
156
  coverage.append({
157
  "concept": concept,
158
+ "entailment": round(best_entailment, 3),
159
  "covered": ok
160
  })
161
+
162
  if not ok:
163
  covered_all = False
164
+
165
  logs["coverage"] = coverage
166
+
167
  # ---------------- CONTRADICTIONS ----------------
168
  contradictions = []
169
  kb_sents = split_sentences(kb)
170
+
171
  for claim in claims:
172
  for sent in kb_sents:
173
  probs = softmax_logits(nli_model.predict([(sent, claim)]))
 
174
  contradiction = probs[0]
175
  entailment = probs[2]
176
 
177
+ # Conservative contradiction rule
178
+ if contradiction >= CONTRADICTION_THRESHOLD and entailment < 0.2:
 
179
  contradictions.append({
180
  "claim": claim,
181
  "sentence": sent,
182
+ "confidence": round(contradiction * 100, 1)
183
  })
184
+
185
  logs["contradictions"] = contradictions
186
+
187
  # ---------------- FINAL VERDICT ----------------
188
  if contradictions:
189
  verdict = "❌ INCORRECT (Contradiction)"
 
191
  verdict = "✅ CORRECT"
192
  else:
193
  verdict = "⚠️ PARTIALLY CORRECT"
194
+
195
  logs["final_verdict"] = verdict
196
  return verdict, logs
197
 
 
202
  return evaluate_answer(answer, question, kb)
203
 
204
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
205
+ gr.Markdown("## 🧠 Competitive Exam Answer Checker")
206
+
207
  kb = gr.Textbox(label="Knowledge Base", lines=10)
208
  question = gr.Textbox(label="Question")
209
  answer = gr.Textbox(label="Student Answer")
210
+
211
  verdict = gr.Textbox(label="Verdict")
212
  debug = gr.JSON(label="Debug Logs")
213
+
214
  btn = gr.Button("Evaluate")
215
  btn.click(run, [answer, question, kb], [verdict, debug])
216
 
217
+ demo.launch()