cs3319-project2 / figures_paper /scripts /fig7_error_heatmap.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
3.82 kB
"""Figure 7 — Error bucket heatmap.
Curated rows (cold-start degree, citation popularity, internal rank, local evidence)
x metrics (precision, recall, F1). Weak rows (F1 < 0.9) are outlined in red to
localise where the model still struggles.
"""
from pathlib import Path
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import seaborn as sns
from style import apply, save, COL2 # noqa: E402
KEY = "fig7_error_heatmap"
TITLE = "Figure 7. Error bucket heatmap"
# (bucket_type, display label, bucket order desired)
GROUPS = [
("author_degree", "author deg."),
("paper_degree", "paper deg."),
("paper_citation_in", "citation in-deg."),
("author_internal_rank", "author rank"),
("has_local_evidence", "local evidence"),
]
def _bucket_order_key(b):
import re
# robust low-bound extraction for "[low,high)", "=val", or bare value
toks = re.findall(r"-?inf|-?\d+(?:\.\d+)?", b)
vals = []
for v in toks:
if v.endswith("inf"):
vals.append(-1e9 if v.startswith("-") else 1e9)
else:
vals.append(float(v))
return vals[0] if vals else 0
def make(root, out):
apply()
csv = root / "validation_runs" / "dynamic_seed202" / "error_group_calibration" / "error_analysis_buckets.csv"
if not csv.exists():
return dict(key=KEY, title=TITLE, status="skipped", files=[], sources=[], note=str(csv),
caption="error_analysis_buckets.csv missing")
df = pd.read_csv(csv)
rows, rowlabels, rowcolors, weak = [], [], [], []
for bt, lbl in GROUPS:
sub = df[df.bucket_type == bt].copy()
sub["lo"] = sub.bucket.apply(_bucket_order_key)
sub = sub.sort_values("lo")
for _, r in sub.iterrows():
tag = r.bucket.replace(bt, "").strip(" []")
rows.append([r.precision, r.recall, r.f1])
rowlabels.append(f"{lbl} · {tag}")
rowcolors.append(bt)
weak.append(r.f1 < 0.90)
M = np.array(rows)
type_color = {bt: sns.color_palette("deep")[i] for i, (bt, _) in enumerate(GROUPS)}
fig, ax = plt.subplots(figsize=(COL2, 5.2))
sns.heatmap(M, annot=True, fmt=".3f", cmap="RdYlGn", vmin=0.0, vmax=1.0,
xticklabels=["precision", "recall", "F1"], yticklabels=rowlabels,
cbar_kws={"label": "score", "shrink": 0.7}, linewidths=0.4, linecolor="white", ax=ax)
ax.set_xticklabels(ax.get_xticklabels(), fontsize=8)
ax.set_yticklabels(ax.get_yticklabels(), fontsize=6.8)
# color-code row labels by bucket type
for tick, bt in zip(ax.get_yticklabels(), rowcolors):
tick.set_color(type_color[bt])
# outline weak rows
for i, w in enumerate(weak):
if w:
ax.add_patch(plt.Rectangle((0, i), 3, 1, fill=False, edgecolor="red", lw=1.3))
ax.set_title("Where the final model struggles (red = F1 < 0.90)", fontsize=9.5)
save(fig, KEY, out)
return dict(key=KEY, title=TITLE, status="ok", files=[f"{KEY}.pdf", f"{KEY}.png", f"{KEY}.svg"],
sources=[str(csv)],
caption=(
"Error-bucket heatmap (validation, seed=202). Rows are curated node/pair buckets "
"(author and paper degree, citation in-degree, author internal rank, local-evidence flag); "
"columns are precision, recall and F1. Red-outlined rows have F1 below 0.90. The hard "
"regions are cold-start nodes (author/paper degree ≤ 2), low in-citation papers, mid-range "
"LightGCN-score pairs, and pairs without any local structural evidence — exactly the cases "
"the high-order propagation and random-walk features are designed to rescue."))
if __name__ == "__main__":
from style import ensure_dirs
r = make(Path("."), ensure_dirs(Path(".")))
print(r["key"], r["status"])