File size: 2,208 Bytes
f28d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
"""Fig 5: Precision-Recall and ROC curves across model stages.

Uses the row-aligned validation labels (val_labels_seed202.npy, alignment-verified)
against cached OOF / val scores for four representative stages.
"""
from pathlib import Path
import sys

import numpy as np
import matplotlib.pyplot as plt
from sklearn.metrics import (
    precision_recall_curve,
    roc_curve,
    auc,
    average_precision_score,
)

sys.path.insert(0, str(Path(__file__).resolve().parent))
from plot_style import apply, save, PALETTE_DEEP as C  # noqa: E402

apply()
ROOT = Path(__file__).resolve().parents[2]
FIG = ROOT / "reports" / "figures"
VR = ROOT / "validation_runs" / "dynamic_seed202"
y = np.load(VR / "val_labels_seed202.npy").astype(int)

models = [
    ("LightGCN ensemble",        VR / "dyn202_l2d512_bpr_bigbatch_more/scores/val_vanilla_ensemble_mean.npy", C[7]),
    ("+ graph stack (post95)",   VR / "post95_ablation/ensemble_lgcn_oof.npy",                                  C[1]),
    ("+ DeepWalk/Node2Vec",      VR / "node2vec_deepwalk/node2vec_stack_oof.npy",                               C[2]),
    ("+ high-order (final)",     VR / "high_order_graph_stack/rich_rw7_highorder_directed_oof.npy",            C[3]),
]

fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(12.5, 5.2))

for name, path, color in models:
    s = np.load(path).astype(np.float64)
    p, r, _ = precision_recall_curve(y, s)
    ax1.plot(r, p, color=color, lw=1.9,
             label=f"{name} (AP={average_precision_score(y, s):.4f})")
    fpr, tpr, _ = roc_curve(y, s)
    ax2.plot(fpr, tpr, color=color, lw=1.9, label=f"{name} (AUC={auc(fpr, tpr):.4f})")

ax1.set_xlabel("Recall")
ax1.set_ylabel("Precision")
ax1.set_title("Precision-Recall")
ax1.set_xlim(0, 1)
ax1.set_ylim(0.9, 1.005)
ax1.legend(loc="lower left", fontsize=9)

ax2.plot([0, 1], [0, 1], color="gray", lw=1, ls=":", label="chance")
ax2.set_xlabel("False positive rate")
ax2.set_ylabel("True positive rate")
ax2.set_title("ROC")
ax2.set_xlim(0, 1)
ax2.set_ylim(0.5, 1.005)
ax2.legend(loc="lower right", fontsize=9)

fig.suptitle("Discrimination improves across stacking stages (validation, seed=202)", y=1.02)
save(fig, "fig5_pr_roc", FIG)
print("saved fig5_pr_roc")