Spaces:
Runtime error
Runtime error
Update app.py
Browse files
app.py
CHANGED
|
@@ -61,13 +61,8 @@ def hash_key(kb, question):
|
|
| 61 |
# ============================================================
|
| 62 |
|
| 63 |
def generate_schema_with_llm(kb, question):
|
| 64 |
-
"""
|
| 65 |
-
Uses a FREE HuggingFace LLM (Flan-T5) to generate grading schema.
|
| 66 |
-
HF-safe, CPU-safe, deterministic.
|
| 67 |
-
"""
|
| 68 |
-
|
| 69 |
prompt = f"""
|
| 70 |
-
|
| 71 |
|
| 72 |
Knowledge Base:
|
| 73 |
{kb}
|
|
@@ -75,40 +70,35 @@ Knowledge Base:
|
|
| 75 |
Question:
|
| 76 |
{question}
|
| 77 |
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
{{
|
| 81 |
-
"question_type": "FACT",
|
| 82 |
-
"required_concepts": ["fact1", "fact2"],
|
| 83 |
-
"forbidden_concepts": [],
|
| 84 |
-
"allow_extra_info": true
|
| 85 |
-
}}
|
| 86 |
"""
|
| 87 |
|
| 88 |
inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
|
| 89 |
|
| 90 |
outputs = llm_model.generate(
|
| 91 |
**inputs,
|
| 92 |
-
max_new_tokens=
|
| 93 |
temperature=0.0,
|
| 94 |
do_sample=False
|
| 95 |
)
|
| 96 |
|
| 97 |
text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 98 |
|
| 99 |
-
|
| 100 |
-
|
| 101 |
-
|
| 102 |
-
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
|
| 106 |
-
|
| 107 |
-
|
| 108 |
-
|
| 109 |
-
|
| 110 |
-
|
| 111 |
-
|
|
|
|
|
|
|
| 112 |
|
| 113 |
# ============================================================
|
| 114 |
# ANSWER DECOMPOSITION
|
|
@@ -160,20 +150,34 @@ def evaluate_answer(answer, question, kb):
|
|
| 160 |
logs["coverage"] = coverage
|
| 161 |
|
| 162 |
# ---------------- CONTRADICTION CHECK ----------------
|
| 163 |
-
contradictions = []
|
| 164 |
-
kb_sents = split_sentences(kb)
|
| 165 |
-
|
| 166 |
-
for claim in claims:
|
| 167 |
-
|
| 168 |
-
|
| 169 |
-
|
| 170 |
-
|
| 171 |
-
|
| 172 |
-
|
| 173 |
-
|
| 174 |
-
|
| 175 |
-
|
| 176 |
-
logs["contradictions"] = contradictions
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 177 |
|
| 178 |
# ---------------- FINAL DECISION ----------------
|
| 179 |
if contradictions:
|
|
|
|
| 61 |
# ============================================================
|
| 62 |
|
| 63 |
def generate_schema_with_llm(kb, question):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 64 |
prompt = f"""
|
| 65 |
+
From the knowledge base below, answer the question using short factual points.
|
| 66 |
|
| 67 |
Knowledge Base:
|
| 68 |
{kb}
|
|
|
|
| 70 |
Question:
|
| 71 |
{question}
|
| 72 |
|
| 73 |
+
Write 1–3 short factual bullet points. Do NOT explain.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
"""
|
| 75 |
|
| 76 |
inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
|
| 77 |
|
| 78 |
outputs = llm_model.generate(
|
| 79 |
**inputs,
|
| 80 |
+
max_new_tokens=128,
|
| 81 |
temperature=0.0,
|
| 82 |
do_sample=False
|
| 83 |
)
|
| 84 |
|
| 85 |
text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| 86 |
|
| 87 |
+
# Extract bullet-like facts
|
| 88 |
+
facts = [
|
| 89 |
+
line.strip("-• ").strip()
|
| 90 |
+
for line in text.split("\n")
|
| 91 |
+
if len(line.strip()) > 3
|
| 92 |
+
]
|
| 93 |
+
|
| 94 |
+
return {
|
| 95 |
+
"question_type": "FACT",
|
| 96 |
+
"required_concepts": facts,
|
| 97 |
+
"forbidden_concepts": [],
|
| 98 |
+
"allow_extra_info": True,
|
| 99 |
+
"raw_llm_output": text
|
| 100 |
+
}
|
| 101 |
+
|
| 102 |
|
| 103 |
# ============================================================
|
| 104 |
# ANSWER DECOMPOSITION
|
|
|
|
| 150 |
logs["coverage"] = coverage
|
| 151 |
|
| 152 |
# ---------------- CONTRADICTION CHECK ----------------
|
| 153 |
+
# contradictions = []
|
| 154 |
+
# kb_sents = split_sentences(kb)
|
| 155 |
+
|
| 156 |
+
# for claim in claims:
|
| 157 |
+
# for sent in kb_sents:
|
| 158 |
+
# probs = softmax_logits(nli_model.predict([(sent, claim)]))
|
| 159 |
+
# if probs[0] > CONTRADICTION_THRESHOLD:
|
| 160 |
+
# contradictions.append({
|
| 161 |
+
# "claim": claim,
|
| 162 |
+
# "sentence": sent,
|
| 163 |
+
# "confidence": round(probs[0] * 100, 1)
|
| 164 |
+
# })
|
| 165 |
+
|
| 166 |
+
# logs["contradictions"] = contradictions
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
relevant_kb = schema.get("required_concepts", [])
|
| 170 |
+
|
| 171 |
+
for claim in claims:
|
| 172 |
+
for sent in relevant_kb:
|
| 173 |
+
probs = softmax_logits(nli_model.predict([(sent, claim)]))
|
| 174 |
+
if probs[0] > CONTRADICTION_THRESHOLD:
|
| 175 |
+
contradictions.append({
|
| 176 |
+
"claim": claim,
|
| 177 |
+
"sentence": sent,
|
| 178 |
+
"confidence": round(probs[0] * 100, 1)
|
| 179 |
+
})
|
| 180 |
+
|
| 181 |
|
| 182 |
# ---------------- FINAL DECISION ----------------
|
| 183 |
if contradictions:
|