heerjtdev commited on
Commit
d0c9fda
·
verified ·
1 Parent(s): d14dbdf

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +76 -34
app.py CHANGED
@@ -5,7 +5,6 @@ from sentence_transformers import CrossEncoder
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import re
7
  import hashlib
8
- import json
9
 
10
  # ============================================================
11
  # DEVICE
@@ -21,23 +20,24 @@ SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
21
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
22
  LLM_NAME = "google/flan-t5-base"
23
 
24
- print("Loading similarity + NLI models...")
25
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
 
 
26
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
27
 
28
  print("Loading LLM for schema generation...")
29
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
30
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
31
 
32
- print("✅ All models loaded successfully")
33
 
34
  # ============================================================
35
- # CONFIGURATION
36
  # ============================================================
37
 
38
  SIM_THRESHOLD_REQUIRED = 0.55
39
  CONTRADICTION_THRESHOLD = 0.70
40
-
41
  SCHEMA_CACHE = {}
42
 
43
  # ============================================================
@@ -56,13 +56,34 @@ def softmax_logits(logits):
56
  def hash_key(kb, question):
57
  return hashlib.sha256((kb + question).encode()).hexdigest()
58
 
 
 
 
 
 
 
 
 
 
 
59
  # ============================================================
60
- # LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR)
61
  # ============================================================
62
 
63
  def generate_schema_with_llm(kb, question):
 
 
64
  prompt = f"""
65
- From the knowledge base below, answer the question using short factual points.
 
 
 
 
 
 
 
 
 
66
 
67
  Knowledge Base:
68
  {kb}
@@ -70,7 +91,7 @@ Knowledge Base:
70
  Question:
71
  {question}
72
 
73
- Write 1–3 short factual bullet points. Do NOT explain.
74
  """
75
 
76
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
@@ -78,35 +99,46 @@ Write 1–3 short factual bullet points. Do NOT explain.
78
  outputs = llm_model.generate(
79
  **inputs,
80
  max_new_tokens=128,
81
- temperature=0.0,
82
- do_sample=False
83
  )
84
 
85
- text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
86
 
87
- # Extract bullet-like facts
88
  facts = [
89
  line.strip("-• ").strip()
90
- for line in text.split("\n")
91
- if len(line.strip()) > 3
92
  ]
93
 
94
  return {
95
- "question_type": "FACT",
96
  "required_concepts": facts,
97
- "forbidden_concepts": [],
98
  "allow_extra_info": True,
99
- "raw_llm_output": text
100
  }
101
 
 
 
 
 
 
 
 
 
 
 
 
 
 
102
 
103
  # ============================================================
104
  # ANSWER DECOMPOSITION
105
  # ============================================================
106
 
107
  def decompose_answer(answer):
108
- clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer)
109
- return [c.strip() for c in clauses if c.strip()]
110
 
111
  # ============================================================
112
  # CORE EVALUATION
@@ -121,42 +153,53 @@ def evaluate_answer(answer, question, kb):
121
  }
122
  }
123
 
124
- # ---------------- SCHEMA ----------------
125
  key = hash_key(kb, question)
 
126
  if key not in SCHEMA_CACHE:
127
- SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)
 
 
 
 
 
 
 
 
 
 
 
128
 
129
  schema = SCHEMA_CACHE[key]
130
  logs["schema"] = schema
131
 
132
- # ---------------- ANSWER CLAIMS ----------------
133
  claims = decompose_answer(answer)
134
  logs["claims"] = claims
135
 
136
- # ---------------- REQUIRED COVERAGE ----------------
137
  coverage = []
138
  covered_all = True
139
 
140
- for concept in schema.get("required_concepts", []):
141
  scores = sim_model.predict([(concept, c) for c in claims])
142
  best = float(scores.max())
143
- covered = best >= SIM_THRESHOLD_REQUIRED
 
144
  coverage.append({
145
  "concept": concept,
146
  "similarity": round(best, 3),
147
- "covered": covered
148
  })
149
- if not covered:
 
150
  covered_all = False
151
 
152
  logs["coverage"] = coverage
153
 
154
- # ---------------- CONTRADICTION CHECK (FIXED) ----------------
155
  contradictions = []
156
- relevant_kb = schema.get("required_concepts", [])
157
 
158
  for claim in claims:
159
- for sent in relevant_kb:
160
  probs = softmax_logits(nli_model.predict([(sent, claim)]))
161
  if probs[0] > CONTRADICTION_THRESHOLD:
162
  contradictions.append({
@@ -167,7 +210,7 @@ def evaluate_answer(answer, question, kb):
167
 
168
  logs["contradictions"] = contradictions
169
 
170
- # ---------------- FINAL DECISION ----------------
171
  if contradictions:
172
  verdict = "❌ INCORRECT (Contradiction)"
173
  elif covered_all:
@@ -178,7 +221,6 @@ def evaluate_answer(answer, question, kb):
178
  logs["final_verdict"] = verdict
179
  return verdict, logs
180
 
181
-
182
  # ============================================================
183
  # GRADIO UI
184
  # ============================================================
@@ -187,9 +229,9 @@ def run(answer, question, kb):
187
  return evaluate_answer(answer, question, kb)
188
 
189
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
190
- gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)")
191
 
192
- kb = gr.Textbox(label="Knowledge Base", lines=7)
193
  question = gr.Textbox(label="Question")
194
  answer = gr.Textbox(label="Student Answer")
195
 
 
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
6
  import re
7
  import hashlib
 
8
 
9
  # ============================================================
10
  # DEVICE
 
20
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
21
  LLM_NAME = "google/flan-t5-base"
22
 
23
+ print("Loading similarity model...")
24
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
25
+
26
+ print("Loading NLI model...")
27
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
28
 
29
  print("Loading LLM for schema generation...")
30
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
31
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
32
 
33
+ print("✅ All models loaded")
34
 
35
  # ============================================================
36
+ # CONFIG
37
  # ============================================================
38
 
39
  SIM_THRESHOLD_REQUIRED = 0.55
40
  CONTRADICTION_THRESHOLD = 0.70
 
41
  SCHEMA_CACHE = {}
42
 
43
  # ============================================================
 
56
  def hash_key(kb, question):
57
  return hashlib.sha256((kb + question).encode()).hexdigest()
58
 
59
+ def infer_question_type(question):
60
+ q = question.lower().strip()
61
+ if q.startswith("how"):
62
+ return "METHOD"
63
+ if q.startswith("why"):
64
+ return "REASON"
65
+ if q.startswith("when") or q.startswith("where"):
66
+ return "FACT"
67
+ return "FACT"
68
+
69
  # ============================================================
70
+ # LLM SCHEMA GENERATION (HARDENED)
71
  # ============================================================
72
 
73
  def generate_schema_with_llm(kb, question):
74
+ q_type = infer_question_type(question)
75
+
76
  prompt = f"""
77
+ You are extracting the correct answer to a competitive exam question.
78
+
79
+ RULES:
80
+ - ONLY extract facts that DIRECTLY answer the question.
81
+ - IGNORE unrelated events.
82
+ - If the question asks "how", extract the METHOD.
83
+ - Use short, atomic factual sentences.
84
+ - Do NOT summarize the story.
85
+
86
+ Question Type: {q_type}
87
 
88
  Knowledge Base:
89
  {kb}
 
91
  Question:
92
  {question}
93
 
94
+ Return 1–3 bullet points that directly answer the question.
95
  """
96
 
97
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
 
99
  outputs = llm_model.generate(
100
  **inputs,
101
  max_new_tokens=128,
102
+ do_sample=False,
103
+ temperature=0.0
104
  )
105
 
106
+ raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
107
 
 
108
  facts = [
109
  line.strip("-• ").strip()
110
+ for line in raw.split("\n")
111
+ if len(line.strip()) > 4
112
  ]
113
 
114
  return {
115
+ "question_type": q_type,
116
  "required_concepts": facts,
 
117
  "allow_extra_info": True,
118
+ "raw_llm_output": raw
119
  }
120
 
121
+ # ============================================================
122
+ # SCHEMA VALIDATION (CRITICAL)
123
+ # ============================================================
124
+
125
+ def validate_schema(schema, question):
126
+ q_words = set(question.lower().split())
127
+ valid = []
128
+
129
+ for c in schema["required_concepts"]:
130
+ if q_words & set(c.lower().split()):
131
+ valid.append(c)
132
+
133
+ return valid
134
 
135
  # ============================================================
136
  # ANSWER DECOMPOSITION
137
  # ============================================================
138
 
139
  def decompose_answer(answer):
140
+ parts = re.split(r'\b(?:and|because|before|after|while)\b', answer)
141
+ return [p.strip() for p in parts if p.strip()]
142
 
143
  # ============================================================
144
  # CORE EVALUATION
 
153
  }
154
  }
155
 
 
156
  key = hash_key(kb, question)
157
+
158
  if key not in SCHEMA_CACHE:
159
+ schema = generate_schema_with_llm(kb, question)
160
+ validated = validate_schema(schema, question)
161
+
162
+ if not validated:
163
+ # fallback: keyword-based extraction
164
+ validated = [
165
+ s for s in split_sentences(kb)
166
+ if any(w in s.lower() for w in question.lower().split())
167
+ ][:2]
168
+
169
+ schema["required_concepts"] = validated
170
+ SCHEMA_CACHE[key] = schema
171
 
172
  schema = SCHEMA_CACHE[key]
173
  logs["schema"] = schema
174
 
 
175
  claims = decompose_answer(answer)
176
  logs["claims"] = claims
177
 
178
+ # ---------------- COVERAGE ----------------
179
  coverage = []
180
  covered_all = True
181
 
182
+ for concept in schema["required_concepts"]:
183
  scores = sim_model.predict([(concept, c) for c in claims])
184
  best = float(scores.max())
185
+ ok = best >= SIM_THRESHOLD_REQUIRED
186
+
187
  coverage.append({
188
  "concept": concept,
189
  "similarity": round(best, 3),
190
+ "covered": ok
191
  })
192
+
193
+ if not ok:
194
  covered_all = False
195
 
196
  logs["coverage"] = coverage
197
 
198
+ # ---------------- CONTRADICTIONS ----------------
199
  contradictions = []
 
200
 
201
  for claim in claims:
202
+ for sent in schema["required_concepts"]:
203
  probs = softmax_logits(nli_model.predict([(sent, claim)]))
204
  if probs[0] > CONTRADICTION_THRESHOLD:
205
  contradictions.append({
 
210
 
211
  logs["contradictions"] = contradictions
212
 
213
+ # ---------------- FINAL VERDICT ----------------
214
  if contradictions:
215
  verdict = "❌ INCORRECT (Contradiction)"
216
  elif covered_all:
 
221
  logs["final_verdict"] = verdict
222
  return verdict, logs
223
 
 
224
  # ============================================================
225
  # GRADIO UI
226
  # ============================================================
 
229
  return evaluate_answer(answer, question, kb)
230
 
231
  with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
232
+ gr.Markdown("## 🧠 Competitive Exam Answer Checker")
233
 
234
+ kb = gr.Textbox(label="Knowledge Base", lines=8)
235
  question = gr.Textbox(label="Question")
236
  answer = gr.Textbox(label="Student Answer")
237