| """Train and validate LightGCN ensembles on the notebook-style split.""" |
|
|
| from __future__ import annotations |
|
|
| import argparse |
| import pickle as pkl |
| import random |
| from pathlib import Path |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from numpy.linalg import norm |
| from sklearn.metrics import precision_recall_curve, roc_auc_score |
| from torch_geometric.data import HeteroData |
|
|
|
|
| EDGE_TYPES = [ |
| ("author", "ref", "paper"), |
| ("paper", "beref", "author"), |
| ("paper", "cite", "paper"), |
| ("author", "coauthor", "author"), |
| ] |
|
|
|
|
| def set_seed(seed: int) -> None: |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| def read_txt(path: Path) -> list[list[int]]: |
| rows: list[list[int]] = [] |
| with path.open("r") as f: |
| for line in f: |
| rows.append(list(map(int, line.strip().split()))) |
| return rows |
|
|
|
|
| def log_norm(x: np.ndarray) -> np.ndarray: |
| x = np.log1p(x) |
| return (x - x.mean()) / (x.std() + 1e-8) |
|
|
|
|
| class LightGCNLayer(nn.Module): |
| def forward(self, x_dict, edge_index_dict): |
| agg_dict = {node_type: [] for node_type in x_dict} |
| for edge_type in EDGE_TYPES: |
| if edge_type not in edge_index_dict: |
| continue |
| src_type, _, dst_type = edge_type |
| src, dst = edge_index_dict[edge_type] |
| src_x = x_dict[src_type] |
| agg = src_x.new_zeros((x_dict[dst_type].size(0), src_x.size(-1))) |
| deg = src_x.new_zeros((x_dict[dst_type].size(0), 1)) |
| agg.index_add_(0, dst, src_x[src]) |
| deg.index_add_( |
| 0, |
| dst, |
| torch.ones((dst.numel(), 1), dtype=src_x.dtype, device=src_x.device), |
| ) |
| agg_dict[dst_type].append(agg / deg.clamp(min=1.0)) |
| return { |
| node_type: sum(aggs) / len(aggs) if aggs else x_dict[node_type] |
| for node_type, aggs in agg_dict.items() |
| } |
|
|
|
|
| class LightGCN(nn.Module): |
| def __init__(self, num_authors: int, paper_feat_dim: int, embed_dim: int, num_layers: int = 4): |
| super().__init__() |
| self.author_emb = nn.Embedding(num_authors, embed_dim) |
| self.paper_proj = nn.Linear(paper_feat_dim, embed_dim) |
| self.layers = nn.ModuleList([LightGCNLayer() for _ in range(num_layers)]) |
| self.num_layers = num_layers |
| self.reset_parameters() |
|
|
| def reset_parameters(self) -> None: |
| nn.init.xavier_uniform_(self.author_emb.weight) |
| nn.init.xavier_uniform_(self.paper_proj.weight) |
| nn.init.zeros_(self.paper_proj.bias) |
|
|
| def encode(self, data): |
| x_dict = { |
| "author": self.author_emb.weight, |
| "paper": self.paper_proj(data["paper"].x), |
| } |
| all_layers = [x_dict] |
| for layer in self.layers: |
| x_dict = layer(x_dict, data.edge_index_dict) |
| all_layers.append(x_dict) |
| weight = 1.0 / (self.num_layers + 1) |
| return { |
| node_type: sum(weight * layer[node_type] for layer in all_layers) |
| for node_type in x_dict |
| } |
|
|
| def decode(self, z_dict, edge_index): |
| src, dst = edge_index |
| return (z_dict["author"][src] * z_dict["paper"][dst]).sum(dim=-1) |
|
|
|
|
| class LearnableWeightLightGCN(LightGCN): |
| def __init__(self, num_authors: int, paper_feat_dim: int, embed_dim: int, num_layers: int = 4): |
| super().__init__(num_authors, paper_feat_dim, embed_dim, num_layers) |
| self.layer_weight = nn.Parameter(torch.zeros(num_layers + 1)) |
|
|
| def encode(self, data): |
| x_dict = { |
| "author": self.author_emb.weight, |
| "paper": self.paper_proj(data["paper"].x), |
| } |
| all_layers = [x_dict] |
| for layer in self.layers: |
| x_dict = layer(x_dict, data.edge_index_dict) |
| all_layers.append(x_dict) |
| weights = F.softmax(self.layer_weight, dim=0) |
| return { |
| node_type: sum(weights[i] * layer[node_type] for i, layer in enumerate(all_layers)) |
| for node_type in x_dict |
| } |
|
|
|
|
| def cos_sim(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> np.ndarray: |
| return np.sum(a * b, axis=1) / (norm(a, axis=1) * norm(b, axis=1) + eps) |
|
|
|
|
| def make_notebook_style_split(root: Path, seed: int, train_frac: float): |
| data_dir = root / "data_and_docs" |
| refs = read_txt(data_dir / "bipartite_train_ann.txt") |
| coauthor = read_txt(data_dir / "author_file_ann.txt") |
| citation = read_txt(data_dir / "paper_file_ann.txt") |
|
|
| ref_edges = pd.DataFrame(refs, columns=["source", "target"]) |
| ref_edges = ref_edges.set_index("r-" + ref_edges.index.astype(str)) |
| coauthor_edges = pd.DataFrame(coauthor, columns=["source", "target"]) |
| citation_edges = pd.DataFrame(citation, columns=["source", "target"]) |
|
|
| node_tmp = pd.concat([citation_edges["source"], citation_edges["target"], ref_edges["target"]]) |
| paper_ids = pd.unique(node_tmp).astype(np.int64) |
| node_tmp = pd.concat([ref_edges["source"], coauthor_edges["source"], coauthor_edges["target"]]) |
| author_ids = pd.unique(node_tmp).astype(np.int64) |
|
|
| train_refs = ref_edges.sample(frac=train_frac, random_state=seed, axis=0) |
| val_pos = ref_edges[~ref_edges.index.isin(train_refs.index)].copy() |
| val_pos.loc[:, "label"] = 1 |
|
|
| existing_ref_set = set(map(tuple, ref_edges[["source", "target"]].to_numpy().tolist())) |
| neg_pairs: list[tuple[int, int]] = [] |
| rng = np.random.default_rng(seed) |
| while len(neg_pairs) < len(val_pos): |
| src = int(rng.choice(author_ids)) |
| dst = int(rng.choice(paper_ids)) |
| if (src, dst) not in existing_ref_set: |
| neg_pairs.append((src, dst)) |
|
|
| val_neg = pd.DataFrame(neg_pairs, columns=["source", "target"]) |
| val_neg.loc[:, "label"] = 0 |
| val_pairs = pd.concat([val_pos.reset_index(drop=True), val_neg], ignore_index=True) |
| val_pairs = val_pairs.sample(frac=1, random_state=seed, axis=0).reset_index(drop=True) |
| return train_refs[["source", "target"]].reset_index(drop=True), val_pairs |
|
|
|
|
| def build_parts( |
| root: Path, |
| split_dir: Path | None, |
| num_papers: int, |
| split_seed: int | None = None, |
| train_frac: float = 0.9, |
| ): |
| data_dir = root / "data_and_docs" |
| if split_seed is None: |
| if split_dir is None: |
| raise ValueError("split_dir is required when split_seed is not set") |
| train_refs = pd.read_csv(split_dir / "train_refs.csv") |
| val_pairs = pd.read_csv(split_dir / "val_pairs.csv") |
| else: |
| train_refs, val_pairs = make_notebook_style_split(root, split_seed, train_frac) |
| citation = read_txt(data_dir / "paper_file_ann.txt") |
| coauthor = read_txt(data_dir / "author_file_ann.txt") |
|
|
| with (data_dir / "feature.pkl").open("rb") as f: |
| paper_feature = pkl.load(f) |
|
|
| paper_ref_deg = np.zeros(num_papers, dtype=np.float32) |
| paper_cite_out = np.zeros(num_papers, dtype=np.float32) |
| paper_cite_in = np.zeros(num_papers, dtype=np.float32) |
| for _, paper in train_refs[["source", "target"]].to_numpy(): |
| paper_ref_deg[paper] += 1 |
| for source, target in citation: |
| paper_cite_out[source] += 1 |
| paper_cite_in[target] += 1 |
|
|
| paper_feat_np = paper_feature.numpy().astype(np.float32) |
| paper_deg_feat = np.stack( |
| [log_norm(paper_ref_deg), log_norm(paper_cite_out), log_norm(paper_cite_in)], |
| axis=-1, |
| ) |
| paper_feat_aug = np.concatenate([paper_feat_np, paper_deg_feat], axis=-1) |
| paper_feat_aug = (paper_feat_aug - paper_feat_aug.mean(axis=0)) / ( |
| paper_feat_aug.std(axis=0) + 1e-8 |
| ) |
|
|
| coauthor_map = {} |
| for s, t in coauthor: |
| coauthor_map.setdefault(s, set()).add(t) |
| coauthor_map.setdefault(t, set()).add(s) |
| author_papers = {} |
| for s, t in train_refs[["source", "target"]].to_numpy(): |
| author_papers.setdefault(int(s), set()).add(int(t)) |
| coauthor_pool = {} |
| for author in range(6611): |
| pool = set() |
| for co in coauthor_map.get(author, ()): |
| pool.update(author_papers.get(co, ())) |
| pool -= author_papers.get(author, set()) |
| coauthor_pool[author] = np.array(list(pool), dtype=np.int64) if pool else None |
|
|
| popular_threshold = np.percentile(paper_ref_deg[paper_ref_deg > 0], 70) |
| popular = np.where(paper_ref_deg >= popular_threshold)[0] |
| train_set = set(map(tuple, train_refs[["source", "target"]].to_numpy().tolist())) |
|
|
| return { |
| "train_refs": train_refs, |
| "val_pairs": val_pairs, |
| "citation": pd.DataFrame(citation, columns=["source", "target"]), |
| "coauthor": pd.DataFrame(coauthor, columns=["source", "target"]), |
| "paper_feat_aug": paper_feat_aug, |
| "popular": popular, |
| "coauthor_pool": coauthor_pool, |
| "train_set": train_set, |
| } |
|
|
|
|
| def build_data( |
| parts, |
| num_authors: int, |
| num_papers: int, |
| device: torch.device, |
| use_citation: bool = True, |
| use_coauthor: bool = True, |
| ): |
| ref_tensor = torch.as_tensor( |
| parts["train_refs"][["source", "target"]].to_numpy(), dtype=torch.long |
| ) |
| cite_tensor = torch.as_tensor( |
| parts["citation"][["source", "target"]].to_numpy(), dtype=torch.long |
| ) |
| coauthor_tensor = torch.as_tensor( |
| parts["coauthor"][["source", "target"]].to_numpy(), dtype=torch.long |
| ) |
| data = HeteroData() |
| data["author"].num_nodes = num_authors |
| data["paper"].num_nodes = num_papers |
| data["paper"].x = torch.as_tensor(parts["paper_feat_aug"], dtype=torch.float) |
| data["author", "ref", "paper"].edge_index = ref_tensor.t().contiguous() |
| data["paper", "beref", "author"].edge_index = ref_tensor[:, [1, 0]].t().contiguous() |
| if use_citation: |
| data["paper", "cite", "paper"].edge_index = torch.cat( |
| [cite_tensor, cite_tensor[:, [1, 0]]], dim=0 |
| ).t().contiguous() |
| if use_coauthor: |
| data["author", "coauthor", "author"].edge_index = torch.cat( |
| [coauthor_tensor, coauthor_tensor[:, [1, 0]]], dim=0 |
| ).t().contiguous() |
| return data.to(device) |
|
|
|
|
| def sample_hard_negatives(parts, n_samples: int, num_authors: int, num_papers: int, device): |
| neg_list: list[tuple[int, int]] = [] |
| train_set = parts["train_set"] |
| popular = parts["popular"] |
| coauthor_pool = parts["coauthor_pool"] |
|
|
| def add_random(target: int) -> None: |
| while len(neg_list) < target: |
| s = np.random.randint(0, num_authors) |
| d = np.random.randint(0, num_papers) |
| if (s, d) not in train_set: |
| neg_list.append((s, d)) |
|
|
| add_random(int(n_samples * 0.5)) |
| attempts = 0 |
| while len(neg_list) < int(n_samples * 0.75) and attempts < n_samples * 3: |
| attempts += 1 |
| s = np.random.randint(0, num_authors) |
| d = int(popular[np.random.randint(0, len(popular))]) |
| if (s, d) not in train_set: |
| neg_list.append((s, d)) |
| attempts = 0 |
| while len(neg_list) < n_samples and attempts < n_samples * 4: |
| attempts += 1 |
| s = np.random.randint(0, num_authors) |
| pool = coauthor_pool.get(s) |
| if pool is None or len(pool) == 0: |
| continue |
| d = int(pool[np.random.randint(0, len(pool))]) |
| if (s, d) not in train_set: |
| neg_list.append((s, d)) |
| add_random(n_samples) |
| return torch.tensor(neg_list[:n_samples], dtype=torch.long, device=device).t().contiguous() |
|
|
|
|
| @torch.no_grad() |
| def predict_scores( |
| model: LightGCN, |
| data, |
| pairs: np.ndarray, |
| batch_size: int, |
| mode: str = "cos", |
| normalize_embeddings: bool = False, |
| ) -> np.ndarray: |
| model.eval() |
| z_dict = model.encode(data) |
| if normalize_embeddings: |
| z_dict = {k: F.normalize(v, p=2, dim=1) for k, v in z_dict.items()} |
| author_z = z_dict["author"].detach().cpu().numpy() |
| paper_z = z_dict["paper"].detach().cpu().numpy() |
| scores = [] |
| for start in range(0, len(pairs), batch_size): |
| batch = pairs[start : start + batch_size] |
| a = author_z[batch[:, 0]] |
| p = paper_z[batch[:, 1]] |
| if mode == "cos": |
| score = cos_sim(a, p) |
| elif mode == "dot": |
| score = np.sum(a * p, axis=1) |
| elif mode == "neg_l2": |
| score = -np.sum((a - p) ** 2, axis=1) |
| else: |
| raise ValueError(mode) |
| scores.append(score.astype(np.float32)) |
| return np.concatenate(scores) |
|
|
|
|
| def best_f1(labels: np.ndarray, scores: np.ndarray) -> tuple[float, float, float]: |
| precision, recall, thresholds = precision_recall_curve(labels, scores) |
| f1s = 2 * precision * recall / (precision + recall + 1e-12) |
| idx = int(np.argmax(f1s)) |
| threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5 |
| auc = float(roc_auc_score(labels, scores)) |
| return float(f1s[idx]), threshold, auc |
|
|
|
|
| def train_one(args, parts, data, seed: int, embed_dim: int, model_dir: Path, score_dir: Path): |
| set_seed(seed) |
| device = torch.device(args.device) |
| model_cls = LearnableWeightLightGCN if args.variant == "learnw" else LightGCN |
| model = model_cls(args.num_authors, parts["paper_feat_aug"].shape[1], embed_dim, args.layers).to(device) |
| optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay) |
| pos_edges = data["author", "ref", "paper"].edge_index |
| batch_size = min(args.train_batch_size, pos_edges.size(1)) |
| best = (-1.0, 0.0, 0.0) |
| best_state = None |
|
|
| val_arr = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64) |
| labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8) |
|
|
| for epoch in range(args.epochs): |
| model.train() |
| perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size] |
| pos = pos_edges[:, perm] |
| neg = sample_hard_negatives( |
| parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device |
| ) |
| z_dict = model.encode(data) |
| if args.normalize_embeddings: |
| z_dict = {k: F.normalize(v, p=2, dim=1) for k, v in z_dict.items()} |
| raw_pos_scores = model.decode(z_dict, pos) |
| pos_scores = raw_pos_scores.repeat_interleave(args.neg_per_pos) |
| neg_scores = model.decode(z_dict, neg) |
| if args.loss == "bpr": |
| loss = -F.logsigmoid(pos_scores - neg_scores).mean() |
| elif args.loss == "hinge": |
| loss = (args.margin - pos_scores + neg_scores).clamp(min=0).mean() |
| elif args.loss == "bce": |
| logits = torch.cat([raw_pos_scores, neg_scores]) |
| targets = torch.cat([torch.ones_like(raw_pos_scores), torch.zeros_like(neg_scores)]) |
| loss = F.binary_cross_entropy_with_logits(logits, targets) |
| else: |
| raise ValueError(args.loss) |
|
|
| optimizer.zero_grad() |
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| should_eval = (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1 |
| if should_eval: |
| scores = predict_scores( |
| model, |
| data, |
| val_arr, |
| args.pred_batch_size, |
| args.eval_mode, |
| args.normalize_embeddings, |
| ) |
| f1, threshold, auc = best_f1(labels, scores) |
| if f1 > best[0]: |
| best = (f1, threshold, auc) |
| best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()} |
| np.save(score_dir / f"val_{args.variant}_{args.eval_mode}_s{seed}_d{embed_dim}.npy", scores) |
| print( |
| f"seed={seed} dim={embed_dim} epoch={epoch+1:03d} " |
| f"loss={loss.item():.4f} val_f1={f1:.5f} th={threshold:.5f} auc={auc:.5f}" |
| ) |
|
|
| if best_state is not None: |
| torch.save(best_state, model_dir / f"{args.variant}_val_s{seed}_d{embed_dim}.pt") |
| return best |
|
|
|
|
| def main() -> None: |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1]) |
| parser.add_argument("--split-dir", type=Path, default=None) |
| parser.add_argument("--split-seed", type=int, default=None) |
| parser.add_argument("--train-frac", type=float, default=0.9) |
| parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu") |
| parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42, 2024]) |
| parser.add_argument("--dims", nargs="*", type=int, default=[256]) |
| parser.add_argument("--layers", type=int, default=4) |
| parser.add_argument("--epochs", type=int, default=160) |
| parser.add_argument("--eval-every", type=int, default=20) |
| parser.add_argument("--lr", type=float, default=0.005) |
| parser.add_argument("--weight-decay", type=float, default=1e-5) |
| parser.add_argument("--train-batch-size", type=int, default=32768) |
| parser.add_argument("--pred-batch-size", type=int, default=65536) |
| parser.add_argument("--neg-per-pos", type=int, default=2) |
| parser.add_argument("--loss", choices=["bpr", "bce", "hinge"], default="bpr") |
| parser.add_argument("--margin", type=float, default=1.0) |
| parser.add_argument("--normalize-embeddings", action="store_true") |
| parser.add_argument("--num-authors", type=int, default=6611) |
| parser.add_argument("--num-papers", type=int, default=79937) |
| parser.add_argument("--run-name", default=None) |
| parser.add_argument("--variant", choices=["vanilla", "learnw"], default="vanilla") |
| parser.add_argument("--eval-mode", choices=["cos", "dot", "neg_l2"], default="cos") |
| parser.add_argument("--drop-citation", action="store_true") |
| parser.add_argument("--drop-coauthor", action="store_true") |
| args = parser.parse_args() |
|
|
| root = args.package_root |
| split_dir = args.split_dir or root / "splits" / "notebook_seed0" |
| run_name = args.run_name or ( |
| f"dims{'-'.join(map(str, args.dims))}_" |
| f"seeds{'-'.join(map(str, args.seeds))}_" |
| f"L{args.layers}_E{args.epochs}" |
| ) |
| split_name = f"dynamic_seed{args.split_seed}" if args.split_seed is not None else split_dir.name |
| out_dir = root / "validation_runs" / split_name / run_name |
| model_dir = out_dir / "checkpoints" |
| score_dir = out_dir / "scores" |
| model_dir.mkdir(parents=True, exist_ok=True) |
| score_dir.mkdir(parents=True, exist_ok=True) |
|
|
| parts = build_parts( |
| root, |
| split_dir if args.split_seed is None else None, |
| args.num_papers, |
| split_seed=args.split_seed, |
| train_frac=args.train_frac, |
| ) |
| data = build_data( |
| parts, |
| args.num_authors, |
| args.num_papers, |
| torch.device(args.device), |
| use_citation=not args.drop_citation, |
| use_coauthor=not args.drop_coauthor, |
| ) |
| labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8) |
|
|
| rows = [] |
| for dim in args.dims: |
| for seed in args.seeds: |
| best = train_one(args, parts, data, seed, dim, model_dir, score_dir) |
| rows.append({"seed": seed, "dim": dim, "f1": best[0], "threshold": best[1], "auc": best[2]}) |
|
|
| result = pd.DataFrame(rows).sort_values("f1", ascending=False) |
| result.to_csv(out_dir / "model_results.csv", index=False) |
| print("\nModel results:") |
| print(result.to_string(index=False)) |
|
|
| val_scores = [] |
| names = [] |
| for row in rows: |
| path = score_dir / f"val_{args.variant}_{args.eval_mode}_s{int(row['seed'])}_d{int(row['dim'])}.npy" |
| if path.exists(): |
| val_scores.append(np.load(path)) |
| names.append(path.stem) |
| if val_scores: |
| ensemble = np.mean(val_scores, axis=0) |
| f1, threshold, auc = best_f1(labels, ensemble) |
| np.save(score_dir / f"val_{args.variant}_ensemble_mean.npy", ensemble) |
| with (out_dir / "ensemble_result.txt").open("w") as f: |
| f.write(f"models={','.join(names)}\n") |
| f.write(f"f1={f1:.8f}\nthreshold={threshold:.8f}\nauc={auc:.8f}\n") |
| print(f"\nMean ensemble: f1={f1:.5f} threshold={threshold:.5f} auc={auc:.5f}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|