File size: 2,208 Bytes
f28d994 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 | """Fig 5: Precision-Recall and ROC curves across model stages.
Uses the row-aligned validation labels (val_labels_seed202.npy, alignment-verified)
against cached OOF / val scores for four representative stages.
"""
from pathlib import Path
import sys
import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
precision_recall_curve,
roc_curve,
auc,
average_precision_score,
)
sys.path.insert(0, str(Path(__file__).resolve().parent))
from plot_style import apply, save, PALETTE_DEEP as C # noqa: E402
apply()
ROOT = Path(__file__).resolve().parents[2]
FIG = ROOT / "reports" / "figures"
VR = ROOT / "validation_runs" / "dynamic_seed202"
y = np.load(VR / "val_labels_seed202.npy").astype(int)
models = [
("LightGCN ensemble", VR / "dyn202_l2d512_bpr_bigbatch_more/scores/val_vanilla_ensemble_mean.npy", C[7]),
("+ graph stack (post95)", VR / "post95_ablation/ensemble_lgcn_oof.npy", C[1]),
("+ DeepWalk/Node2Vec", VR / "node2vec_deepwalk/node2vec_stack_oof.npy", C[2]),
("+ high-order (final)", VR / "high_order_graph_stack/rich_rw7_highorder_directed_oof.npy", C[3]),
]
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12.5, 5.2))
for name, path, color in models:
s = np.load(path).astype(np.float64)
p, r, _ = precision_recall_curve(y, s)
ax1.plot(r, p, color=color, lw=1.9,
label=f"{name} (AP={average_precision_score(y, s):.4f})")
fpr, tpr, _ = roc_curve(y, s)
ax2.plot(fpr, tpr, color=color, lw=1.9, label=f"{name} (AUC={auc(fpr, tpr):.4f})")
ax1.set_xlabel("Recall")
ax1.set_ylabel("Precision")
ax1.set_title("Precision-Recall")
ax1.set_xlim(0, 1)
ax1.set_ylim(0.9, 1.005)
ax1.legend(loc="lower left", fontsize=9)
ax2.plot([0, 1], [0, 1], color="gray", lw=1, ls=":", label="chance")
ax2.set_xlabel("False positive rate")
ax2.set_ylabel("True positive rate")
ax2.set_title("ROC")
ax2.set_xlim(0, 1)
ax2.set_ylim(0.5, 1.005)
ax2.legend(loc="lower right", fontsize=9)
fig.suptitle("Discrimination improves across stacking stages (validation, seed=202)", y=1.02)
save(fig, "fig5_pr_roc", FIG)
print("saved fig5_pr_roc")
|