File size: 5,703 Bytes
f28d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
"""Evaluate saved validation checkpoints with multiple scoring functions."""

from __future__ import annotations

import argparse
import importlib.util
import re
from pathlib import Path

import numpy as np
import pandas as pd
import torch
from sklearn.metrics import precision_recall_curve, roc_auc_score


def load_train_module(path: Path):
    spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path)
    module = importlib.util.module_from_spec(spec)
    assert spec.loader is not None
    spec.loader.exec_module(module)
    return module


def best_f1(labels: np.ndarray, scores: np.ndarray):
    precision, recall, thresholds = precision_recall_curve(labels, scores)
    f1s = 2 * precision * recall / (precision + recall + 1e-12)
    idx = int(np.argmax(f1s))
    threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5
    return float(f1s[idx]), threshold, float(roc_auc_score(labels, scores))


def infer_layers(path: Path, state: dict) -> int:
    if "layer_weight" in state:
        return int(state["layer_weight"].shape[0] - 1)
    text = f"{path.parent.parent.name}_{path.name}"
    match = re.search(r"_l(\d+)d", text)
    if match:
        return int(match.group(1))
    match = re.search(r"L(\d+)", text)
    if match:
        return int(match.group(1))
    return 4


@torch.no_grad()
def score_model(module, model, data, pairs: np.ndarray, mode: str, batch_size: int) -> np.ndarray:
    model.eval()
    z = model.encode(data)
    author = z["author"].detach().cpu().numpy()
    paper = z["paper"].detach().cpu().numpy()
    scores = []
    for start in range(0, len(pairs), batch_size):
        batch = pairs[start : start + batch_size]
        a = author[batch[:, 0]]
        p = paper[batch[:, 1]]
        if mode == "dot":
            s = np.sum(a * p, axis=1)
        elif mode == "cos":
            s = module.cos_sim(a, p)
        elif mode == "neg_l2":
            s = -np.sum((a - p) ** 2, axis=1)
        else:
            raise ValueError(mode)
        scores.append(s.astype(np.float32))
    return np.concatenate(scores)


def main() -> None:
    parser = argparse.ArgumentParser()
    parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
    parser.add_argument("--split-seed", type=int, default=None)
    parser.add_argument("--train-frac", type=float, default=0.9)
    parser.add_argument("--run-glob", default="dyn*")
    parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--batch-size", type=int, default=65536)
    args = parser.parse_args()

    root = args.package_root
    module = load_train_module(root / "code" / "train_val_lgcn_ensemble.py")
    if args.split_seed is None:
        split_name = "notebook_seed0"
        split_dir = root / "splits" / split_name
        parts = module.build_parts(root, split_dir, 79937)
    else:
        split_name = f"dynamic_seed{args.split_seed}"
        parts = module.build_parts(root, None, 79937, split_seed=args.split_seed, train_frac=args.train_frac)
    data_cache = {}
    val_pairs = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64)
    labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8)

    out_dir = root / "validation_runs" / split_name / "score_modes"
    out_dir.mkdir(parents=True, exist_ok=True)

    rows = []
    checkpoint_paths = sorted((root / "validation_runs" / split_name).glob(f"{args.run_glob}/checkpoints/*.pt"))
    for path in checkpoint_paths:
        state = torch.load(path, map_location=args.device)
        embed_dim = state["author_emb.weight"].shape[1]
        variant = "learnw" if "learnw" in path.name else "vanilla"
        layers = infer_layers(path, state)
        run_name = path.parent.parent.name
        use_citation = "no_cite" not in run_name and "author_paper_only" not in run_name
        use_coauthor = "no_coauthor" not in run_name and "author_paper_only" not in run_name
        data_key = (use_citation, use_coauthor)
        if data_key not in data_cache:
            data_cache[data_key] = module.build_data(
                parts,
                6611,
                79937,
                torch.device(args.device),
                use_citation=use_citation,
                use_coauthor=use_coauthor,
            )
        data = data_cache[data_key]
        model_cls = module.LearnableWeightLightGCN if variant == "learnw" else module.LightGCN
        model = model_cls(6611, parts["paper_feat_aug"].shape[1], embed_dim, layers).to(torch.device(args.device))
        model.load_state_dict(state)
        stem = f"{path.parent.parent.name}_{path.stem}"
        for mode in ["cos", "dot", "neg_l2"]:
            scores = score_model(module, model, data, val_pairs, mode, args.batch_size)
            np.save(out_dir / f"{stem}_{mode}.npy", scores)
            f1, th, auc = best_f1(labels, scores)
            rows.append(
                {
                    "run": path.parent.parent.name,
                    "checkpoint": path.name,
                    "variant": variant,
                    "dim": embed_dim,
                    "mode": mode,
                    "f1": f1,
                    "threshold": th,
                    "auc": auc,
                }
            )
            print(f"{stem} {mode}: f1={f1:.6f} th={th:.6f} auc={auc:.6f}")
        del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()

    df = pd.DataFrame(rows).sort_values("f1", ascending=False)
    df.to_csv(out_dir / "score_mode_results.csv", index=False)
    print("\nTop results:")
    print(df.head(30).to_string(index=False))


if __name__ == "__main__":
    main()