"""Dynamic-split HGT attention recommender for author-paper validation.""" from __future__ import annotations import argparse import importlib.util from pathlib import Path import numpy as np import pandas as pd import torch import torch.nn as nn import torch.nn.functional as F from torch_geometric.nn import HGTConv def load_lgcn_module(path: Path): spec = importlib.util.spec_from_file_location("train_val_lgcn_ensemble", path) module = importlib.util.module_from_spec(spec) assert spec.loader is not None spec.loader.exec_module(module) return module class HGTRecommender(nn.Module): def __init__( self, metadata, num_authors: int, paper_dim: int, hidden_dim: int, layers: int, heads: int, dropout: float, ): super().__init__() self.dropout = dropout self.author_emb = nn.Embedding(num_authors, hidden_dim) self.paper_proj = nn.Linear(paper_dim, hidden_dim) self.convs = nn.ModuleList( [HGTConv(hidden_dim, hidden_dim, metadata, heads=heads) for _ in range(layers)] ) self.norms = nn.ModuleList( [nn.ModuleDict({nt: nn.LayerNorm(hidden_dim) for nt in metadata[0]}) for _ in range(layers)] ) self.reset_parameters() def reset_parameters(self): nn.init.xavier_uniform_(self.author_emb.weight) self.paper_proj.reset_parameters() def encode(self, data): x = {"author": self.author_emb.weight, "paper": self.paper_proj(data["paper"].x)} for conv, norm in zip(self.convs, self.norms): h = conv(x, data.edge_index_dict) x = { nt: norm[nt](x[nt] + F.dropout(F.relu(h[nt]), p=self.dropout, training=self.training)) for nt in x } return x def decode(self, z, edge_index): src, dst = edge_index return (z["author"][src] * z["paper"][dst]).sum(-1) @torch.no_grad() def predict_scores(model, data, pairs: np.ndarray, batch_size: int) -> np.ndarray: model.eval() z = model.encode(data) author = z["author"].detach().cpu().numpy() paper = z["paper"].detach().cpu().numpy() scores = [] for start in range(0, len(pairs), batch_size): batch = pairs[start : start + batch_size] scores.append(np.sum(author[batch[:, 0]] * paper[batch[:, 1]], axis=1).astype(np.float32)) return np.concatenate(scores) def train_one(args, lgcn, parts, data, seed: int, out_dir: Path): lgcn.set_seed(seed) device = torch.device(args.device) model = HGTRecommender( data.metadata(), args.num_authors, parts["paper_feat_aug"].shape[1], args.hidden_dim, args.layers, args.heads, args.dropout, ).to(device) opt = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) pos_edges = data["author", "ref", "paper"].edge_index batch_size = min(args.train_batch_size, pos_edges.size(1)) val_arr = parts["val_pairs"][["source", "target"]].to_numpy(np.int64) labels = parts["val_pairs"]["label"].to_numpy(np.int8) best = (-1.0, 0.0, 0.0) best_state = None for epoch in range(args.epochs): model.train() perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size] pos = pos_edges[:, perm] neg = lgcn.sample_hard_negatives( parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device ) z = model.encode(data) pos_scores = model.decode(z, pos).repeat_interleave(args.neg_per_pos) neg_scores = model.decode(z, neg) loss = -F.logsigmoid(pos_scores - neg_scores).mean() opt.zero_grad() loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1: scores = predict_scores(model, data, val_arr, args.pred_batch_size) f1, th, auc = lgcn.best_f1(labels, scores) if f1 > best[0]: best = (f1, th, auc) best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} np.save(out_dir / "scores" / f"val_hgt_dot_s{seed}_d{args.hidden_dim}.npy", scores) print( f"seed={seed} epoch={epoch+1:03d} loss={loss.item():.4f} " f"val_f1={f1:.5f} th={th:.5f} auc={auc:.5f}" ) if best_state is not None: torch.save(best_state, out_dir / "checkpoints" / f"hgt_val_s{seed}_d{args.hidden_dim}.pt") return best def main(): parser = argparse.ArgumentParser() parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) parser.add_argument("--split-seed", type=int, required=True) parser.add_argument("--train-frac", type=float, default=0.9) parser.add_argument("--device", default="cuda:0") parser.add_argument("--run-name", required=True) parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42]) parser.add_argument("--hidden-dim", type=int, default=128) parser.add_argument("--layers", type=int, default=2) parser.add_argument("--heads", type=int, default=4) parser.add_argument("--dropout", type=float, default=0.1) parser.add_argument("--epochs", type=int, default=120) parser.add_argument("--eval-every", type=int, default=20) parser.add_argument("--lr", type=float, default=0.002) parser.add_argument("--weight-decay", type=float, default=1e-4) parser.add_argument("--train-batch-size", type=int, default=32768) parser.add_argument("--pred-batch-size", type=int, default=65536) parser.add_argument("--neg-per-pos", type=int, default=3) parser.add_argument("--num-authors", type=int, default=6611) parser.add_argument("--num-papers", type=int, default=79937) args = parser.parse_args() root = args.package_root lgcn = load_lgcn_module(root / "code" / "train_val_lgcn_ensemble.py") parts = lgcn.build_parts(root, None, args.num_papers, split_seed=args.split_seed, train_frac=args.train_frac) data = lgcn.build_data(parts, args.num_authors, args.num_papers, torch.device(args.device)) out_dir = root / "validation_runs" / f"dynamic_seed{args.split_seed}" / args.run_name (out_dir / "scores").mkdir(parents=True, exist_ok=True) (out_dir / "checkpoints").mkdir(parents=True, exist_ok=True) rows = [] for seed in args.seeds: f1, th, auc = train_one(args, lgcn, parts, data, seed, out_dir) rows.append({"seed": seed, "dim": args.hidden_dim, "f1": f1, "threshold": th, "auc": auc}) pd.DataFrame(rows).sort_values("f1", ascending=False).to_csv(out_dir / "model_results.csv", index=False) labels = parts["val_pairs"]["label"].to_numpy(np.int8) vals, names = [], [] for seed in args.seeds: p = out_dir / "scores" / f"val_hgt_dot_s{seed}_d{args.hidden_dim}.npy" if p.exists(): vals.append(np.load(p)) names.append(p.stem) if vals: ens = np.mean(vals, axis=0) f1, th, auc = lgcn.best_f1(labels, ens) np.save(out_dir / "scores" / "val_hgt_ensemble_mean.npy", ens) (out_dir / "ensemble_result.txt").write_text( f"models={','.join(names)}\nf1={f1:.8f}\nthreshold={th:.8f}\nauc={auc:.8f}\n" ) print(f"\nMean ensemble: f1={f1:.5f} threshold={th:.5f} auc={auc:.5f}") if __name__ == "__main__": main()