File size: 4,452 Bytes
f28d994
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
"""Matrix factorization BPR baseline on the notebook validation split."""

from __future__ import annotations

import argparse
from pathlib import Path

import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from sklearn.metrics import precision_recall_curve, roc_auc_score


def best_f1(y, s):
    p, r, t = precision_recall_curve(y, s)
    f = 2 * p * r / (p + r + 1e-12)
    i = int(np.argmax(f))
    return float(f[i]), float(t[i] if i < len(t) else 0.5), float(roc_auc_score(y, s))


class MF(nn.Module):
    def __init__(self, n_author: int, n_paper: int, dim: int):
        super().__init__()
        self.a = nn.Embedding(n_author, dim)
        self.p = nn.Embedding(n_paper, dim)
        self.ab = nn.Embedding(n_author, 1)
        self.pb = nn.Embedding(n_paper, 1)
        nn.init.normal_(self.a.weight, std=0.05)
        nn.init.normal_(self.p.weight, std=0.05)
        nn.init.zeros_(self.ab.weight)
        nn.init.zeros_(self.pb.weight)

    def score(self, pairs):
        a = pairs[:, 0]
        p = pairs[:, 1]
        return (self.a(a) * self.p(p)).sum(-1) + self.ab(a).squeeze(-1) + self.pb(p).squeeze(-1)


def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
    parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
    parser.add_argument("--dim", type=int, default=256)
    parser.add_argument("--epochs", type=int, default=300)
    parser.add_argument("--batch-size", type=int, default=65536)
    parser.add_argument("--seed", type=int, default=0)
    args = parser.parse_args()

    torch.manual_seed(args.seed)
    np.random.seed(args.seed)
    root = args.package_root
    split = root / "splits" / "notebook_seed0"
    train = pd.read_csv(split / "train_refs.csv")[["source", "target"]].to_numpy(np.int64)
    val_df = pd.read_csv(split / "val_pairs.csv")
    val = val_df[["source", "target"]].to_numpy(np.int64)
    y = val_df["label"].to_numpy(np.int8)
    train_set = set(map(tuple, train.tolist()))
    device = torch.device(args.device)
    model = MF(6611, 79937, args.dim).to(device)
    opt = torch.optim.AdamW(model.parameters(), lr=0.01, weight_decay=1e-6)
    train_t = torch.as_tensor(train, dtype=torch.long, device=device)
    val_t = torch.as_tensor(val, dtype=torch.long, device=device)
    best = (-1, 0, 0)
    best_scores = None
    out = root / "validation_runs" / "notebook_seed0" / f"mf_bpr_s{args.seed}_d{args.dim}"
    out.mkdir(parents=True, exist_ok=True)
    rng = np.random.default_rng(args.seed)

    for ep in range(args.epochs):
        idx = torch.randint(0, train_t.size(0), (args.batch_size,), device=device)
        pos = train_t[idx]
        neg_np = np.empty((args.batch_size, 2), dtype=np.int64)
        authors = pos[:, 0].detach().cpu().numpy()
        filled = 0
        while filled < args.batch_size:
            n = args.batch_size - filled
            papers = rng.integers(0, 79937, size=n)
            for a, p in zip(authors[filled:], papers):
                if (int(a), int(p)) not in train_set:
                    neg_np[filled] = (a, p)
                    filled += 1
                    if filled >= args.batch_size:
                        break
        neg = torch.as_tensor(neg_np, dtype=torch.long, device=device)
        loss = -F.logsigmoid(model.score(pos) - model.score(neg)).mean()
        opt.zero_grad()
        loss.backward()
        opt.step()
        if (ep + 1) % 20 == 0 or ep == args.epochs - 1:
            with torch.no_grad():
                scores = []
                for st in range(0, len(val), 131072):
                    scores.append(model.score(val_t[st : st + 131072]).detach().cpu().numpy())
                scores = np.concatenate(scores).astype(np.float32)
            f1, th, auc = best_f1(y, scores)
            if f1 > best[0]:
                best = (f1, th, auc)
                best_scores = scores
                torch.save(model.state_dict(), out / "model.pt")
            print(f"epoch={ep+1:03d} loss={loss.item():.4f} f1={f1:.5f} th={th:.5f} auc={auc:.5f}")
    if best_scores is not None:
        np.save(out / f"val_mf_bpr_s{args.seed}_d{args.dim}.npy", best_scores)
    (out / "result.txt").write_text(f"f1={best[0]:.8f}\nthreshold={best[1]:.8f}\nauc={best[2]:.8f}\n")
    print("best", best)


if __name__ == "__main__":
    main()