"""Compare different GNN architectures on validation set.""" import os, pickle as pkl, random, time import numpy as np import pandas as pd import torch, torch.nn as nn, torch.nn.functional as F from torch_geometric.data import HeteroData from torch_geometric.nn import GATv2Conv, HeteroConv, SAGEConv from sklearn.metrics import precision_recall_curve, roc_auc_score from numpy.linalg import norm device = torch.device('cuda:0') ETS = [('author','ref','paper'),('paper','beref','author'), ('paper','cite','paper'),('author','coauthor','author')] def set_seed(seed=0): random.seed(seed); np.random.seed(seed) torch.manual_seed(seed) if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) # ── Data loading ───────────────────────────────────────────────── base = '/home/lzc/cs3319-project' def read_txt(f): res = [] with open(f) as fh: for line in fh: res.append(list(map(int, line.strip().split()))) return res train_raw = read_txt(f'{base}/bipartite_train_ann.txt') test_raw = read_txt(f'{base}/bipartite_test_ann.txt') citing_raw = read_txt(f'{base}/paper_file_ann.txt') coauthor_raw = read_txt(f'{base}/author_file_ann.txt') with open(f'{base}/feature.pkl', 'rb') as f: paper_feat_raw = pkl.load(f) # Build node sets df_cite = pd.DataFrame(citing_raw, columns=['source','target']) df_ref = pd.DataFrame(train_raw, columns=['source','target']) df_coa = pd.DataFrame(coauthor_raw, columns=['source','target']) tmp = pd.concat([df_cite['source'], df_cite['target'], df_ref['target']]) paper_nodes = pd.DataFrame(index=pd.unique(tmp)) tmp = pd.concat([df_ref['source'], df_coa['source'], df_coa['target']]) author_nodes = pd.DataFrame(index=pd.unique(tmp)) N_AUTHORS = len(author_nodes) N_PAPERS = len(paper_nodes) print(f"Authors: {N_AUTHORS}, Papers: {N_PAPERS}") # Degree features author_deg = np.zeros(N_AUTHORS, np.float32) paper_deg = np.zeros(N_PAPERS, np.float32) paper_cout = np.zeros(N_PAPERS, np.float32) paper_cin = np.zeros(N_PAPERS, np.float32) for s, t in train_raw: author_deg[s] += 1; paper_deg[t] += 1 for s, t in citing_raw: paper_cout[s] += 1; paper_cin[t] += 1 def log_norm(x): x = np.log1p(x); return (x - x.mean()) / (x.std() + 1e-8) pf = paper_feat_raw.numpy().astype(np.float32) pdeg = np.stack([log_norm(paper_deg), log_norm(paper_cout), log_norm(paper_cin)], -1) PAPER_FEAT = np.concatenate([pf, pdeg], -1) PAPER_FEAT = (PAPER_FEAT - PAPER_FEAT.mean(0)) / (PAPER_FEAT.std(0) + 1e-8) # Hard negative pools popular = np.where(paper_deg >= np.percentile(paper_deg[paper_deg > 0], 70))[0] coauthor_set = {i: set() for i in range(N_AUTHORS)} for s, t in coauthor_raw: coauthor_set[s].add(t); coauthor_set[t].add(s) author_papers_set = {i: set() for i in range(N_AUTHORS)} for s, t in train_raw: author_papers_set[s].add(t) coauthor_pool = {} for a in range(N_AUTHORS): pool = set() for c in coauthor_set[a]: pool.update(author_papers_set[c]) pool -= author_papers_set[a] coauthor_pool[a] = list(pool) if pool else list(range(N_PAPERS)) TRAIN_SET = set(map(tuple, train_raw)) # Train/val split (90/10) df_ref_idx = df_ref.copy() train_90 = df_ref_idx.sample(frac=0.9, random_state=0, axis=0) val_pos_df = df_ref_idx[~df_ref_idx.index.isin(train_90.index)].copy() val_pos_df['label'] = 1 neg_list = [] while len(neg_list) < len(val_pos_df): s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS) if (s, d) not in TRAIN_SET: neg_list.append((s, d)) val_neg_df = pd.DataFrame(neg_list, columns=['source', 'target']) val_neg_df['label'] = 0 VAL_DF = pd.concat([val_pos_df, val_neg_df], ignore_index=True) VAL_DF = VAL_DF.sample(frac=1, random_state=0) # ── Graph building ──────────────────────────────────────────────── def build_data(edges_df): rt = torch.as_tensor(edges_df[['source','target']].to_numpy(), dtype=torch.long) ct = torch.as_tensor(df_cite[['source','target']].to_numpy(), dtype=torch.long) cot = torch.as_tensor(df_coa[['source','target']].to_numpy(), dtype=torch.long) d = HeteroData() d['author'].num_nodes = N_AUTHORS d['paper'].num_nodes = N_PAPERS d['paper'].x = torch.as_tensor(PAPER_FEAT, dtype=torch.float) d['author','ref','paper'].edge_index = rt.t().contiguous() d['paper','beref','author'].edge_index = rt[:, [1,0]].t().contiguous() d['paper','cite','paper'].edge_index = torch.cat([ct, ct[:, [1,0]]], 0).t().contiguous() d['author','coauthor','author'].edge_index = torch.cat([cot, cot[:, [1,0]]], 0).t().contiguous() return d.to(device) def sample_hard_neg(n): nl = [] def add_rand(tgt): nonlocal nl while len(nl) < tgt: s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS) if (s, d) not in TRAIN_SET: nl.append((s, d)) add_rand(int(n * 0.5)) cnt = 0 while len(nl) < int(n * 0.75) and cnt < n * 2: cnt += 1; s = np.random.randint(0, N_AUTHORS) d = popular[np.random.randint(0, len(popular))] if (s, d) not in TRAIN_SET: nl.append((s, d)) cnt = 0 while len(nl) < n and cnt < n * 3: cnt += 1; s = np.random.randint(0, N_AUTHORS) pl = coauthor_pool.get(s, []) d = pl[np.random.randint(0, len(pl))] if pl else np.random.randint(0, N_PAPERS) if (s, d) not in TRAIN_SET: nl.append((s, d)) add_rand(n) return torch.tensor(nl[:n], dtype=torch.long, device=device).t().contiguous() def cos_sim(a, b, eps=1e-12): return np.sum(a * b, 1) / (norm(a, 1) * norm(b, 1) + eps) @torch.no_grad() def evaluate(model, data): model.eval() z = model.encode(data) zc = {k: v.cpu() for k, v in z.items()} va = VAL_DF[['source','target']].to_numpy(dtype=np.int64) sc = cos_sim(zc['author'][va[:,0]].numpy(), zc['paper'][va[:,1]].numpy()) lb = VAL_DF['label'].to_numpy() p, r, t = precision_recall_curve(lb, sc) f1s = 2 * p * r / (p + r + 1e-12) bi = np.argmax(f1s) return f1s[bi], roc_auc_score(lb, sc), t[bi] if bi < len(t) else 0.5 # ═══════════════════════════════════════════════════════════════════ # GNN Layer Implementations # ═══════════════════════════════════════════════════════════════════ class MeanAggLayer(nn.Module): """LightGCN-style mean aggregation (no params).""" def forward(self, xd, eid): ad = {nt: [] for nt in xd} for et in ETS: if et not in eid: continue st, _, dt = et; src, dst = eid[et]; sx = xd[st] a = sx.new_zeros((xd[dt].size(0), sx.size(-1))) d = sx.new_zeros((xd[dt].size(0), 1)) a.index_add_(0, dst, sx[src]) d.index_add_(0, dst, torch.ones((dst.numel(), 1), dtype=sx.dtype, device=sx.device)) ad[dt].append(a / d.clamp(min=1.0)) return {nt: sum(ad[nt]) / len(ad[nt]) if ad[nt] else xd[nt] for nt in xd} class GATAggLayer(nn.Module): """GATv2 attention-based aggregation.""" def __init__(self, hdim, heads=2): super().__init__() self.conv = HeteroConv({ et: GATv2Conv(hdim, hdim // heads, heads=heads, add_self_loops=False, dropout=0.1) for et in ETS }, aggr='mean') def forward(self, xd, eid): h = self.conv(xd, eid) return {nt: h.get(nt, xd[nt]) for nt in xd} class SAGEAggLayer(nn.Module): """GraphSAGE aggregation with mean pooling.""" def __init__(self, hdim): super().__init__() self.conv = HeteroConv({ et: SAGEConv(hdim, hdim) for et in ETS }, aggr='mean') def forward(self, xd, eid): return self.conv(xd, eid) # ═══════════════════════════════════════════════════════════════════ # Model Builders # ═══════════════════════════════════════════════════════════════════ def make_vanilla_lgcn(edim=256, nlayers=4): """Baseline LightGCN.""" m = nn.Module() m.ae = nn.Embedding(N_AUTHORS, edim) m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) m.L = nlayers m._type = 'vanilla' def encode(data): xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} als = [xd] for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) w = 1.0 / (m.L + 1) return {nt: sum(w * l[nt] for l in als) for nt in xd} def decode(zd, ei): s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) def reset(): nn.init.xavier_uniform_(m.ae.weight) nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) m.encode = encode; m.decode = decode; m.reset_parameters = reset m.reset_parameters() return m def make_learnw_lgcn(edim=256, nlayers=4): """LightGCN with learnable layer weights.""" m = nn.Module() m.ae = nn.Embedding(N_AUTHORS, edim) m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) m.L = nlayers m.layer_w = nn.Parameter(torch.ones(nlayers + 1) / (nlayers + 1)) m._type = 'learnw' def encode(data): xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} als = [xd] for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) w = F.softmax(m.layer_w, dim=0) return {nt: sum(w[i] * l[nt] for i, l in enumerate(als)) for nt in xd} def decode(zd, ei): s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) def reset(): nn.init.xavier_uniform_(m.ae.weight) nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) m.encode = encode; m.decode = decode; m.reset_parameters = reset m.reset_parameters() return m def make_gat_lgcn(edim=256, nlayers=3, heads=2): """LightGCN framework with GAT aggregation.""" m = nn.Module() m.ae = nn.Embedding(N_AUTHORS, edim) m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) m.layers = nn.ModuleList([GATAggLayer(edim, heads) for _ in range(nlayers)]) m.L = nlayers m._type = 'gat' def encode(data): xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} als = [xd] for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) w = 1.0 / (m.L + 1) return {nt: sum(w * l[nt] for l in als) for nt in xd} def decode(zd, ei): s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) def reset(): nn.init.xavier_uniform_(m.ae.weight) nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) m.encode = encode; m.decode = decode; m.reset_parameters = reset m.reset_parameters() return m def make_sage_lgcn(edim=256, nlayers=3): """LightGCN framework with SAGE aggregation.""" m = nn.Module() m.ae = nn.Embedding(N_AUTHORS, edim) m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) m.layers = nn.ModuleList([SAGEAggLayer(edim) for _ in range(nlayers)]) m.L = nlayers m._type = 'sage' def encode(data): xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} als = [xd] for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) w = 1.0 / (m.L + 1) return {nt: F.relu(sum(w * l[nt] for l in als)) for nt in xd} def decode(zd, ei): s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) def reset(): nn.init.xavier_uniform_(m.ae.weight) nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) m.encode = encode; m.decode = decode; m.reset_parameters = reset m.reset_parameters() return m def make_deep_lgcn(edim=256, nlayers=6): """Deeper LightGCN (6 layers).""" m = nn.Module() m.ae = nn.Embedding(N_AUTHORS, edim) m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) m.L = nlayers m._type = 'deep' def encode(data): xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} als = [xd] for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) w = 1.0 / (m.L + 1) return {nt: sum(w * l[nt] for l in als) for nt in xd} def decode(zd, ei): s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) def reset(): nn.init.xavier_uniform_(m.ae.weight) nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) m.encode = encode; m.decode = decode; m.reset_parameters = reset m.reset_parameters() return m def make_wide_lgcn(edim=384, nlayers=3): """Wider LightGCN (384 dim).""" return make_vanilla_lgcn(edim=edim, nlayers=nlayers) # ═══════════════════════════════════════════════════════════════════ # Training # ═══════════════════════════════════════════════════════════════════ def run_trial(name, make_fn, epochs=120): print(f'\n--- {name} ---') t0 = time.time() set_seed(0) data = build_data(train_90) model = make_fn().to(device) opt = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5) pe = data['author', 'ref', 'paper'].edge_index bs = min(32768, pe.size(1)) best_f1, best_th = 0, 0.5 for ep in range(epochs): model.train() perm = torch.randperm(pe.size(1), device=device)[:bs] pos = pe[:, perm] neg = sample_hard_neg(pos.size(1) * 2) z = model.encode(data) ps = model.decode(z, pos).repeat_interleave(2) ns = model.decode(z, neg) loss = -F.logsigmoid(ps - ns).mean() opt.zero_grad(); loss.backward() torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) opt.step() if ep % 20 == 0 or ep == epochs - 1: f1, auc, th = evaluate(model, data) if f1 > best_f1: best_f1 = f1; best_th = th mkr = '>' if f1 == best_f1 else ' ' print(f'{mkr} ep={ep:03d} loss={loss.item():.4f} f1={f1:.4f} auc={auc:.4f}') t = time.time() - t0 npar = sum(p.numel() for p in model.parameters()) print(f' Best: F1={best_f1:.4f} Thresh={best_th:.4f} Params={npar/1e6:.1f}M Time={t:.0f}s') return best_f1 # Run all results = {} configs = [ ('1. Vanilla LightGCN (4L, 256d)', lambda: make_vanilla_lgcn(256, 4)), ('2. Learnable Weights (4L, 256d)', lambda: make_learnw_lgcn(256, 4)), ('3. GAT Aggregation (3L, 256d, 2h)', lambda: make_gat_lgcn(256, 3, 2)), ('4. SAGE Aggregation (3L, 256d)', lambda: make_sage_lgcn(256, 3)), ('5. Deep LightGCN (6L, 256d)', lambda: make_deep_lgcn(256, 6)), ('6. Wide LightGCN (3L, 384d)', lambda: make_wide_lgcn(384, 3)), ('7. Vanilla LightGCN (5L, 256d)', lambda: make_vanilla_lgcn(256, 5)), ('8. GAT Aggregation (4L, 256d, 4h)', lambda: make_gat_lgcn(256, 4, 4)), ] for name, fn in configs: try: f1 = run_trial(name, fn) results[name] = f1 except Exception as e: print(f' FAILED: {e}') print('\n' + '=' * 60) print('RESULTS (sorted by val F1):') print('=' * 60) for name, f1 in sorted(results.items(), key=lambda x: -x[1]): print(f' {f1:.4f} {name}')