"""Evaluate saved validation checkpoints with multiple scoring functions.""" from __future__ import annotations import argparse import importlib.util import re from pathlib import Path import numpy as np import pandas as pd import torch from sklearn.metrics import precision_recall_curve, roc_auc_score def load_train_module(path: Path): spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module def best_f1(labels: np.ndarray, scores: np.ndarray): precision, recall, thresholds = precision_recall_curve(labels, scores) f1s = 2 * precision * recall / (precision + recall + 1e-12) idx = int(np.argmax(f1s)) threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5 return float(f1s[idx]), threshold, float(roc_auc_score(labels, scores)) def infer_layers(path: Path, state: dict) -> int: if "layer_weight" in state: return int(state["layer_weight"].shape[0] - 1) text = f"{path.parent.parent.name}_{path.name}" match = re.search(r"_l(\d+)d", text) if match: return int(match.group(1)) match = re.search(r"L(\d+)", text) if match: return int(match.group(1)) return 4 @torch.no_grad() def score_model(module, model, data, pairs: np.ndarray, mode: str, batch_size: int) -> np.ndarray: model.eval() z = model.encode(data) author = z["author"].detach().cpu().numpy() paper = z["paper"].detach().cpu().numpy() scores = [] for start in range(0, len(pairs), batch_size): batch = pairs[start : start + batch_size] a = author[batch[:, 0]] p = paper[batch[:, 1]] if mode == "dot": s = np.sum(a * p, axis=1) elif mode == "cos": s = module.cos_sim(a, p) elif mode == "neg_l2": s = -np.sum((a - p) ** 2, axis=1) else: raise ValueError(mode) scores.append(s.astype(np.float32)) return np.concatenate(scores) def main() -> None: parser = argparse.ArgumentParser() parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) parser.add_argument("--split-seed", type=int, default=None) parser.add_argument("--train-frac", type=float, default=0.9) parser.add_argument("--run-glob", default="dyn*") parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") parser.add_argument("--batch-size", type=int, default=65536) args = parser.parse_args() root = args.package_root module = load_train_module(root / "code" / "train_val_lgcn_ensemble.py") if args.split_seed is None: split_name = "notebook_seed0" split_dir = root / "splits" / split_name parts = module.build_parts(root, split_dir, 79937) else: split_name = f"dynamic_seed{args.split_seed}" parts = module.build_parts(root, None, 79937, split_seed=args.split_seed, train_frac=args.train_frac) data_cache = {} val_pairs = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64) labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8) out_dir = root / "validation_runs" / split_name / "score_modes" out_dir.mkdir(parents=True, exist_ok=True) rows = [] checkpoint_paths = sorted((root / "validation_runs" / split_name).glob(f"{args.run_glob}/checkpoints/*.pt")) for path in checkpoint_paths: state = torch.load(path, map_location=args.device) embed_dim = state["author_emb.weight"].shape[1] variant = "learnw" if "learnw" in path.name else "vanilla" layers = infer_layers(path, state) run_name = path.parent.parent.name use_citation = "no_cite" not in run_name and "author_paper_only" not in run_name use_coauthor = "no_coauthor" not in run_name and "author_paper_only" not in run_name data_key = (use_citation, use_coauthor) if data_key not in data_cache: data_cache[data_key] = module.build_data( parts, 6611, 79937, torch.device(args.device), use_citation=use_citation, use_coauthor=use_coauthor, ) data = data_cache[data_key] model_cls = module.LearnableWeightLightGCN if variant == "learnw" else module.LightGCN model = model_cls(6611, parts["paper_feat_aug"].shape[1], embed_dim, layers).to(torch.device(args.device)) model.load_state_dict(state) stem = f"{path.parent.parent.name}_{path.stem}" for mode in ["cos", "dot", "neg_l2"]: scores = score_model(module, model, data, val_pairs, mode, args.batch_size) np.save(out_dir / f"{stem}_{mode}.npy", scores) f1, th, auc = best_f1(labels, scores) rows.append( { "run": path.parent.parent.name, "checkpoint": path.name, "variant": variant, "dim": embed_dim, "mode": mode, "f1": f1, "threshold": th, "auc": auc, } ) print(f"{stem} {mode}: f1={f1:.6f} th={th:.6f} auc={auc:.6f}") del model if torch.cuda.is_available(): torch.cuda.empty_cache() df = pd.DataFrame(rows).sort_values("f1", ascending=False) df.to_csv(out_dir / "score_mode_results.csv", index=False) print("\nTop results:") print(df.head(30).to_string(index=False)) if __name__ == "__main__": main()