cs3319-project2 / code /train_val_hgt_bpr.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
7.55 kB
"""Dynamic-split HGT attention recommender for author-paper validation."""
from __future__ import annotations
import argparse
import importlib.util
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.nn import HGTConv
def load_lgcn_module(path: Path):
spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path)
module = importlib.util.module_from_spec(spec)
assert spec.loader is not None
spec.loader.exec_module(module)
return module
class HGTRecommender(nn.Module):
def __init__(
self,
metadata,
num_authors: int,
paper_dim: int,
hidden_dim: int,
layers: int,
heads: int,
dropout: float,
):
super().__init__()
self.dropout = dropout
self.author_emb = nn.Embedding(num_authors, hidden_dim)
self.paper_proj = nn.Linear(paper_dim, hidden_dim)
self.convs = nn.ModuleList(
[HGTConv(hidden_dim, hidden_dim, metadata, heads=heads) for _ in range(layers)]
)
self.norms = nn.ModuleList(
[nn.ModuleDict({nt: nn.LayerNorm(hidden_dim) for nt in metadata[0]}) for _ in range(layers)]
)
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.author_emb.weight)
self.paper_proj.reset_parameters()
def encode(self, data):
x = {"author": self.author_emb.weight, "paper": self.paper_proj(data["paper"].x)}
for conv, norm in zip(self.convs, self.norms):
h = conv(x, data.edge_index_dict)
x = {
nt: norm[nt](x[nt] + F.dropout(F.relu(h[nt]), p=self.dropout, training=self.training))
for nt in x
}
return x
def decode(self, z, edge_index):
src, dst = edge_index
return (z["author"][src] * z["paper"][dst]).sum(-1)
@torch.no_grad()
def predict_scores(model, data, pairs: np.ndarray, batch_size: int) -> np.ndarray:
model.eval()
z = model.encode(data)
author = z["author"].detach().cpu().numpy()
paper = z["paper"].detach().cpu().numpy()
scores = []
for start in range(0, len(pairs), batch_size):
batch = pairs[start : start + batch_size]
scores.append(np.sum(author[batch[:, 0]] * paper[batch[:, 1]], axis=1).astype(np.float32))
return np.concatenate(scores)
def train_one(args, lgcn, parts, data, seed: int, out_dir: Path):
lgcn.set_seed(seed)
device = torch.device(args.device)
model = HGTRecommender(
data.metadata(),
args.num_authors,
parts["paper_feat_aug"].shape[1],
args.hidden_dim,
args.layers,
args.heads,
args.dropout,
).to(device)
opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
pos_edges = data["author", "ref", "paper"].edge_index
batch_size = min(args.train_batch_size, pos_edges.size(1))
val_arr = parts["val_pairs"][["source", "target"]].to_numpy(np.int64)
labels = parts["val_pairs"]["label"].to_numpy(np.int8)
best = (-1.0, 0.0, 0.0)
best_state = None
for epoch in range(args.epochs):
model.train()
perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size]
pos = pos_edges[:, perm]
neg = lgcn.sample_hard_negatives(
parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device
)
z = model.encode(data)
pos_scores = model.decode(z, pos).repeat_interleave(args.neg_per_pos)
neg_scores = model.decode(z, neg)
loss = -F.logsigmoid(pos_scores - neg_scores).mean()
opt.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1:
scores = predict_scores(model, data, val_arr, args.pred_batch_size)
f1, th, auc = lgcn.best_f1(labels, scores)
if f1 > best[0]:
best = (f1, th, auc)
best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
np.save(out_dir / "scores" / f"val_hgt_dot_s{seed}_d{args.hidden_dim}.npy", scores)
print(
f"seed={seed} epoch={epoch+1:03d} loss={loss.item():.4f} "
f"val_f1={f1:.5f} th={th:.5f} auc={auc:.5f}"
)
if best_state is not None:
torch.save(best_state, out_dir / "checkpoints" / f"hgt_val_s{seed}_d{args.hidden_dim}.pt")
return best
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--split-seed", type=int, required=True)
parser.add_argument("--train-frac", type=float, default=0.9)
parser.add_argument("--device", default="cuda:0")
parser.add_argument("--run-name", required=True)
parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42])
parser.add_argument("--hidden-dim", type=int, default=128)
parser.add_argument("--layers", type=int, default=2)
parser.add_argument("--heads", type=int, default=4)
parser.add_argument("--dropout", type=float, default=0.1)
parser.add_argument("--epochs", type=int, default=120)
parser.add_argument("--eval-every", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.002)
parser.add_argument("--weight-decay", type=float, default=1e-4)
parser.add_argument("--train-batch-size", type=int, default=32768)
parser.add_argument("--pred-batch-size", type=int, default=65536)
parser.add_argument("--neg-per-pos", type=int, default=3)
parser.add_argument("--num-authors", type=int, default=6611)
parser.add_argument("--num-papers", type=int, default=79937)
args = parser.parse_args()
root = args.package_root
lgcn = load_lgcn_module(root / "code" / "train_val_lgcn_ensemble.py")
parts = lgcn.build_parts(root, None, args.num_papers, split_seed=args.split_seed, train_frac=args.train_frac)
data = lgcn.build_data(parts, args.num_authors, args.num_papers, torch.device(args.device))
out_dir = root / "validation_runs" / f"dynamic_seed{args.split_seed}" / args.run_name
(out_dir / "scores").mkdir(parents=True, exist_ok=True)
(out_dir / "checkpoints").mkdir(parents=True, exist_ok=True)
rows = []
for seed in args.seeds:
f1, th, auc = train_one(args, lgcn, parts, data, seed, out_dir)
rows.append({"seed": seed, "dim": args.hidden_dim, "f1": f1, "threshold": th, "auc": auc})
pd.DataFrame(rows).sort_values("f1", ascending=False).to_csv(out_dir / "model_results.csv", index=False)
labels = parts["val_pairs"]["label"].to_numpy(np.int8)
vals, names = [], []
for seed in args.seeds:
p = out_dir / "scores" / f"val_hgt_dot_s{seed}_d{args.hidden_dim}.npy"
if p.exists():
vals.append(np.load(p))
names.append(p.stem)
if vals:
ens = np.mean(vals, axis=0)
f1, th, auc = lgcn.best_f1(labels, ens)
np.save(out_dir / "scores" / "val_hgt_ensemble_mean.npy", ens)
(out_dir / "ensemble_result.txt").write_text(
f"models={','.join(names)}\nf1={f1:.8f}\nthreshold={th:.8f}\nauc={auc:.8f}\n"
)
print(f"\nMean ensemble: f1={f1:.5f} threshold={th:.5f} auc={auc:.5f}")
if __name__ == "__main__":
main()