answer_checker / app.py
heerjtdev's picture
Update app.py
d14dbdf verified
raw
history blame
5.92 kB
import gradio as gr
import torch
import torch.nn.functional as F
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import hashlib
import json
# ============================================================
# DEVICE
# ============================================================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# ============================================================
# MODELS
# ============================================================
SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
LLM_NAME = "google/flan-t5-base"
print("Loading similarity + NLI models...")
sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)
print("Loading LLM for schema generation...")
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)
print("✅ All models loaded successfully")
# ============================================================
# CONFIGURATION
# ============================================================
SIM_THRESHOLD_REQUIRED = 0.55
CONTRADICTION_THRESHOLD = 0.70
SCHEMA_CACHE = {}
# ============================================================
# UTILITIES
# ============================================================
def split_sentences(text):
return re.split(r'(?<=[.!?])\s+', text.strip())
def softmax_logits(logits):
t = torch.tensor(logits)
if t.dim() > 1:
t = t.squeeze(0)
return F.softmax(t, dim=0).tolist()
def hash_key(kb, question):
return hashlib.sha256((kb + question).encode()).hexdigest()
# ============================================================
# LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR)
# ============================================================
def generate_schema_with_llm(kb, question):
prompt = f"""
From the knowledge base below, answer the question using short factual points.
Knowledge Base:
{kb}
Question:
{question}
Write 1–3 short factual bullet points. Do NOT explain.
"""
inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)
outputs = llm_model.generate(
**inputs,
max_new_tokens=128,
temperature=0.0,
do_sample=False
)
text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)
# Extract bullet-like facts
facts = [
line.strip("-• ").strip()
for line in text.split("\n")
if len(line.strip()) > 3
]
return {
"question_type": "FACT",
"required_concepts": facts,
"forbidden_concepts": [],
"allow_extra_info": True,
"raw_llm_output": text
}
# ============================================================
# ANSWER DECOMPOSITION
# ============================================================
def decompose_answer(answer):
clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer)
return [c.strip() for c in clauses if c.strip()]
# ============================================================
# CORE EVALUATION
# ============================================================
def evaluate_answer(answer, question, kb):
logs = {
"inputs": {
"question": question,
"answer": answer,
"kb_length": len(kb)
}
}
# ---------------- SCHEMA ----------------
key = hash_key(kb, question)
if key not in SCHEMA_CACHE:
SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)
schema = SCHEMA_CACHE[key]
logs["schema"] = schema
# ---------------- ANSWER CLAIMS ----------------
claims = decompose_answer(answer)
logs["claims"] = claims
# ---------------- REQUIRED COVERAGE ----------------
coverage = []
covered_all = True
for concept in schema.get("required_concepts", []):
scores = sim_model.predict([(concept, c) for c in claims])
best = float(scores.max())
covered = best >= SIM_THRESHOLD_REQUIRED
coverage.append({
"concept": concept,
"similarity": round(best, 3),
"covered": covered
})
if not covered:
covered_all = False
logs["coverage"] = coverage
# ---------------- CONTRADICTION CHECK (FIXED) ----------------
contradictions = []
relevant_kb = schema.get("required_concepts", [])
for claim in claims:
for sent in relevant_kb:
probs = softmax_logits(nli_model.predict([(sent, claim)]))
if probs[0] > CONTRADICTION_THRESHOLD:
contradictions.append({
"claim": claim,
"sentence": sent,
"confidence": round(probs[0] * 100, 1)
})
logs["contradictions"] = contradictions
# ---------------- FINAL DECISION ----------------
if contradictions:
verdict = "❌ INCORRECT (Contradiction)"
elif covered_all:
verdict = "✅ CORRECT"
else:
verdict = "⚠️ PARTIALLY CORRECT"
logs["final_verdict"] = verdict
return verdict, logs
# ============================================================
# GRADIO UI
# ============================================================
def run(answer, question, kb):
return evaluate_answer(answer, question, kb)
with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)")
kb = gr.Textbox(label="Knowledge Base", lines=7)
question = gr.Textbox(label="Question")
answer = gr.Textbox(label="Student Answer")
verdict = gr.Textbox(label="Verdict")
debug = gr.JSON(label="Debug Logs")
btn = gr.Button("Evaluate")
btn.click(run, [answer, question, kb], [verdict, debug])
demo.launch()