| """Dynamic-split GraphSAGE hetero recommender for validation fusion.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import importlib.util |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch_geometric.nn import HeteroConv, SAGEConv |
|
|
|
|
| def load_lgcn_module(path: Path): |
| spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path) |
| module = importlib.util.module_from_spec(spec) |
| assert spec.loader is not None |
| spec.loader.exec_module(module) |
| return module |
|
|
|
|
| class ResidualSAGE(nn.Module): |
| def __init__(self, metadata, hidden_dim: int, num_layers: int, dropout: float): |
| super().__init__() |
| self.dropout = dropout |
| self.convs = nn.ModuleList() |
| self.norms = nn.ModuleList() |
| for _ in range(num_layers): |
| self.convs.append( |
| HeteroConv( |
| {et: SAGEConv((hidden_dim, hidden_dim), hidden_dim) for et in metadata[1]}, |
| aggr="mean", |
| ) |
| ) |
| self.norms.append(nn.ModuleDict({nt: nn.LayerNorm(hidden_dim) for nt in metadata[0]})) |
|
|
| def forward(self, x_dict, edge_index_dict): |
| for conv, norm in zip(self.convs, self.norms): |
| h = conv(x_dict, edge_index_dict) |
| out = {} |
| for nt, x in x_dict.items(): |
| y = h.get(nt, x) |
| y = F.dropout(F.relu(y), p=self.dropout, training=self.training) |
| out[nt] = norm[nt](x + y) |
| x_dict = out |
| return x_dict |
|
|
|
|
| class SAGERecommender(nn.Module): |
| def __init__(self, metadata, num_authors: int, paper_dim: int, hidden_dim: int, num_layers: int, dropout: float): |
| super().__init__() |
| self.author_emb = nn.Embedding(num_authors, hidden_dim) |
| self.paper_proj = nn.Linear(paper_dim, hidden_dim) |
| self.encoder = ResidualSAGE(metadata, hidden_dim, num_layers, dropout) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| nn.init.xavier_uniform_(self.author_emb.weight) |
| self.paper_proj.reset_parameters() |
|
|
| def encode(self, data): |
| x = {"author": self.author_emb.weight, "paper": self.paper_proj(data["paper"].x)} |
| return self.encoder(x, data.edge_index_dict) |
|
|
| def decode(self, z, edge_index): |
| src, dst = edge_index |
| return (z["author"][src] * z["paper"][dst]).sum(-1) |
|
|
|
|
| @torch.no_grad() |
| def predict_scores(model, data, pairs: np.ndarray, batch_size: int) -> np.ndarray: |
| model.eval() |
| z = model.encode(data) |
| a = z["author"].detach().cpu().numpy() |
| p = z["paper"].detach().cpu().numpy() |
| scores = [] |
| for st in range(0, len(pairs), batch_size): |
| b = pairs[st : st + batch_size] |
| scores.append(np.sum(a[b[:, 0]] * p[b[:, 1]], axis=1).astype(np.float32)) |
| return np.concatenate(scores) |
|
|
|
|
| def train_one(args, lgcn, parts, data, seed: int, out_dir: Path): |
| lgcn.set_seed(seed) |
| device = torch.device(args.device) |
| model = SAGERecommender( |
| data.metadata(), |
| args.num_authors, |
| parts["paper_feat_aug"].shape[1], |
| args.hidden_dim, |
| args.layers, |
| args.dropout, |
| ).to(device) |
| opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| pos_edges = data["author", "ref", "paper"].edge_index |
| batch_size = min(args.train_batch_size, pos_edges.size(1)) |
| val_arr = parts["val_pairs"][["source", "target"]].to_numpy(np.int64) |
| labels = parts["val_pairs"]["label"].to_numpy(np.int8) |
|
|
| best = (-1.0, 0.0, 0.0) |
| best_state = None |
| for epoch in range(args.epochs): |
| model.train() |
| perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size] |
| pos = pos_edges[:, perm] |
| neg = lgcn.sample_hard_negatives(parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device) |
| z = model.encode(data) |
| pos_s = model.decode(z, pos).repeat_interleave(args.neg_per_pos) |
| neg_s = model.decode(z, neg) |
| loss = -F.logsigmoid(pos_s - neg_s).mean() |
| opt.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
|
|
| if (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1: |
| scores = predict_scores(model, data, val_arr, args.pred_batch_size) |
| f1, th, auc = lgcn.best_f1(labels, scores) |
| if f1 > best[0]: |
| best = (f1, th, auc) |
| best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
| np.save(out_dir / "scores" / f"val_sage_dot_s{seed}_d{args.hidden_dim}.npy", scores) |
| print(f"seed={seed} epoch={epoch+1:03d} loss={loss.item():.4f} val_f1={f1:.5f} th={th:.5f} auc={auc:.5f}") |
|
|
| if best_state is not None: |
| torch.save(best_state, out_dir / "checkpoints" / f"sage_val_s{seed}_d{args.hidden_dim}.pt") |
| return best |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) |
| parser.add_argument("--split-seed", type=int, required=True) |
| parser.add_argument("--train-frac", type=float, default=0.9) |
| parser.add_argument("--device", default="cuda:0") |
| parser.add_argument("--run-name", required=True) |
| parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42]) |
| parser.add_argument("--hidden-dim", type=int, default=256) |
| parser.add_argument("--layers", type=int, default=2) |
| parser.add_argument("--epochs", type=int, default=140) |
| parser.add_argument("--eval-every", type=int, default=20) |
| parser.add_argument("--lr", type=float, default=0.003) |
| parser.add_argument("--weight-decay", type=float, default=1e-4) |
| parser.add_argument("--dropout", type=float, default=0.1) |
| parser.add_argument("--train-batch-size", type=int, default=32768) |
| parser.add_argument("--pred-batch-size", type=int, default=65536) |
| parser.add_argument("--neg-per-pos", type=int, default=3) |
| parser.add_argument("--num-authors", type=int, default=6611) |
| parser.add_argument("--num-papers", type=int, default=79937) |
| args = parser.parse_args() |
|
|
| root = args.package_root |
| lgcn = load_lgcn_module(root / "code" / "train_val_lgcn_ensemble.py") |
| parts = lgcn.build_parts(root, None, args.num_papers, split_seed=args.split_seed, train_frac=args.train_frac) |
| data = lgcn.build_data(parts, args.num_authors, args.num_papers, torch.device(args.device)) |
| out_dir = root / "validation_runs" / f"dynamic_seed{args.split_seed}" / args.run_name |
| (out_dir / "scores").mkdir(parents=True, exist_ok=True) |
| (out_dir / "checkpoints").mkdir(parents=True, exist_ok=True) |
|
|
| rows = [] |
| for seed in args.seeds: |
| f1, th, auc = train_one(args, lgcn, parts, data, seed, out_dir) |
| rows.append({"seed": seed, "dim": args.hidden_dim, "f1": f1, "threshold": th, "auc": auc}) |
| pd.DataFrame(rows).sort_values("f1", ascending=False).to_csv(out_dir / "model_results.csv", index=False) |
|
|
| labels = parts["val_pairs"]["label"].to_numpy(np.int8) |
| vals = [] |
| names = [] |
| for seed in args.seeds: |
| p = out_dir / "scores" / f"val_sage_dot_s{seed}_d{args.hidden_dim}.npy" |
| if p.exists(): |
| vals.append(np.load(p)) |
| names.append(p.stem) |
| if vals: |
| ens = np.mean(vals, axis=0) |
| f1, th, auc = lgcn.best_f1(labels, ens) |
| np.save(out_dir / "scores" / "val_sage_ensemble_mean.npy", ens) |
| (out_dir / "ensemble_result.txt").write_text( |
| f"models={','.join(names)}\nf1={f1:.8f}\nthreshold={th:.8f}\nauc={auc:.8f}\n" |
| ) |
| print(f"\nMean ensemble: f1={f1:.5f} threshold={th:.5f} auc={auc:.5f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|