heerjtdev commited on
Commit
c53835a
·
verified ·
1 Parent(s): d0c9fda

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +36 -60
app.py CHANGED
@@ -20,23 +20,17 @@ SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
20
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
21
  LLM_NAME = "google/flan-t5-base"
22
 
23
- print("Loading similarity model...")
24
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
25
-
26
- print("Loading NLI model...")
27
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
28
 
29
- print("Loading LLM for schema generation...")
30
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
31
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
32
 
33
- print("✅ All models loaded")
34
-
35
  # ============================================================
36
  # CONFIG
37
  # ============================================================
38
 
39
- SIM_THRESHOLD_REQUIRED = 0.55
40
  CONTRADICTION_THRESHOLD = 0.70
41
  SCHEMA_CACHE = {}
42
 
@@ -57,33 +51,34 @@ def hash_key(kb, question):
57
  return hashlib.sha256((kb + question).encode()).hexdigest()
58
 
59
  def infer_question_type(question):
60
- q = question.lower().strip()
61
  if q.startswith("how"):
62
  return "METHOD"
63
  if q.startswith("why"):
64
  return "REASON"
65
- if q.startswith("when") or q.startswith("where"):
66
- return "FACT"
67
  return "FACT"
68
 
 
 
 
69
  # ============================================================
70
- # LLM SCHEMA GENERATION (HARDENED)
71
  # ============================================================
72
 
73
  def generate_schema_with_llm(kb, question):
74
  q_type = infer_question_type(question)
75
 
76
  prompt = f"""
77
- You are extracting the correct answer to a competitive exam question.
78
 
79
- RULES:
80
- - ONLY extract facts that DIRECTLY answer the question.
81
- - IGNORE unrelated events.
82
- - If the question asks "how", extract the METHOD.
83
- - Use short, atomic factual sentences.
84
- - Do NOT summarize the story.
85
 
86
- Question Type: {q_type}
87
 
88
  Knowledge Base:
89
  {kb}
@@ -91,16 +86,15 @@ Knowledge Base:
91
  Question:
92
  {question}
93
 
94
- Return 1–3 bullet points that directly answer the question.
95
  """
96
 
97
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
98
-
99
  outputs = llm_model.generate(
100
  **inputs,
101
- max_new_tokens=128,
102
- do_sample=False,
103
- temperature=0.0
104
  )
105
 
106
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
@@ -114,32 +108,9 @@ Return 1–3 bullet points that directly answer the question.
114
  return {
115
  "question_type": q_type,
116
  "required_concepts": facts,
117
- "allow_extra_info": True,
118
  "raw_llm_output": raw
119
  }
120
 
121
- # ============================================================
122
- # SCHEMA VALIDATION (CRITICAL)
123
- # ============================================================
124
-
125
- def validate_schema(schema, question):
126
- q_words = set(question.lower().split())
127
- valid = []
128
-
129
- for c in schema["required_concepts"]:
130
- if q_words & set(c.lower().split()):
131
- valid.append(c)
132
-
133
- return valid
134
-
135
- # ============================================================
136
- # ANSWER DECOMPOSITION
137
- # ============================================================
138
-
139
- def decompose_answer(answer):
140
- parts = re.split(r'\b(?:and|because|before|after|while)\b', answer)
141
- return [p.strip() for p in parts if p.strip()]
142
-
143
  # ============================================================
144
  # CORE EVALUATION
145
  # ============================================================
@@ -157,16 +128,21 @@ def evaluate_answer(answer, question, kb):
157
 
158
  if key not in SCHEMA_CACHE:
159
  schema = generate_schema_with_llm(kb, question)
160
- validated = validate_schema(schema, question)
161
 
162
- if not validated:
163
- # fallback: keyword-based extraction
164
- validated = [
 
 
 
 
 
 
165
  s for s in split_sentences(kb)
166
- if any(w in s.lower() for w in question.lower().split())
167
- ][:2]
168
 
169
- schema["required_concepts"] = validated
170
  SCHEMA_CACHE[key] = schema
171
 
172
  schema = SCHEMA_CACHE[key]
@@ -182,7 +158,7 @@ def evaluate_answer(answer, question, kb):
182
  for concept in schema["required_concepts"]:
183
  scores = sim_model.predict([(concept, c) for c in claims])
184
  best = float(scores.max())
185
- ok = best >= SIM_THRESHOLD_REQUIRED
186
 
187
  coverage.append({
188
  "concept": concept,
@@ -199,20 +175,20 @@ def evaluate_answer(answer, question, kb):
199
  contradictions = []
200
 
201
  for claim in claims:
202
- for sent in schema["required_concepts"]:
203
- probs = softmax_logits(nli_model.predict([(sent, claim)]))
204
  if probs[0] > CONTRADICTION_THRESHOLD:
205
  contradictions.append({
206
  "claim": claim,
207
- "sentence": sent,
208
  "confidence": round(probs[0] * 100, 1)
209
  })
210
 
211
  logs["contradictions"] = contradictions
212
 
213
- # ---------------- FINAL VERDICT ----------------
214
  if contradictions:
215
- verdict = "❌ INCORRECT (Contradiction)"
216
  elif covered_all:
217
  verdict = "✅ CORRECT"
218
  else:
 
20
  NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
21
  LLM_NAME = "google/flan-t5-base"
22
 
 
23
  sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
 
 
24
  nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
25
 
 
26
  llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
27
  llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
28
 
 
 
29
  # ============================================================
30
  # CONFIG
31
  # ============================================================
32
 
33
+ SIM_THRESHOLD = 0.60
34
  CONTRADICTION_THRESHOLD = 0.70
35
  SCHEMA_CACHE = {}
36
 
 
51
  return hashlib.sha256((kb + question).encode()).hexdigest()
52
 
53
  def infer_question_type(question):
54
+ q = question.lower()
55
  if q.startswith("how"):
56
  return "METHOD"
57
  if q.startswith("why"):
58
  return "REASON"
 
 
59
  return "FACT"
60
 
61
+ def decompose_answer(answer):
62
+ return [s.strip() for s in split_sentences(answer) if len(s.strip()) > 3]
63
+
64
  # ============================================================
65
+ # 🔥 ACTION-FOCUSED SCHEMA GENERATION (FIXED)
66
  # ============================================================
67
 
68
  def generate_schema_with_llm(kb, question):
69
  q_type = infer_question_type(question)
70
 
71
  prompt = f"""
72
+ You are extracting the exact answer to a competitive exam question.
73
 
74
+ STRICT RULES:
75
+ - Extract ONLY the direct action that answers the question.
76
+ - DO NOT include background events.
77
+ - DO NOT include earlier or later story details.
78
+ - Use ACTIVE VERBS.
79
+ - Keep answers short (one clause).
80
 
81
+ Question type: {q_type}
82
 
83
  Knowledge Base:
84
  {kb}
 
86
  Question:
87
  {question}
88
 
89
+ Return the answer as bullet points.
90
  """
91
 
92
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
 
93
  outputs = llm_model.generate(
94
  **inputs,
95
+ max_new_tokens=80,
96
+ temperature=0.0,
97
+ do_sample=False
98
  )
99
 
100
  raw = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
 
108
  return {
109
  "question_type": q_type,
110
  "required_concepts": facts,
 
111
  "raw_llm_output": raw
112
  }
113
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
114
  # ============================================================
115
  # CORE EVALUATION
116
  # ============================================================
 
128
 
129
  if key not in SCHEMA_CACHE:
130
  schema = generate_schema_with_llm(kb, question)
 
131
 
132
+ # HARD FILTER: must contain an ACTION VERB
133
+ action_schema = []
134
+ for s in schema["required_concepts"]:
135
+ if re.search(r'\b(bit|cut|free|help|rescue|save)\b', s.lower()):
136
+ action_schema.append(s)
137
+
138
+ # Fallback: extract action sentences directly from KB
139
+ if not action_schema:
140
+ action_schema = [
141
  s for s in split_sentences(kb)
142
+ if re.search(r'\b(bit|cut|free|help|rescue|save)\b', s.lower())
143
+ ]
144
 
145
+ schema["required_concepts"] = action_schema[:2]
146
  SCHEMA_CACHE[key] = schema
147
 
148
  schema = SCHEMA_CACHE[key]
 
158
  for concept in schema["required_concepts"]:
159
  scores = sim_model.predict([(concept, c) for c in claims])
160
  best = float(scores.max())
161
+ ok = best >= SIM_THRESHOLD
162
 
163
  coverage.append({
164
  "concept": concept,
 
175
  contradictions = []
176
 
177
  for claim in claims:
178
+ for ref in schema["required_concepts"]:
179
+ probs = softmax_logits(nli_model.predict([(ref, claim)]))
180
  if probs[0] > CONTRADICTION_THRESHOLD:
181
  contradictions.append({
182
  "claim": claim,
183
+ "against": ref,
184
  "confidence": round(probs[0] * 100, 1)
185
  })
186
 
187
  logs["contradictions"] = contradictions
188
 
189
+ # ---------------- VERDICT ----------------
190
  if contradictions:
191
+ verdict = "❌ INCORRECT"
192
  elif covered_all:
193
  verdict = "✅ CORRECT"
194
  else: