cs3319-project2 / code /evaluate_val_checkpoints.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
5.7 kB
"""Evaluate saved validation checkpoints with multiple scoring functions."""
from __future__ import annotations
import argparse
import importlib.util
import re
from pathlib import Path
import numpy as np
import pandas as pd
import torch
from sklearn.metrics import precision_recall_curve, roc_auc_score
def load_train_module(path: Path):
spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
def best_f1(labels: np.ndarray, scores: np.ndarray):
precision, recall, thresholds = precision_recall_curve(labels, scores)
f1s = 2 * precision * recall / (precision + recall + 1e-12)
idx = int(np.argmax(f1s))
threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5
return float(f1s[idx]), threshold, float(roc_auc_score(labels, scores))
def infer_layers(path: Path, state: dict) -> int:
if "layer_weight" in state:
return int(state["layer_weight"].shape[0] - 1)
text = f"{path.parent.parent.name}_{path.name}"
match = re.search(r"_l(\d+)d", text)
if match:
return int(match.group(1))
match = re.search(r"L(\d+)", text)
if match:
return int(match.group(1))
return 4
@torch.no_grad()
def score_model(module, model, data, pairs: np.ndarray, mode: str, batch_size: int) -> np.ndarray:
model.eval()
z = model.encode(data)
author = z["author"].detach().cpu().numpy()
paper = z["paper"].detach().cpu().numpy()
scores = []
for start in range(0, len(pairs), batch_size):
batch = pairs[start : start + batch_size]
a = author[batch[:, 0]]
p = paper[batch[:, 1]]
if mode == "dot":
s = np.sum(a * p, axis=1)
elif mode == "cos":
s = module.cos_sim(a, p)
elif mode == "neg_l2":
s = -np.sum((a - p) ** 2, axis=1)
else:
raise ValueError(mode)
scores.append(s.astype(np.float32))
return np.concatenate(scores)
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--split-seed", type=int, default=None)
parser.add_argument("--train-frac", type=float, default=0.9)
parser.add_argument("--run-glob", default="dyn*")
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--batch-size", type=int, default=65536)
args = parser.parse_args()
root = args.package_root
module = load_train_module(root / "code" / "train_val_lgcn_ensemble.py")
if args.split_seed is None:
split_name = "notebook_seed0"
split_dir = root / "splits" / split_name
parts = module.build_parts(root, split_dir, 79937)
else:
split_name = f"dynamic_seed{args.split_seed}"
parts = module.build_parts(root, None, 79937, split_seed=args.split_seed, train_frac=args.train_frac)
data_cache = {}
val_pairs = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64)
labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8)
out_dir = root / "validation_runs" / split_name / "score_modes"
out_dir.mkdir(parents=True, exist_ok=True)
rows = []
checkpoint_paths = sorted((root / "validation_runs" / split_name).glob(f"{args.run_glob}/checkpoints/*.pt"))
for path in checkpoint_paths:
state = torch.load(path, map_location=args.device)
embed_dim = state["author_emb.weight"].shape[1]
variant = "learnw" if "learnw" in path.name else "vanilla"
layers = infer_layers(path, state)
run_name = path.parent.parent.name
use_citation = "no_cite" not in run_name and "author_paper_only" not in run_name
use_coauthor = "no_coauthor" not in run_name and "author_paper_only" not in run_name
data_key = (use_citation, use_coauthor)
if data_key not in data_cache:
data_cache[data_key] = module.build_data(
parts,
6611,
79937,
torch.device(args.device),
use_citation=use_citation,
use_coauthor=use_coauthor,
)
data = data_cache[data_key]
model_cls = module.LearnableWeightLightGCN if variant == "learnw" else module.LightGCN
model = model_cls(6611, parts["paper_feat_aug"].shape[1], embed_dim, layers).to(torch.device(args.device))
model.load_state_dict(state)
stem = f"{path.parent.parent.name}_{path.stem}"
for mode in ["cos", "dot", "neg_l2"]:
scores = score_model(module, model, data, val_pairs, mode, args.batch_size)
np.save(out_dir / f"{stem}_{mode}.npy", scores)
f1, th, auc = best_f1(labels, scores)
rows.append(
{
"run": path.parent.parent.name,
"checkpoint": path.name,
"variant": variant,
"dim": embed_dim,
"mode": mode,
"f1": f1,
"threshold": th,
"auc": auc,
}
)
print(f"{stem} {mode}: f1={f1:.6f} th={th:.6f} auc={auc:.6f}")
del model
if torch.cuda.is_available():
torch.cuda.empty_cache()
df = pd.DataFrame(rows).sort_values("f1", ascending=False)
df.to_csv(out_dir / "score_mode_results.csv", index=False)
print("\nTop results:")
print(df.head(30).to_string(index=False))
if __name__ == "__main__":
main()