import gradio as gr import torch import torch.nn.functional as F from sentence_transformers import CrossEncoder from transformers import AutoTokenizer, AutoModelForSeq2SeqLM import re import hashlib import json # ============================================================ # DEVICE # ============================================================ DEVICE = "cuda" if torch.cuda.is_available() else "cpu" # ============================================================ # MODELS # ============================================================ SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base" NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall" LLM_NAME = "google/flan-t5-base" print("Loading similarity + NLI models...") sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE) nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE) print("Loading LLM for schema generation...") llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME) llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE) print("✅ All models loaded successfully") # ============================================================ # CONFIGURATION # ============================================================ SIM_THRESHOLD_REQUIRED = 0.55 CONTRADICTION_THRESHOLD = 0.70 SCHEMA_CACHE = {} # ============================================================ # UTILITIES # ============================================================ def split_sentences(text): return re.split(r'(?<=[.!?])\s+', text.strip()) def softmax_logits(logits): t = torch.tensor(logits) if t.dim() > 1: t = t.squeeze(0) return F.softmax(t, dim=0).tolist() def hash_key(kb, question): return hashlib.sha256((kb + question).encode()).hexdigest() # ============================================================ # LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR) # ============================================================ def generate_schema_with_llm(kb, question): prompt = f""" From the knowledge base below, answer the question using short factual points. Knowledge Base: {kb} Question: {question} Write 1–3 short factual bullet points. Do NOT explain. """ inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE) outputs = llm_model.generate( **inputs, max_new_tokens=128, temperature=0.0, do_sample=False ) text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True) # Extract bullet-like facts facts = [ line.strip("-• ").strip() for line in text.split("\n") if len(line.strip()) > 3 ] return { "question_type": "FACT", "required_concepts": facts, "forbidden_concepts": [], "allow_extra_info": True, "raw_llm_output": text } # ============================================================ # ANSWER DECOMPOSITION # ============================================================ def decompose_answer(answer): clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer) return [c.strip() for c in clauses if c.strip()] # ============================================================ # CORE EVALUATION # ============================================================ def evaluate_answer(answer, question, kb): logs = { "inputs": { "question": question, "answer": answer, "kb_length": len(kb) } } # ---------------- SCHEMA ---------------- key = hash_key(kb, question) if key not in SCHEMA_CACHE: SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question) schema = SCHEMA_CACHE[key] logs["schema"] = schema # ---------------- ANSWER CLAIMS ---------------- claims = decompose_answer(answer) logs["claims"] = claims # ---------------- REQUIRED COVERAGE ---------------- coverage = [] covered_all = True for concept in schema.get("required_concepts", []): scores = sim_model.predict([(concept, c) for c in claims]) best = float(scores.max()) covered = best >= SIM_THRESHOLD_REQUIRED coverage.append({ "concept": concept, "similarity": round(best, 3), "covered": covered }) if not covered: covered_all = False logs["coverage"] = coverage # ---------------- CONTRADICTION CHECK (FIXED) ---------------- contradictions = [] relevant_kb = schema.get("required_concepts", []) for claim in claims: for sent in relevant_kb: probs = softmax_logits(nli_model.predict([(sent, claim)])) if probs[0] > CONTRADICTION_THRESHOLD: contradictions.append({ "claim": claim, "sentence": sent, "confidence": round(probs[0] * 100, 1) }) logs["contradictions"] = contradictions # ---------------- FINAL DECISION ---------------- if contradictions: verdict = "❌ INCORRECT (Contradiction)" elif covered_all: verdict = "✅ CORRECT" else: verdict = "⚠️ PARTIALLY CORRECT" logs["final_verdict"] = verdict return verdict, logs # ============================================================ # GRADIO UI # ============================================================ def run(answer, question, kb): return evaluate_answer(answer, question, kb) with gr.Blocks(title="Competitive Exam Answer Checker") as demo: gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)") kb = gr.Textbox(label="Knowledge Base", lines=7) question = gr.Textbox(label="Question") answer = gr.Textbox(label="Student Answer") verdict = gr.Textbox(label="Verdict") debug = gr.JSON(label="Debug Logs") btn = gr.Button("Evaluate") btn.click(run, [answer, question, kb], [verdict, debug]) demo.launch()