| """Matrix factorization BPR baseline on the notebook validation split.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from sklearn.metrics import precision_recall_curve, roc_auc_score |
|
|
|
|
| def best_f1(y, s): |
| p, r, t = precision_recall_curve(y, s) |
| f = 2 * p * r / (p + r + 1e-12) |
| i = int(np.argmax(f)) |
| return float(f[i]), float(t[i] if i < len(t) else 0.5), float(roc_auc_score(y, s)) |
|
|
|
|
| class MF(nn.Module): |
| def __init__(self, n_author: int, n_paper: int, dim: int): |
| super().__init__() |
| self.a = nn.Embedding(n_author, dim) |
| self.p = nn.Embedding(n_paper, dim) |
| self.ab = nn.Embedding(n_author, 1) |
| self.pb = nn.Embedding(n_paper, 1) |
| nn.init.normal_(self.a.weight, std=0.05) |
| nn.init.normal_(self.p.weight, std=0.05) |
| nn.init.zeros_(self.ab.weight) |
| nn.init.zeros_(self.pb.weight) |
|
|
| def score(self, pairs): |
| a = pairs[:, 0] |
| p = pairs[:, 1] |
| return (self.a(a) * self.p(p)).sum(-1) + self.ab(a).squeeze(-1) + self.pb(p).squeeze(-1) |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) |
| parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--dim", type=int, default=256) |
| parser.add_argument("--epochs", type=int, default=300) |
| parser.add_argument("--batch-size", type=int, default=65536) |
| parser.add_argument("--seed", type=int, default=0) |
| args = parser.parse_args() |
|
|
| torch.manual_seed(args.seed) |
| np.random.seed(args.seed) |
| root = args.package_root |
| split = root / "splits" / "notebook_seed0" |
| train = pd.read_csv(split / "train_refs.csv")[["source", "target"]].to_numpy(np.int64) |
| val_df = pd.read_csv(split / "val_pairs.csv") |
| val = val_df[["source", "target"]].to_numpy(np.int64) |
| y = val_df["label"].to_numpy(np.int8) |
| train_set = set(map(tuple, train.tolist())) |
| device = torch.device(args.device) |
| model = MF(6611, 79937, args.dim).to(device) |
| opt = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-6) |
| train_t = torch.as_tensor(train, dtype=torch.long, device=device) |
| val_t = torch.as_tensor(val, dtype=torch.long, device=device) |
| best = (-1, 0, 0) |
| best_scores = None |
| out = root / "validation_runs" / "notebook_seed0" / f"mf_bpr_s{args.seed}_d{args.dim}" |
| out.mkdir(parents=True, exist_ok=True) |
| rng = np.random.default_rng(args.seed) |
|
|
| for ep in range(args.epochs): |
| idx = torch.randint(0, train_t.size(0), (args.batch_size,), device=device) |
| pos = train_t[idx] |
| neg_np = np.empty((args.batch_size, 2), dtype=np.int64) |
| authors = pos[:, 0].detach().cpu().numpy() |
| filled = 0 |
| while filled < args.batch_size: |
| n = args.batch_size - filled |
| papers = rng.integers(0, 79937, size=n) |
| for a, p in zip(authors[filled:], papers): |
| if (int(a), int(p)) not in train_set: |
| neg_np[filled] = (a, p) |
| filled += 1 |
| if filled >= args.batch_size: |
| break |
| neg = torch.as_tensor(neg_np, dtype=torch.long, device=device) |
| loss = -F.logsigmoid(model.score(pos) - model.score(neg)).mean() |
| opt.zero_grad() |
| loss.backward() |
| opt.step() |
| if (ep + 1) % 20 == 0 or ep == args.epochs - 1: |
| with torch.no_grad(): |
| scores = [] |
| for st in range(0, len(val), 131072): |
| scores.append(model.score(val_t[st : st + 131072]).detach().cpu().numpy()) |
| scores = np.concatenate(scores).astype(np.float32) |
| f1, th, auc = best_f1(y, scores) |
| if f1 > best[0]: |
| best = (f1, th, auc) |
| best_scores = scores |
| torch.save(model.state_dict(), out / "model.pt") |
| print(f"epoch={ep+1:03d} loss={loss.item():.4f} f1={f1:.5f} th={th:.5f} auc={auc:.5f}") |
| if best_scores is not None: |
| np.save(out / f"val_mf_bpr_s{args.seed}_d{args.dim}.npy", best_scores) |
| (out / "result.txt").write_text(f"f1={best[0]:.8f}\nthreshold={best[1]:.8f}\nauc={best[2]:.8f}\n") |
| print("best", best) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|