|
|
import re |
|
|
import io |
|
|
import math |
|
|
import os |
|
|
import numpy as np |
|
|
import networkx as nx |
|
|
import matplotlib.pyplot as plt |
|
|
from PIL import Image |
|
|
from dataclasses import dataclass |
|
|
import gradio as gr |
|
|
|
|
|
|
|
|
USE_LLM = os.getenv("USE_LLM", "false").lower() == "true" |
|
|
MODEL_ID = "TinyLlama/TinyLlama-1.1B-Chat-v1.0" |
|
|
|
|
|
llm = None |
|
|
if USE_LLM: |
|
|
try: |
|
|
from transformers import AutoModelForCausalLM, AutoTokenizer, pipeline |
|
|
tok = AutoTokenizer.from_pretrained(MODEL_ID) |
|
|
mdl = AutoModelForCausalLM.from_pretrained(MODEL_ID, device_map="auto") |
|
|
llm = pipeline("text-generation", model=mdl, tokenizer=tok, max_new_tokens=220) |
|
|
except Exception: |
|
|
llm = None |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class Parsed: |
|
|
pA: float | None |
|
|
pB: float | None |
|
|
pAintB: float | None |
|
|
text: str |
|
|
|
|
|
prob_pat = r"P\(\s*A\s*\)\s*=\s*([0-9]*\.?[0-9]+)" |
|
|
prob_patB = r"P\(\s*B\s*\)\s*=\s*([0-9]*\.?[0-9]+)" |
|
|
|
|
|
prob_pAnB = r"(?:P\(\s*A\s*∩\s*B\s*\)|P\(\s*A\\cap B\s*\)|P\(\s*A\s*&\s*B\s*\)|P\(\s*A\s*and\s*B\s*\)|P\(\s*A\s*\,\s*B\s*\))" |
|
|
|
|
|
def parse_input(txt: str) -> Parsed: |
|
|
t = txt |
|
|
pA = re.search(prob_pat, t) |
|
|
pB = re.search(prob_patB, t) |
|
|
pAnB_val = None |
|
|
m = re.search(rf"{prob_pAnB}\s*=\s*([0-9]*\.?[0-9]+)", t, re.IGNORECASE) |
|
|
if m: |
|
|
pAnB_val = float(m.group(1)) |
|
|
return Parsed( |
|
|
pA=float(pA.group(1)) if pA else None, |
|
|
pB=float(pB.group(1)) if pB else None, |
|
|
pAintB=pAnB_val, |
|
|
text=txt.strip() |
|
|
) |
|
|
|
|
|
def compute_reasoning(parsed: Parsed, user_reasoning: str | None = ""): |
|
|
res = { |
|
|
"can_compute": False, |
|
|
"p_cond": None, |
|
|
"independence_claimed": False, |
|
|
"independence_holds": None, |
|
|
"explanation": "", |
|
|
} |
|
|
if parsed.pA is None or parsed.pB is None or parsed.pAintB is None: |
|
|
res["explanation"] = "Missing P(A), P(B) or P(A∩B)." |
|
|
return res |
|
|
|
|
|
pA, pB, pAnB = parsed.pA, parsed.pB, parsed.pAintB |
|
|
res["p_cond"] = pAnB / pB if pB != 0 else None |
|
|
res["can_compute"] = pB != 0 |
|
|
|
|
|
if user_reasoning: |
|
|
if re.search(r"independ", user_reasoning, re.IGNORECASE) or re.search(r"P\(A\|B\)\s*=\s*P\(A\)", user_reasoning): |
|
|
res["independence_claimed"] = True |
|
|
|
|
|
res["independence_holds"] = math.isclose(pAnB, pA * pB, rel_tol=1e-6, abs_tol=1e-6) |
|
|
|
|
|
return res |
|
|
|
|
|
def make_graph_image(parsed: Parsed, info: dict): |
|
|
G = nx.DiGraph() |
|
|
G.add_node("Given", desc=f"P(A)={parsed.pA}, P(B)={parsed.pB}, P(A∩B)={parsed.pAintB}") |
|
|
G.add_node("Formula", desc="P(A|B)=P(A∩B)/P(B)") |
|
|
G.add_node("Compute", desc=f"P(A|B) = {parsed.pAintB}/{parsed.pB}") |
|
|
result = f"{(parsed.pAintB/parsed.pB):.4f}" if info["can_compute"] else "undefined" |
|
|
G.add_node("Result", desc=f"P(A|B)={result}") |
|
|
G.add_node("Independence", desc="Assume A ⟂ B (P(A|B)=P(A))") |
|
|
G.add_node("Check", desc=f"P(A)P(B)={parsed.pA*parsed.pB:.4f} vs P(A∩B)={parsed.pAintB:.4f}") |
|
|
|
|
|
G.add_edges_from([ |
|
|
("Given","Formula"), |
|
|
("Formula","Compute"), |
|
|
("Compute","Result"), |
|
|
("Given","Check"), |
|
|
("Independence","Result") |
|
|
]) |
|
|
|
|
|
pos = { |
|
|
"Given": (-0.9, 0.3), |
|
|
"Formula": (-0.2, 0.5), |
|
|
"Compute": ( 0.5, 0.3), |
|
|
"Result": ( 0.6, -0.5), |
|
|
"Independence": (-0.7, -0.4), |
|
|
"Check": (-0.1, -0.2), |
|
|
} |
|
|
|
|
|
fig, ax = plt.subplots(figsize=(5.6, 4.2), dpi=180) |
|
|
ax.axis('off') |
|
|
|
|
|
node_colors = [] |
|
|
for n in G.nodes(): |
|
|
if n == "Independence" and info["independence_claimed"] and not info["independence_holds"]: |
|
|
node_colors.append("#f8d7da") |
|
|
elif n == "Result": |
|
|
node_colors.append("#d1e7dd") |
|
|
else: |
|
|
node_colors.append("#f0f4ff") |
|
|
|
|
|
nx.draw_networkx_nodes(G, pos, node_size=2200, node_color=node_colors, linewidths=1.5, edgecolors="#213555") |
|
|
nx.draw_networkx_edges(G, pos, arrows=True, arrowstyle="->", width=1.2, edge_color="#213555") |
|
|
nx.draw_networkx_labels(G, pos, font_size=9, font_color="#1a2b3c") |
|
|
|
|
|
for n, (x,y) in pos.items(): |
|
|
ax.text(x, y-0.12, G.nodes[n]["desc"], fontsize=8, ha="center", color="#334e68") |
|
|
|
|
|
buf = io.BytesIO() |
|
|
plt.tight_layout() |
|
|
plt.savefig(buf, format="png", bbox_inches="tight", dpi=180) |
|
|
plt.close(fig) |
|
|
buf.seek(0) |
|
|
|
|
|
img = Image.open(buf).convert("RGB") |
|
|
return np.array(img) |
|
|
|
|
|
def rule_based_explanation(parsed: Parsed, info: dict): |
|
|
if not info["can_compute"]: |
|
|
return "Cannot compute: P(B)=0 or missing values." |
|
|
lines = [ |
|
|
"• Using the definition, P(A|B)=P(A∩B)/P(B).", |
|
|
f"• Substituting: {parsed.pAintB} / {parsed.pB} → {parsed.pAintB/parsed.pB:.4f}." |
|
|
] |
|
|
if info["independence_claimed"]: |
|
|
if info["independence_holds"]: |
|
|
lines += [ |
|
|
"• Your assumption of independence holds (P(A∩B)=P(A)P(B)).", |
|
|
"• In this case P(A|B)=P(A) is consistent with the data." |
|
|
] |
|
|
else: |
|
|
lines += [ |
|
|
f"• Your assumption of independence is violated: P(A)P(B)={parsed.pA*parsed.pB:.4f} ≠ P(A∩B)={parsed.pAintB:.4f}.", |
|
|
"• Minimal fix: treat A and B as dependent; use P(A|B)=P(A∩B)/P(B).", |
|
|
f"• Counterfactual: if independence held, P(A|B) would equal P(A)={parsed.pA:.4f}." |
|
|
] |
|
|
else: |
|
|
lines.append("• No independence assumption detected; standard conditional rule applied.") |
|
|
return "\n".join(lines) |
|
|
|
|
|
def llm_explain(prompt): |
|
|
if llm is None: |
|
|
return None |
|
|
try: |
|
|
out = llm(prompt)[0]["generated_text"] |
|
|
return out |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
EXAMPLE_PROBLEM = "Problem: Given P(A)=0.4, P(B)=0.5, P(A∩B)=0.18, find P(A|B)." |
|
|
EXAMPLE_REASONING = "You assumed independence, so P(A|B)=P(A)=0.4." |
|
|
|
|
|
def run(problem_text, user_reasoning): |
|
|
parsed = parse_input(problem_text or "") |
|
|
info = compute_reasoning(parsed, user_reasoning or "") |
|
|
graph_img = make_graph_image(parsed, info) |
|
|
rb_text = rule_based_explanation(parsed, info) |
|
|
llm_out = llm_explain( |
|
|
f"Explain the mistake and minimal fix clearly:\nProblem: {problem_text}\nUser reasoning: {user_reasoning}\n" |
|
|
f"Key facts: P(A)={parsed.pA}, P(B)={parsed.pB}, P(A∩B)={parsed.pAintB}.\\n" |
|
|
f"Computed P(A|B)={info['p_cond']}. Independence claimed: {info['independence_claimed']}. " |
|
|
f"Independence holds: {info['independence_holds']}.\\n" |
|
|
f"Give a concise, student-friendly diagnosis and a contrastive / counterfactual fix." |
|
|
) |
|
|
final_text = llm_out if llm_out else rb_text |
|
|
status = { |
|
|
"Parsed": f"P(A)={parsed.pA}, P(B)={parsed.pB}, P(A∩B)={parsed.pAintB}", |
|
|
"Independence claimed": info["independence_claimed"], |
|
|
"Independence holds": info["independence_holds"], |
|
|
"P(A|B)": f"{info['p_cond']:.4f}" if info["p_cond"] is not None else "undefined", |
|
|
} |
|
|
stat_str = "\n".join([f"{k}: {v}" for k,v in status.items()]) |
|
|
return graph_img, final_text, stat_str |
|
|
|
|
|
with gr.Blocks(theme=gr.themes.Soft(primary_hue='indigo')) as demo: |
|
|
gr.Markdown("## Aicher (MVP) — Visual Causal Tutor\n" |
|
|
"Enter a conditional-probability problem (or use the example). " |
|
|
"Optionally type your own reasoning to test independence errors.") |
|
|
|
|
|
with gr.Row(): |
|
|
with gr.Column(scale=1): |
|
|
problem = gr.Textbox(label="Problem", value=EXAMPLE_PROBLEM, lines=4) |
|
|
reasoning = gr.Textbox(label="Your reasoning (optional)", value=EXAMPLE_REASONING, lines=3) |
|
|
btn = gr.Button("Explain reasoning") |
|
|
gr.Examples( |
|
|
examples=[[EXAMPLE_PROBLEM, EXAMPLE_REASONING], |
|
|
["P(A)=0.3, P(B)=0.25, P(A∩B)=0.05. Find P(A|B).", "No independence claimed."]], |
|
|
inputs=[problem, reasoning], |
|
|
) |
|
|
with gr.Column(scale=2): |
|
|
graph = gr.Image(label="Reasoning Graph", type="numpy") |
|
|
with gr.Column(scale=1): |
|
|
diagnosis = gr.Textbox(label="Diagnosis & Minimal Fix", lines=12) |
|
|
status = gr.Textbox(label="Parsed / Status", lines=6) |
|
|
|
|
|
btn.click(run, inputs=[problem, reasoning], outputs=[graph, diagnosis, status]) |
|
|
|
|
|
if __name__ == "__main__": |
|
|
demo.launch(server_name="0.0.0.0", share=True) |
|
|
|