File size: 1,732 Bytes
f28d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
"""Fig 9 (appendix): LightGCN hyperparameter sweep — validation F1 by embedding dim x layers.

Parses run names (l{layers}d{dim}) in dynamic_summary.csv, takes best F1 per cell.
Highlights the chosen config (2 layers, dim 512 = 0.93858).
"""
from pathlib import Path
import sys

import pandas as pd
import re
import matplotlib.pyplot as plt
import seaborn as sns
import numpy as np

sys.path.insert(0, str(Path(__file__).resolve().parent))
from plot_style import apply, save, PALETTE_DEEP as C  # noqa: E402

apply()
ROOT = Path(__file__).resolve().parents[2]
FIG = ROOT / "reports" / "figures"

df = pd.read_csv(ROOT / "validation_runs" / "dynamic_summary.csv")
df = df[df.split == "dynamic_seed202"].copy()


def parse(run):
    m = re.search(r"l(\d)d(\d+)", run)
    return pd.Series([int(m.group(1)), int(m.group(2))] if m else [np.nan, np.nan])


df[["layers", "dim"]] = df.run.apply(parse)
df = df.dropna(subset=["layers"])
piv = df.groupby(["dim", "layers"]).f1.max().unstack()

fig, ax = plt.subplots(figsize=(8, 4.6))
sns.heatmap(piv, annot=True, fmt=".5f", cmap="viridis", ax=ax,
            cbar_kws={"label": "best validation F1"}, linewidths=0.5, linecolor="white")
# highlight chosen config d512 / l2
try:
    col = list(piv.columns).index(2)
    row = list(piv.index).index(512)
    ax.add_patch(plt.Rectangle((col, row), 1, 1, fill=False, edgecolor=C[3], lw=3))
    ax.text(col + 0.5, row + 0.18, "chosen", ha="center", color=C[3], fontsize=8, fontweight="bold")
except ValueError:
    pass
ax.set_xlabel("propagation layers")
ax.set_ylabel("embedding dim")
ax.set_title("LightGCN hyperparameter sweep (best F1 per cell, split_seed=202)")
save(fig, "fig9_lgcn_hyperparam", FIG)
print("saved fig9_lgcn_hyperparam")