| """Figure 7 — Error bucket heatmap. |
| |
| Curated rows (cold-start degree, citation popularity, internal rank, local evidence) |
| x metrics (precision, recall, F1). Weak rows (F1 < 0.9) are outlined in red to |
| localise where the model still struggles. |
| """ |
| from pathlib import Path |
| import pandas as pd |
| import numpy as np |
| import matplotlib.pyplot as plt |
| import seaborn as sns |
| from style import apply, save, COL2 |
|
|
| KEY = "fig7_error_heatmap" |
| TITLE = "Figure 7. Error bucket heatmap" |
|
|
| |
| GROUPS = [ |
| ("author_degree", "author deg."), |
| ("paper_degree", "paper deg."), |
| ("paper_citation_in", "citation in-deg."), |
| ("author_internal_rank", "author rank"), |
| ("has_local_evidence", "local evidence"), |
| ] |
|
|
|
|
| def _bucket_order_key(b): |
| import re |
| |
| toks = re.findall(r"-?inf|-?\d+(?:\.\d+)?", b) |
| vals = [] |
| for v in toks: |
| if v.endswith("inf"): |
| vals.append(-1e9 if v.startswith("-") else 1e9) |
| else: |
| vals.append(float(v)) |
| return vals[0] if vals else 0 |
|
|
|
|
| def make(root, out): |
| apply() |
| csv = root / "validation_runs" / "dynamic_seed202" / "error_group_calibration" / "error_analysis_buckets.csv" |
| if not csv.exists(): |
| return dict(key=KEY, title=TITLE, status="skipped", files=[], sources=[], note=str(csv), |
| caption="error_analysis_buckets.csv missing") |
| df = pd.read_csv(csv) |
|
|
| rows, rowlabels, rowcolors, weak = [], [], [], [] |
| for bt, lbl in GROUPS: |
| sub = df[df.bucket_type == bt].copy() |
| sub["lo"] = sub.bucket.apply(_bucket_order_key) |
| sub = sub.sort_values("lo") |
| for _, r in sub.iterrows(): |
| tag = r.bucket.replace(bt, "").strip(" []") |
| rows.append([r.precision, r.recall, r.f1]) |
| rowlabels.append(f"{lbl} · {tag}") |
| rowcolors.append(bt) |
| weak.append(r.f1 < 0.90) |
| M = np.array(rows) |
|
|
| type_color = {bt: sns.color_palette("deep")[i] for i, (bt, _) in enumerate(GROUPS)} |
|
|
| fig, ax = plt.subplots(figsize=(COL2, 5.2)) |
| sns.heatmap(M, annot=True, fmt=".3f", cmap="RdYlGn", vmin=0.0, vmax=1.0, |
| xticklabels=["precision", "recall", "F1"], yticklabels=rowlabels, |
| cbar_kws={"label": "score", "shrink": 0.7}, linewidths=0.4, linecolor="white", ax=ax) |
| ax.set_xticklabels(ax.get_xticklabels(), fontsize=8) |
| ax.set_yticklabels(ax.get_yticklabels(), fontsize=6.8) |
| |
| for tick, bt in zip(ax.get_yticklabels(), rowcolors): |
| tick.set_color(type_color[bt]) |
| |
| for i, w in enumerate(weak): |
| if w: |
| ax.add_patch(plt.Rectangle((0, i), 3, 1, fill=False, edgecolor="red", lw=1.3)) |
| ax.set_title("Where the final model struggles (red = F1 < 0.90)", fontsize=9.5) |
| save(fig, KEY, out) |
| return dict(key=KEY, title=TITLE, status="ok", files=[f"{KEY}.pdf", f"{KEY}.png", f"{KEY}.svg"], |
| sources=[str(csv)], |
| caption=( |
| "Error-bucket heatmap (validation, seed=202). Rows are curated node/pair buckets " |
| "(author and paper degree, citation in-degree, author internal rank, local-evidence flag); " |
| "columns are precision, recall and F1. Red-outlined rows have F1 below 0.90. The hard " |
| "regions are cold-start nodes (author/paper degree ≤ 2), low in-citation papers, mid-range " |
| "LightGCN-score pairs, and pairs without any local structural evidence — exactly the cases " |
| "the high-order propagation and random-walk features are designed to rescue.")) |
|
|
|
|
| if __name__ == "__main__": |
| from style import ensure_dirs |
| r = make(Path("."), ensure_dirs(Path("."))) |
| print(r["key"], r["status"]) |
|
|