heerjtdev commited on
Commit
4a83cb5
·
verified ·
1 Parent(s): 177dde3

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +46 -42
app.py CHANGED
@@ -61,13 +61,8 @@ def hash_key(kb, question):
61
  # ============================================================
62
 
63
  def generate_schema_with_llm(kb, question):
64
- """
65
- Uses a FREE HuggingFace LLM (Flan-T5) to generate grading schema.
66
- HF-safe, CPU-safe, deterministic.
67
- """
68
-
69
  prompt = f"""
70
- Extract the expected answer from the knowledge base.
71
 
72
  Knowledge Base:
73
  {kb}
@@ -75,40 +70,35 @@ Knowledge Base:
75
  Question:
76
  {question}
77
 
78
- Return ONLY valid JSON in this format:
79
-
80
- {{
81
- "question_type": "FACT",
82
- "required_concepts": ["fact1", "fact2"],
83
- "forbidden_concepts": [],
84
- "allow_extra_info": true
85
- }}
86
  """
87
 
88
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
89
 
90
  outputs = llm_model.generate(
91
  **inputs,
92
- max_new_tokens=256,
93
  temperature=0.0,
94
  do_sample=False
95
  )
96
 
97
  text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
98
 
99
- try:
100
- json_text = text[text.find("{"):text.rfind("}") + 1]
101
- return json.loads(json_text)
102
- except Exception:
103
- # HARD FAIL SAFE
104
- return {
105
- "question_type": "FACT",
106
- "required_concepts": [],
107
- "forbidden_concepts": [],
108
- "allow_extra_info": True,
109
- "error": "LLM schema parse failed",
110
- "raw_output": text
111
- }
 
 
112
 
113
  # ============================================================
114
  # ANSWER DECOMPOSITION
@@ -160,20 +150,34 @@ def evaluate_answer(answer, question, kb):
160
  logs["coverage"] = coverage
161
 
162
  # ---------------- CONTRADICTION CHECK ----------------
163
- contradictions = []
164
- kb_sents = split_sentences(kb)
165
-
166
- for claim in claims:
167
- for sent in kb_sents:
168
- probs = softmax_logits(nli_model.predict([(sent, claim)]))
169
- if probs[0] > CONTRADICTION_THRESHOLD:
170
- contradictions.append({
171
- "claim": claim,
172
- "sentence": sent,
173
- "confidence": round(probs[0] * 100, 1)
174
- })
175
-
176
- logs["contradictions"] = contradictions
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177
 
178
  # ---------------- FINAL DECISION ----------------
179
  if contradictions:
 
61
  # ============================================================
62
 
63
  def generate_schema_with_llm(kb, question):
 
 
 
 
 
64
  prompt = f"""
65
+ From the knowledge base below, answer the question using short factual points.
66
 
67
  Knowledge Base:
68
  {kb}
 
70
  Question:
71
  {question}
72
 
73
+ Write 1–3 short factual bullet points. Do NOT explain.
 
 
 
 
 
 
 
74
  """
75
 
76
  inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
77
 
78
  outputs = llm_model.generate(
79
  **inputs,
80
+ max_new_tokens=128,
81
  temperature=0.0,
82
  do_sample=False
83
  )
84
 
85
  text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
86
 
87
+ # Extract bullet-like facts
88
+ facts = [
89
+ line.strip("-• ").strip()
90
+ for line in text.split("\n")
91
+ if len(line.strip()) > 3
92
+ ]
93
+
94
+ return {
95
+ "question_type": "FACT",
96
+ "required_concepts": facts,
97
+ "forbidden_concepts": [],
98
+ "allow_extra_info": True,
99
+ "raw_llm_output": text
100
+ }
101
+
102
 
103
  # ============================================================
104
  # ANSWER DECOMPOSITION
 
150
  logs["coverage"] = coverage
151
 
152
  # ---------------- CONTRADICTION CHECK ----------------
153
+ # contradictions = []
154
+ # kb_sents = split_sentences(kb)
155
+
156
+ # for claim in claims:
157
+ # for sent in kb_sents:
158
+ # probs = softmax_logits(nli_model.predict([(sent, claim)]))
159
+ # if probs[0] > CONTRADICTION_THRESHOLD:
160
+ # contradictions.append({
161
+ # "claim": claim,
162
+ # "sentence": sent,
163
+ # "confidence": round(probs[0] * 100, 1)
164
+ # })
165
+
166
+ # logs["contradictions"] = contradictions
167
+
168
+
169
+ relevant_kb = schema.get("required_concepts", [])
170
+
171
+ for claim in claims:
172
+ for sent in relevant_kb:
173
+ probs = softmax_logits(nli_model.predict([(sent, claim)]))
174
+ if probs[0] > CONTRADICTION_THRESHOLD:
175
+ contradictions.append({
176
+ "claim": claim,
177
+ "sentence": sent,
178
+ "confidence": round(probs[0] * 100, 1)
179
+ })
180
+
181
 
182
  # ---------------- FINAL DECISION ----------------
183
  if contradictions: