File size: 5,920 Bytes
212d6cc
e2daaeb
 
 
177dde3
e2daaeb
 
 
 
 
177dde3
e2daaeb
 
177dde3
32a28cd
177dde3
 
 
e2daaeb
 
 
177dde3
e2daaeb
177dde3
e2daaeb
 
177dde3
 
 
 
 
 
e2daaeb
 
 
 
 
 
177dde3
e2daaeb
177dde3
e2daaeb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
177dde3
e2daaeb
 
32a28cd
 
4a83cb5
32a28cd
 
 
 
 
 
 
4a83cb5
32a28cd
 
177dde3
 
 
 
4a83cb5
177dde3
 
32a28cd
 
177dde3
32a28cd
4a83cb5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
32a28cd
e2daaeb
 
 
 
 
177dde3
e2daaeb
 
 
 
 
 
 
177dde3
 
 
 
 
 
 
e2daaeb
d14dbdf
e2daaeb
 
32a28cd
e2daaeb
 
 
 
d14dbdf
e2daaeb
177dde3
e2daaeb
177dde3
e2daaeb
177dde3
e2daaeb
177dde3
 
e2daaeb
177dde3
e2daaeb
177dde3
 
 
e2daaeb
177dde3
 
e2daaeb
177dde3
e2daaeb
d14dbdf
 
 
4a83cb5
d14dbdf
 
 
 
 
 
 
 
 
4a83cb5
d14dbdf
e2daaeb
177dde3
 
e2daaeb
177dde3
 
e2daaeb
 
 
177dde3
212d6cc
 
d14dbdf
e2daaeb
 
 
 
 
 
 
 
177dde3
e2daaeb
177dde3
212d6cc
 
 
 
e2daaeb
212d6cc
 
e2daaeb
212d6cc
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
import gradio as gr
import torch
import torch.nn.functional as F
from sentence_transformers import CrossEncoder
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
import re
import hashlib
import json

# ============================================================
# DEVICE
# ============================================================

DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

# ============================================================
# MODELS
# ============================================================

SIM_MODEL_NAME = "cross-encoder/stsb-distilroberta-base"
NLI_MODEL_NAME = "cross-encoder/nli-deberta-v3-xsmall"
LLM_NAME = "google/flan-t5-base"

print("Loading similarity + NLI models...")
sim_model = CrossEncoder(SIM_MODEL_NAME, device=DEVICE)
nli_model = CrossEncoder(NLI_MODEL_NAME, device=DEVICE)

print("Loading LLM for schema generation...")
llm_tokenizer = AutoTokenizer.from_pretrained(LLM_NAME)
llm_model = AutoModelForSeq2SeqLM.from_pretrained(LLM_NAME).to(DEVICE)

print("✅ All models loaded successfully")

# ============================================================
# CONFIGURATION
# ============================================================

SIM_THRESHOLD_REQUIRED = 0.55
CONTRADICTION_THRESHOLD = 0.70

SCHEMA_CACHE = {}

# ============================================================
# UTILITIES
# ============================================================

def split_sentences(text):
    return re.split(r'(?<=[.!?])\s+', text.strip())

def softmax_logits(logits):
    t = torch.tensor(logits)
    if t.dim() > 1:
        t = t.squeeze(0)
    return F.softmax(t, dim=0).tolist()

def hash_key(kb, question):
    return hashlib.sha256((kb + question).encode()).hexdigest()

# ============================================================
# LLM SCHEMA GENERATION (CORE PART YOU ASKED FOR)
# ============================================================

def generate_schema_with_llm(kb, question):
    prompt = f"""
From the knowledge base below, answer the question using short factual points.

Knowledge Base:
{kb}

Question:
{question}

Write 1–3 short factual bullet points. Do NOT explain.
"""

    inputs = llm_tokenizer(prompt, return_tensors="pt", truncation=True).to(DEVICE)

    outputs = llm_model.generate(
        **inputs,
        max_new_tokens=128,
        temperature=0.0,
        do_sample=False
    )

    text = llm_tokenizer.decode(outputs[0], skip_special_tokens=True)

    # Extract bullet-like facts
    facts = [
        line.strip("-• ").strip()
        for line in text.split("\n")
        if len(line.strip()) > 3
    ]

    return {
        "question_type": "FACT",
        "required_concepts": facts,
        "forbidden_concepts": [],
        "allow_extra_info": True,
        "raw_llm_output": text
    }


# ============================================================
# ANSWER DECOMPOSITION
# ============================================================

def decompose_answer(answer):
    clauses = re.split(r'\b(?:and|because|before|after|while)\b', answer)
    return [c.strip() for c in clauses if c.strip()]

# ============================================================
# CORE EVALUATION
# ============================================================

def evaluate_answer(answer, question, kb):
    logs = {
        "inputs": {
            "question": question,
            "answer": answer,
            "kb_length": len(kb)
        }
    }

    # ---------------- SCHEMA ----------------
    key = hash_key(kb, question)
    if key not in SCHEMA_CACHE:
        SCHEMA_CACHE[key] = generate_schema_with_llm(kb, question)

    schema = SCHEMA_CACHE[key]
    logs["schema"] = schema

    # ---------------- ANSWER CLAIMS ----------------
    claims = decompose_answer(answer)
    logs["claims"] = claims

    # ---------------- REQUIRED COVERAGE ----------------
    coverage = []
    covered_all = True

    for concept in schema.get("required_concepts", []):
        scores = sim_model.predict([(concept, c) for c in claims])
        best = float(scores.max())
        covered = best >= SIM_THRESHOLD_REQUIRED
        coverage.append({
            "concept": concept,
            "similarity": round(best, 3),
            "covered": covered
        })
        if not covered:
            covered_all = False

    logs["coverage"] = coverage

    # ---------------- CONTRADICTION CHECK (FIXED) ----------------
    contradictions = []
    relevant_kb = schema.get("required_concepts", [])

    for claim in claims:
        for sent in relevant_kb:
            probs = softmax_logits(nli_model.predict([(sent, claim)]))
            if probs[0] > CONTRADICTION_THRESHOLD:
                contradictions.append({
                    "claim": claim,
                    "sentence": sent,
                    "confidence": round(probs[0] * 100, 1)
                })

    logs["contradictions"] = contradictions

    # ---------------- FINAL DECISION ----------------
    if contradictions:
        verdict = "❌ INCORRECT (Contradiction)"
    elif covered_all:
        verdict = "✅ CORRECT"
    else:
        verdict = "⚠️ PARTIALLY CORRECT"

    logs["final_verdict"] = verdict
    return verdict, logs


# ============================================================
# GRADIO UI
# ============================================================

def run(answer, question, kb):
    return evaluate_answer(answer, question, kb)

with gr.Blocks(title="Competitive Exam Answer Checker") as demo:
    gr.Markdown("## 🧠 Competitive Exam Answer Checker (HF-Free LLM Version)")

    kb = gr.Textbox(label="Knowledge Base", lines=7)
    question = gr.Textbox(label="Question")
    answer = gr.Textbox(label="Student Answer")

    verdict = gr.Textbox(label="Verdict")
    debug = gr.JSON(label="Debug Logs")

    btn = gr.Button("Evaluate")
    btn.click(run, [answer, question, kb], [verdict, debug])

demo.launch()