cs3319-project2 / code /figures /fig9_lgcn_hyperparam.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
1.73 kB
"""Fig 9 (appendix): LightGCN hyperparameter sweep — validation F1 by embedding dim x layers.
Parses run names (l{layers}d{dim}) in dynamic_summary.csv, takes best F1 per cell.
Highlights the chosen config (2 layers, dim 512 = 0.93858).
"""
from pathlib import Path
import sys
import pandas as pd
import re
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np
sys.path.insert(0, str(Path(__file__).resolve().parent))
from plot_style import apply, save, PALETTE_DEEP as C # noqa: E402
apply()
ROOT = Path(__file__).resolve().parents[2]
FIG = ROOT / "reports" / "figures"
df = pd.read_csv(ROOT / "validation_runs" / "dynamic_summary.csv")
df = df[df.split == "dynamic_seed202"].copy()
def parse(run):
m = re.search(r"l(\d)d(\d+)", run)
return pd.Series([int(m.group(1)), int(m.group(2))] if m else [np.nan, np.nan])
df[["layers", "dim"]] = df.run.apply(parse)
df = df.dropna(subset=["layers"])
piv = df.groupby(["dim", "layers"]).f1.max().unstack()
fig, ax = plt.subplots(figsize=(8, 4.6))
sns.heatmap(piv, annot=True, fmt=".5f", cmap="viridis", ax=ax,
cbar_kws={"label": "best validation F1"}, linewidths=0.5, linecolor="white")
# highlight chosen config d512 / l2
try:
col = list(piv.columns).index(2)
row = list(piv.index).index(512)
ax.add_patch(plt.Rectangle((col, row), 1, 1, fill=False, edgecolor=C[3], lw=3))
ax.text(col + 0.5, row + 0.18, "chosen", ha="center", color=C[3], fontsize=8, fontweight="bold")
except ValueError:
pass
ax.set_xlabel("propagation layers")
ax.set_ylabel("embedding dim")
ax.set_title("LightGCN hyperparameter sweep (best F1 per cell, split_seed=202)")
save(fig, "fig9_lgcn_hyperparam", FIG)
print("saved fig9_lgcn_hyperparam")