Spaces:
Runtime error
Runtime error
File size: 5,920 Bytes
212d6cc e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 32a28cd 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 32a28cd 4a83cb5 32a28cd 4a83cb5 32a28cd 177dde3 4a83cb5 177dde3 32a28cd 177dde3 32a28cd 4a83cb5 32a28cd e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb d14dbdf e2daaeb 32a28cd e2daaeb d14dbdf e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb d14dbdf 4a83cb5 d14dbdf 4a83cb5 d14dbdf e2daaeb 177dde3 e2daaeb 177dde3 e2daaeb 177dde3 212d6cc d14dbdf e2daaeb 177dde3 e2daaeb 177dde3 212d6cc e2daaeb 212d6cc e2daaeb 212d6cc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 | import gradio as gr
import torch
import torch.nn.functional as F
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import hashlib
import json
# ============================================================
# DEVICE
# ============================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ============================================================
# MODELS
# ============================================================
SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
LLM_NAME = "google/flan-t5-base"
print("Loading similarity + NLI models...")
sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
print("Loading LLM for schema generation...")
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
print("✅ All models loaded successfully")
# ============================================================
# CONFIGURATION
# ============================================================
SIM_THRESHOLD_REQUIRED = 0.55
CONTRADICTION_THRESHOLD = 0.70
SCHEMA_CACHE = {}
# ============================================================
# UTILITIES
# ============================================================
def split_sentences(text):
return re.split(r'(?<=[.!?])\s+', text.strip())
def softmax_logits(logits):
t = torch.tensor(logits)
if t.dim() > 1:
t = t.squeeze(0)
return F.softmax(t, dim=0).tolist()
def hash_key(kb, question):
return hashlib.sha256((kb + question).encode()).hexdigest()
# ============================================================
# LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR)
# ============================================================
def generate_schema_with_llm(kb, question):
prompt = f"""
From the knowledge base below, answer the question using short factual points.
Knowledge Base:
{kb}
Question:
{question}
Write 1–3 short factual bullet points. Do NOT explain.
"""
inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
outputs = llm_model.generate(
**inputs,
max_new_tokens=128,
temperature=0.0,
do_sample=False
)
text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract bullet-like facts
facts = [
line.strip("-• ").strip()
for line in text.split("\n")
if len(line.strip()) > 3
]
return {
"question_type": "FACT",
"required_concepts": facts,
"forbidden_concepts": [],
"allow_extra_info": True,
"raw_llm_output": text
}
# ============================================================
# ANSWER DECOMPOSITION
# ============================================================
def decompose_answer(answer):
clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer)
return [c.strip() for c in clauses if c.strip()]
# ============================================================
# CORE EVALUATION
# ============================================================
def evaluate_answer(answer, question, kb):
logs = {
"inputs": {
"question": question,
"answer": answer,
"kb_length": len(kb)
}
}
# ---------------- SCHEMA ----------------
key = hash_key(kb, question)
if key not in SCHEMA_CACHE:
SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)
schema = SCHEMA_CACHE[key]
logs["schema"] = schema
# ---------------- ANSWER CLAIMS ----------------
claims = decompose_answer(answer)
logs["claims"] = claims
# ---------------- REQUIRED COVERAGE ----------------
coverage = []
covered_all = True
for concept in schema.get("required_concepts", []):
scores = sim_model.predict([(concept, c) for c in claims])
best = float(scores.max())
covered = best >= SIM_THRESHOLD_REQUIRED
coverage.append({
"concept": concept,
"similarity": round(best, 3),
"covered": covered
})
if not covered:
covered_all = False
logs["coverage"] = coverage
# ---------------- CONTRADICTION CHECK (FIXED) ----------------
contradictions = []
relevant_kb = schema.get("required_concepts", [])
for claim in claims:
for sent in relevant_kb:
probs = softmax_logits(nli_model.predict([(sent, claim)]))
if probs[0] > CONTRADICTION_THRESHOLD:
contradictions.append({
"claim": claim,
"sentence": sent,
"confidence": round(probs[0] * 100, 1)
})
logs["contradictions"] = contradictions
# ---------------- FINAL DECISION ----------------
if contradictions:
verdict = "❌ INCORRECT (Contradiction)"
elif covered_all:
verdict = "✅ CORRECT"
else:
verdict = "⚠️ PARTIALLY CORRECT"
logs["final_verdict"] = verdict
return verdict, logs
# ============================================================
# GRADIO UI
# ============================================================
def run(answer, question, kb):
return evaluate_answer(answer, question, kb)
with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)")
kb = gr.Textbox(label="Knowledge Base", lines=7)
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Student Answer")
verdict = gr.Textbox(label="Verdict")
debug = gr.JSON(label="Debug Logs")
btn = gr.Button("Evaluate")
btn.click(run, [answer, question, kb], [verdict, debug])
demo.launch()
|