| """Evaluate saved validation checkpoints with multiple scoring functions.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| import re |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| from sklearn.metrics import precision_recall_curve, roc_auc_score |
|
|
|
|
| def load_train_module(path: Path): |
| spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path) |
| module = importlib.util.module_from_spec(spec) |
| assert spec.loader is not None |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| def best_f1(labels: np.ndarray, scores: np.ndarray): |
| precision, recall, thresholds = precision_recall_curve(labels, scores) |
| f1s = 2 * precision * recall / (precision + recall + 1e-12) |
| idx = int(np.argmax(f1s)) |
| threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5 |
| return float(f1s[idx]), threshold, float(roc_auc_score(labels, scores)) |
|
|
|
|
| def infer_layers(path: Path, state: dict) -> int: |
| if "layer_weight" in state: |
| return int(state["layer_weight"].shape[0] - 1) |
| text = f"{path.parent.parent.name}_{path.name}" |
| match = re.search(r"_l(\d+)d", text) |
| if match: |
| return int(match.group(1)) |
| match = re.search(r"L(\d+)", text) |
| if match: |
| return int(match.group(1)) |
| return 4 |
|
|
|
|
| @torch.no_grad() |
| def score_model(module, model, data, pairs: np.ndarray, mode: str, batch_size: int) -> np.ndarray: |
| model.eval() |
| z = model.encode(data) |
| author = z["author"].detach().cpu().numpy() |
| paper = z["paper"].detach().cpu().numpy() |
| scores = [] |
| for start in range(0, len(pairs), batch_size): |
| batch = pairs[start : start + batch_size] |
| a = author[batch[:, 0]] |
| p = paper[batch[:, 1]] |
| if mode == "dot": |
| s = np.sum(a * p, axis=1) |
| elif mode == "cos": |
| s = module.cos_sim(a, p) |
| elif mode == "neg_l2": |
| s = -np.sum((a - p) ** 2, axis=1) |
| else: |
| raise ValueError(mode) |
| scores.append(s.astype(np.float32)) |
| return np.concatenate(scores) |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) |
| parser.add_argument("--split-seed", type=int, default=None) |
| parser.add_argument("--train-frac", type=float, default=0.9) |
| parser.add_argument("--run-glob", default="dyn*") |
| parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--batch-size", type=int, default=65536) |
| args = parser.parse_args() |
|
|
| root = args.package_root |
| module = load_train_module(root / "code" / "train_val_lgcn_ensemble.py") |
| if args.split_seed is None: |
| split_name = "notebook_seed0" |
| split_dir = root / "splits" / split_name |
| parts = module.build_parts(root, split_dir, 79937) |
| else: |
| split_name = f"dynamic_seed{args.split_seed}" |
| parts = module.build_parts(root, None, 79937, split_seed=args.split_seed, train_frac=args.train_frac) |
| data_cache = {} |
| val_pairs = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64) |
| labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8) |
|
|
| out_dir = root / "validation_runs" / split_name / "score_modes" |
| out_dir.mkdir(parents=True, exist_ok=True) |
|
|
| rows = [] |
| checkpoint_paths = sorted((root / "validation_runs" / split_name).glob(f"{args.run_glob}/checkpoints/*.pt")) |
| for path in checkpoint_paths: |
| state = torch.load(path, map_location=args.device) |
| embed_dim = state["author_emb.weight"].shape[1] |
| variant = "learnw" if "learnw" in path.name else "vanilla" |
| layers = infer_layers(path, state) |
| run_name = path.parent.parent.name |
| use_citation = "no_cite" not in run_name and "author_paper_only" not in run_name |
| use_coauthor = "no_coauthor" not in run_name and "author_paper_only" not in run_name |
| data_key = (use_citation, use_coauthor) |
| if data_key not in data_cache: |
| data_cache[data_key] = module.build_data( |
| parts, |
| 6611, |
| 79937, |
| torch.device(args.device), |
| use_citation=use_citation, |
| use_coauthor=use_coauthor, |
| ) |
| data = data_cache[data_key] |
| model_cls = module.LearnableWeightLightGCN if variant == "learnw" else module.LightGCN |
| model = model_cls(6611, parts["paper_feat_aug"].shape[1], embed_dim, layers).to(torch.device(args.device)) |
| model.load_state_dict(state) |
| stem = f"{path.parent.parent.name}_{path.stem}" |
| for mode in ["cos", "dot", "neg_l2"]: |
| scores = score_model(module, model, data, val_pairs, mode, args.batch_size) |
| np.save(out_dir / f"{stem}_{mode}.npy", scores) |
| f1, th, auc = best_f1(labels, scores) |
| rows.append( |
| { |
| "run": path.parent.parent.name, |
| "checkpoint": path.name, |
| "variant": variant, |
| "dim": embed_dim, |
| "mode": mode, |
| "f1": f1, |
| "threshold": th, |
| "auc": auc, |
| } |
| ) |
| print(f"{stem} {mode}: f1={f1:.6f} th={th:.6f} auc={auc:.6f}") |
| del model |
| if torch.cuda.is_available(): |
| torch.cuda.empty_cache() |
|
|
| df = pd.DataFrame(rows).sort_values("f1", ascending=False) |
| df.to_csv(out_dir / "score_mode_results.csv", index=False) |
| print("\nTop results:") |
| print(df.head(30).to_string(index=False)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|