| """Fig 5: Precision-Recall and ROC curves across model stages. |
| |
| Uses the row-aligned validation labels (val_labels_seed202.npy, alignment-verified) |
| against cached OOF / val scores for four representative stages. |
| """ |
| from pathlib import Path |
| import sys |
|
|
| import numpy as np |
| import matplotlib.pyplot as plt |
| from sklearn.metrics import ( |
| precision_recall_curve, |
| roc_curve, |
| auc, |
| average_precision_score, |
| ) |
|
|
| sys.path.insert(0, str(Path(__file__).resolve().parent)) |
| from plot_style import apply, save, PALETTE_DEEP as C |
|
|
| apply() |
| ROOT = Path(__file__).resolve().parents[2] |
| FIG = ROOT / "reports" / "figures" |
| VR = ROOT / "validation_runs" / "dynamic_seed202" |
| y = np.load(VR / "val_labels_seed202.npy").astype(int) |
|
|
| models = [ |
| ("LightGCN ensemble", VR / "dyn202_l2d512_bpr_bigbatch_more/scores/val_vanilla_ensemble_mean.npy", C[7]), |
| ("+ graph stack (post95)", VR / "post95_ablation/ensemble_lgcn_oof.npy", C[1]), |
| ("+ DeepWalk/Node2Vec", VR / "node2vec_deepwalk/node2vec_stack_oof.npy", C[2]), |
| ("+ high-order (final)", VR / "high_order_graph_stack/rich_rw7_highorder_directed_oof.npy", C[3]), |
| ] |
|
|
| fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12.5, 5.2)) |
|
|
| for name, path, color in models: |
| s = np.load(path).astype(np.float64) |
| p, r, _ = precision_recall_curve(y, s) |
| ax1.plot(r, p, color=color, lw=1.9, |
| label=f"{name} (AP={average_precision_score(y, s):.4f})") |
| fpr, tpr, _ = roc_curve(y, s) |
| ax2.plot(fpr, tpr, color=color, lw=1.9, label=f"{name} (AUC={auc(fpr, tpr):.4f})") |
|
|
| ax1.set_xlabel("Recall") |
| ax1.set_ylabel("Precision") |
| ax1.set_title("Precision-Recall") |
| ax1.set_xlim(0, 1) |
| ax1.set_ylim(0.9, 1.005) |
| ax1.legend(loc="lower left", fontsize=9) |
|
|
| ax2.plot([0, 1], [0, 1], color="gray", lw=1, ls=":", label="chance") |
| ax2.set_xlabel("False positive rate") |
| ax2.set_ylabel("True positive rate") |
| ax2.set_title("ROC") |
| ax2.set_xlim(0, 1) |
| ax2.set_ylim(0.5, 1.005) |
| ax2.legend(loc="lower right", fontsize=9) |
|
|
| fig.suptitle("Discrimination improves across stacking stages (validation, seed=202)", y=1.02) |
| save(fig, "fig5_pr_roc", FIG) |
| print("saved fig5_pr_roc") |
|
|