Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import torch | |
| import torch.nn.functional as F | |
| from sentence_transformers import CrossEncoder | |
| from transformers import AutoTokenizer, AutoModelForSeq2SeqLM | |
| import re | |
| import hashlib | |
| import json | |
| # ============================================================ | |
| # DEVICE | |
| # ============================================================ | |
| DEVICE = "cuda" if torch.cuda.is_available() else "cpu" | |
| # ============================================================ | |
| # MODELS | |
| # ============================================================ | |
| SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base" | |
| NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall" | |
| LLM_NAME = "google/flan-t5-base" | |
| print("Loading similarity + NLI models...") | |
| sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE) | |
| nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE) | |
| print("Loading LLM for schema generation...") | |
| llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME) | |
| llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE) | |
| print("✅ All models loaded successfully") | |
| # ============================================================ | |
| # CONFIGURATION | |
| # ============================================================ | |
| SIM_THRESHOLD_REQUIRED = 0.55 | |
| CONTRADICTION_THRESHOLD = 0.70 | |
| SCHEMA_CACHE = {} | |
| # ============================================================ | |
| # UTILITIES | |
| # ============================================================ | |
| def split_sentences(text): | |
| return re.split(r'(?<=[.!?])\s+', text.strip()) | |
| def softmax_logits(logits): | |
| t = torch.tensor(logits) | |
| if t.dim() > 1: | |
| t = t.squeeze(0) | |
| return F.softmax(t, dim=0).tolist() | |
| def hash_key(kb, question): | |
| return hashlib.sha256((kb + question).encode()).hexdigest() | |
| # ============================================================ | |
| # LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR) | |
| # ============================================================ | |
| def generate_schema_with_llm(kb, question): | |
| prompt = f""" | |
| From the knowledge base below, answer the question using short factual points. | |
| Knowledge Base: | |
| {kb} | |
| Question: | |
| {question} | |
| Write 1–3 short factual bullet points. Do NOT explain. | |
| """ | |
| inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) | |
| outputs = llm_model.generate( | |
| **inputs, | |
| max_new_tokens=128, | |
| temperature=0.0, | |
| do_sample=False | |
| ) | |
| text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) | |
| # Extract bullet-like facts | |
| facts = [ | |
| line.strip("-• ").strip() | |
| for line in text.split("\n") | |
| if len(line.strip()) > 3 | |
| ] | |
| return { | |
| "question_type": "FACT", | |
| "required_concepts": facts, | |
| "forbidden_concepts": [], | |
| "allow_extra_info": True, | |
| "raw_llm_output": text | |
| } | |
| # ============================================================ | |
| # ANSWER DECOMPOSITION | |
| # ============================================================ | |
| def decompose_answer(answer): | |
| clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer) | |
| return [c.strip() for c in clauses if c.strip()] | |
| # ============================================================ | |
| # CORE EVALUATION | |
| # ============================================================ | |
| def evaluate_answer(answer, question, kb): | |
| logs = { | |
| "inputs": { | |
| "question": question, | |
| "answer": answer, | |
| "kb_length": len(kb) | |
| } | |
| } | |
| # ---------------- SCHEMA ---------------- | |
| key = hash_key(kb, question) | |
| if key not in SCHEMA_CACHE: | |
| SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question) | |
| schema = SCHEMA_CACHE[key] | |
| logs["schema"] = schema | |
| # ---------------- ANSWER CLAIMS ---------------- | |
| claims = decompose_answer(answer) | |
| logs["claims"] = claims | |
| # ---------------- REQUIRED COVERAGE ---------------- | |
| coverage = [] | |
| covered_all = True | |
| for concept in schema.get("required_concepts", []): | |
| scores = sim_model.predict([(concept, c) for c in claims]) | |
| best = float(scores.max()) | |
| covered = best >= SIM_THRESHOLD_REQUIRED | |
| coverage.append({ | |
| "concept": concept, | |
| "similarity": round(best, 3), | |
| "covered": covered | |
| }) | |
| if not covered: | |
| covered_all = False | |
| logs["coverage"] = coverage | |
| # ---------------- CONTRADICTION CHECK (FIXED) ---------------- | |
| contradictions = [] | |
| relevant_kb = schema.get("required_concepts", []) | |
| for claim in claims: | |
| for sent in relevant_kb: | |
| probs = softmax_logits(nli_model.predict([(sent, claim)])) | |
| if probs[0] > CONTRADICTION_THRESHOLD: | |
| contradictions.append({ | |
| "claim": claim, | |
| "sentence": sent, | |
| "confidence": round(probs[0] * 100, 1) | |
| }) | |
| logs["contradictions"] = contradictions | |
| # ---------------- FINAL DECISION ---------------- | |
| if contradictions: | |
| verdict = "❌ INCORRECT (Contradiction)" | |
| elif covered_all: | |
| verdict = "✅ CORRECT" | |
| else: | |
| verdict = "⚠️ PARTIALLY CORRECT" | |
| logs["final_verdict"] = verdict | |
| return verdict, logs | |
| # ============================================================ | |
| # GRADIO UI | |
| # ============================================================ | |
| def run(answer, question, kb): | |
| return evaluate_answer(answer, question, kb) | |
| with gr.Blocks(title="Competitive Exam Answer Checker") as demo: | |
| gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)") | |
| kb = gr.Textbox(label="Knowledge Base", lines=7) | |
| question = gr.Textbox(label="Question") | |
| answer = gr.Textbox(label="Student Answer") | |
| verdict = gr.Textbox(label="Verdict") | |
| debug = gr.JSON(label="Debug Logs") | |
| btn = gr.Button("Evaluate") | |
| btn.click(run, [answer, question, kb], [verdict, debug]) | |
| demo.launch() | |