cs3319-project2 / code /train_val_sage_bpr.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
7.83 kB
"""Dynamic-split GraphSAGE hetero recommender for validation fusion."""
from __future__ import annotations
import argparse
import importlib.util
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HeteroConv, SAGEConv
def load_lgcn_module(path: Path):
spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
class ResidualSAGE(nn.Module):
def __init__(self, metadata, hidden_dim: int, num_layers: int, dropout: float):
super().__init__()
self.dropout = dropout
self.convs = nn.ModuleList()
self.norms = nn.ModuleList()
for _ in range(num_layers):
self.convs.append(
HeteroConv(
{et: SAGEConv((hidden_dim, hidden_dim), hidden_dim) for et in metadata[1]},
aggr="mean",
)
)
self.norms.append(nn.ModuleDict({nt: nn.LayerNorm(hidden_dim) for nt in metadata[0]}))
def forward(self, x_dict, edge_index_dict):
for conv, norm in zip(self.convs, self.norms):
h = conv(x_dict, edge_index_dict)
out = {}
for nt, x in x_dict.items():
y = h.get(nt, x)
y = F.dropout(F.relu(y), p=self.dropout, training=self.training)
out[nt] = norm[nt](x + y)
x_dict = out
return x_dict
class SAGERecommender(nn.Module):
def __init__(self, metadata, num_authors: int, paper_dim: int, hidden_dim: int, num_layers: int, dropout: float):
super().__init__()
self.author_emb = nn.Embedding(num_authors, hidden_dim)
self.paper_proj = nn.Linear(paper_dim, hidden_dim)
self.encoder = ResidualSAGE(metadata, hidden_dim, num_layers, dropout)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.author_emb.weight)
self.paper_proj.reset_parameters()
def encode(self, data):
x = {"author": self.author_emb.weight, "paper": self.paper_proj(data["paper"].x)}
return self.encoder(x, data.edge_index_dict)
def decode(self, z, edge_index):
src, dst = edge_index
return (z["author"][src] * z["paper"][dst]).sum(-1)
@torch.no_grad()
def predict_scores(model, data, pairs: np.ndarray, batch_size: int) -> np.ndarray:
model.eval()
z = model.encode(data)
a = z["author"].detach().cpu().numpy()
p = z["paper"].detach().cpu().numpy()
scores = []
for st in range(0, len(pairs), batch_size):
b = pairs[st : st + batch_size]
scores.append(np.sum(a[b[:, 0]] * p[b[:, 1]], axis=1).astype(np.float32))
return np.concatenate(scores)
def train_one(args, lgcn, parts, data, seed: int, out_dir: Path):
lgcn.set_seed(seed)
device = torch.device(args.device)
model = SAGERecommender(
data.metadata(),
args.num_authors,
parts["paper_feat_aug"].shape[1],
args.hidden_dim,
args.layers,
args.dropout,
).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
pos_edges = data["author", "ref", "paper"].edge_index
batch_size = min(args.train_batch_size, pos_edges.size(1))
val_arr = parts["val_pairs"][["source", "target"]].to_numpy(np.int64)
labels = parts["val_pairs"]["label"].to_numpy(np.int8)
best = (-1.0, 0.0, 0.0)
best_state = None
for epoch in range(args.epochs):
model.train()
perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size]
pos = pos_edges[:, perm]
neg = lgcn.sample_hard_negatives(parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device)
z = model.encode(data)
pos_s = model.decode(z, pos).repeat_interleave(args.neg_per_pos)
neg_s = model.decode(z, neg)
loss = -F.logsigmoid(pos_s - neg_s).mean()
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1:
scores = predict_scores(model, data, val_arr, args.pred_batch_size)
f1, th, auc = lgcn.best_f1(labels, scores)
if f1 > best[0]:
best = (f1, th, auc)
best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
np.save(out_dir / "scores" / f"val_sage_dot_s{seed}_d{args.hidden_dim}.npy", scores)
print(f"seed={seed} epoch={epoch+1:03d} loss={loss.item():.4f} val_f1={f1:.5f} th={th:.5f} auc={auc:.5f}")
if best_state is not None:
torch.save(best_state, out_dir / "checkpoints" / f"sage_val_s{seed}_d{args.hidden_dim}.pt")
return best
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--split-seed", type=int, required=True)
parser.add_argument("--train-frac", type=float, default=0.9)
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--run-name", required=True)
parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42])
parser.add_argument("--hidden-dim", type=int, default=256)
parser.add_argument("--layers", type=int, default=2)
parser.add_argument("--epochs", type=int, default=140)
parser.add_argument("--eval-every", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.003)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--train-batch-size", type=int, default=32768)
parser.add_argument("--pred-batch-size", type=int, default=65536)
parser.add_argument("--neg-per-pos", type=int, default=3)
parser.add_argument("--num-authors", type=int, default=6611)
parser.add_argument("--num-papers", type=int, default=79937)
args = parser.parse_args()
root = args.package_root
lgcn = load_lgcn_module(root / "code" / "train_val_lgcn_ensemble.py")
parts = lgcn.build_parts(root, None, args.num_papers, split_seed=args.split_seed, train_frac=args.train_frac)
data = lgcn.build_data(parts, args.num_authors, args.num_papers, torch.device(args.device))
out_dir = root / "validation_runs" / f"dynamic_seed{args.split_seed}" / args.run_name
(out_dir / "scores").mkdir(parents=True, exist_ok=True)
(out_dir / "checkpoints").mkdir(parents=True, exist_ok=True)
rows = []
for seed in args.seeds:
f1, th, auc = train_one(args, lgcn, parts, data, seed, out_dir)
rows.append({"seed": seed, "dim": args.hidden_dim, "f1": f1, "threshold": th, "auc": auc})
pd.DataFrame(rows).sort_values("f1", ascending=False).to_csv(out_dir / "model_results.csv", index=False)
labels = parts["val_pairs"]["label"].to_numpy(np.int8)
vals = []
names = []
for seed in args.seeds:
p = out_dir / "scores" / f"val_sage_dot_s{seed}_d{args.hidden_dim}.npy"
if p.exists():
vals.append(np.load(p))
names.append(p.stem)
if vals:
ens = np.mean(vals, axis=0)
f1, th, auc = lgcn.best_f1(labels, ens)
np.save(out_dir / "scores" / "val_sage_ensemble_mean.npy", ens)
(out_dir / "ensemble_result.txt").write_text(
f"models={','.join(names)}\nf1={f1:.8f}\nthreshold={th:.8f}\nauc={auc:.8f}\n"
)
print(f"\nMean ensemble: f1={f1:.5f} threshold={th:.5f} auc={auc:.5f}")
if __name__ == "__main__":
main()