"""Figure 7 — Error bucket heatmap. Curated rows (cold-start degree, citation popularity, internal rank, local evidence) x metrics (precision, recall, F1). Weak rows (F1 < 0.9) are outlined in red to localise where the model still struggles. """ from pathlib import Path import pandas as pd import numpy as np import matplotlib.pyplot as plt import seaborn as sns from style import apply, save, COL2 # noqa: E402 KEY = "fig7_error_heatmap" TITLE = "Figure 7. Error bucket heatmap" # (bucket_type, display label, bucket order desired) GROUPS = [ ("author_degree", "author deg."), ("paper_degree", "paper deg."), ("paper_citation_in", "citation in-deg."), ("author_internal_rank", "author rank"), ("has_local_evidence", "local evidence"), ] def _bucket_order_key(b): import re # robust low-bound extraction for "[low,high)", "=val", or bare value toks = re.findall(r"-?inf|-?\d+(?:\.\d+)?", b) vals = [] for v in toks: if v.endswith("inf"): vals.append(-1e9 if v.startswith("-") else 1e9) else: vals.append(float(v)) return vals[0] if vals else 0 def make(root, out): apply() csv = root / "validation_runs" / "dynamic_seed202" / "error_group_calibration" / "error_analysis_buckets.csv" if not csv.exists(): return dict(key=KEY, title=TITLE, status="skipped", files=[], sources=[], note=str(csv), caption="error_analysis_buckets.csv missing") df = pd.read_csv(csv) rows, rowlabels, rowcolors, weak = [], [], [], [] for bt, lbl in GROUPS: sub = df[df.bucket_type == bt].copy() sub["lo"] = sub.bucket.apply(_bucket_order_key) sub = sub.sort_values("lo") for _, r in sub.iterrows(): tag = r.bucket.replace(bt, "").strip(" []") rows.append([r.precision, r.recall, r.f1]) rowlabels.append(f"{lbl} · {tag}") rowcolors.append(bt) weak.append(r.f1 < 0.90) M = np.array(rows) type_color = {bt: sns.color_palette("deep")[i] for i, (bt, _) in enumerate(GROUPS)} fig, ax = plt.subplots(figsize=(COL2, 5.2)) sns.heatmap(M, annot=True, fmt=".3f", cmap="RdYlGn", vmin=0.0, vmax=1.0, xticklabels=["precision", "recall", "F1"], yticklabels=rowlabels, cbar_kws={"label": "score", "shrink": 0.7}, linewidths=0.4, linecolor="white", ax=ax) ax.set_xticklabels(ax.get_xticklabels(), fontsize=8) ax.set_yticklabels(ax.get_yticklabels(), fontsize=6.8) # color-code row labels by bucket type for tick, bt in zip(ax.get_yticklabels(), rowcolors): tick.set_color(type_color[bt]) # outline weak rows for i, w in enumerate(weak): if w: ax.add_patch(plt.Rectangle((0, i), 3, 1, fill=False, edgecolor="red", lw=1.3)) ax.set_title("Where the final model struggles (red = F1 < 0.90)", fontsize=9.5) save(fig, KEY, out) return dict(key=KEY, title=TITLE, status="ok", files=[f"{KEY}.pdf", f"{KEY}.png", f"{KEY}.svg"], sources=[str(csv)], caption=( "Error-bucket heatmap (validation, seed=202). Rows are curated node/pair buckets " "(author and paper degree, citation in-degree, author internal rank, local-evidence flag); " "columns are precision, recall and F1. Red-outlined rows have F1 below 0.90. The hard " "regions are cold-start nodes (author/paper degree ≤ 2), low in-citation papers, mid-range " "LightGCN-score pairs, and pairs without any local structural evidence — exactly the cases " "the high-order propagation and random-walk features are designed to rescue.")) if __name__ == "__main__": from style import ensure_dirs r = make(Path("."), ensure_dirs(Path("."))) print(r["key"], r["status"])