cs3319-project2 / code /train_val_mf_bpr.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
4.45 kB
"""Matrix factorization BPR baseline on the notebook validation split."""
from __future__ import annotations
import argparse
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, roc_auc_score
def best_f1(y, s):
p, r, t = precision_recall_curve(y, s)
f = 2 * p * r / (p + r + 1e-12)
i = int(np.argmax(f))
return float(f[i]), float(t[i] if i < len(t) else 0.5), float(roc_auc_score(y, s))
class MF(nn.Module):
def __init__(self, n_author: int, n_paper: int, dim: int):
super().__init__()
self.a = nn.Embedding(n_author, dim)
self.p = nn.Embedding(n_paper, dim)
self.ab = nn.Embedding(n_author, 1)
self.pb = nn.Embedding(n_paper, 1)
nn.init.normal_(self.a.weight, std=0.05)
nn.init.normal_(self.p.weight, std=0.05)
nn.init.zeros_(self.ab.weight)
nn.init.zeros_(self.pb.weight)
def score(self, pairs):
a = pairs[:, 0]
p = pairs[:, 1]
return (self.a(a) * self.p(p)).sum(-1) + self.ab(a).squeeze(-1) + self.pb(p).squeeze(-1)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--dim", type=int, default=256)
parser.add_argument("--epochs", type=int, default=300)
parser.add_argument("--batch-size", type=int, default=65536)
parser.add_argument("--seed", type=int, default=0)
args = parser.parse_args()
torch.manual_seed(args.seed)
np.random.seed(args.seed)
root = args.package_root
split = root / "splits" / "notebook_seed0"
train = pd.read_csv(split / "train_refs.csv")[["source", "target"]].to_numpy(np.int64)
val_df = pd.read_csv(split / "val_pairs.csv")
val = val_df[["source", "target"]].to_numpy(np.int64)
y = val_df["label"].to_numpy(np.int8)
train_set = set(map(tuple, train.tolist()))
device = torch.device(args.device)
model = MF(6611, 79937, args.dim).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-6)
train_t = torch.as_tensor(train, dtype=torch.long, device=device)
val_t = torch.as_tensor(val, dtype=torch.long, device=device)
best = (-1, 0, 0)
best_scores = None
out = root / "validation_runs" / "notebook_seed0" / f"mf_bpr_s{args.seed}_d{args.dim}"
out.mkdir(parents=True, exist_ok=True)
rng = np.random.default_rng(args.seed)
for ep in range(args.epochs):
idx = torch.randint(0, train_t.size(0), (args.batch_size,), device=device)
pos = train_t[idx]
neg_np = np.empty((args.batch_size, 2), dtype=np.int64)
authors = pos[:, 0].detach().cpu().numpy()
filled = 0
while filled < args.batch_size:
n = args.batch_size - filled
papers = rng.integers(0, 79937, size=n)
for a, p in zip(authors[filled:], papers):
if (int(a), int(p)) not in train_set:
neg_np[filled] = (a, p)
filled += 1
if filled >= args.batch_size:
break
neg = torch.as_tensor(neg_np, dtype=torch.long, device=device)
loss = -F.logsigmoid(model.score(pos) - model.score(neg)).mean()
opt.zero_grad()
loss.backward()
opt.step()
if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
with torch.no_grad():
scores = []
for st in range(0, len(val), 131072):
scores.append(model.score(val_t[st : st + 131072]).detach().cpu().numpy())
scores = np.concatenate(scores).astype(np.float32)
f1, th, auc = best_f1(y, scores)
if f1 > best[0]:
best = (f1, th, auc)
best_scores = scores
torch.save(model.state_dict(), out / "model.pt")
print(f"epoch={ep+1:03d} loss={loss.item():.4f} f1={f1:.5f} th={th:.5f} auc={auc:.5f}")
if best_scores is not None:
np.save(out / f"val_mf_bpr_s{args.seed}_d{args.dim}.npy", best_scores)
(out / "result.txt").write_text(f"f1={best[0]:.8f}\nthreshold={best[1]:.8f}\nauc={best[2]:.8f}\n")
print("best", best)
if __name__ == "__main__":
main()