| """Compare different GNN architectures on validation set.""" |
| import os, pickle as pkl, random, time |
| import numpy as np |
| import pandas as pd |
| import torch, torch.nn as nn, torch.nn.functional as F |
| from torch_geometric.data import HeteroData |
| from torch_geometric.nn import GATv2Conv, HeteroConv, SAGEConv |
| from sklearn.metrics import precision_recall_curve, roc_auc_score |
| from numpy.linalg import norm |
|
|
| device = torch.device('cuda:0') |
| ETS = [('author','ref','paper'),('paper','beref','author'), |
| ('paper','cite','paper'),('author','coauthor','author')] |
|
|
|
|
| def set_seed(seed=0): |
| random.seed(seed); np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed) |
|
|
|
|
| |
| base = '/home/lzc/cs3319-project' |
| def read_txt(f): |
| res = [] |
| with open(f) as fh: |
| for line in fh: res.append(list(map(int, line.strip().split()))) |
| return res |
|
|
| train_raw = read_txt(f'{base}/bipartite_train_ann.txt') |
| test_raw = read_txt(f'{base}/bipartite_test_ann.txt') |
| citing_raw = read_txt(f'{base}/paper_file_ann.txt') |
| coauthor_raw = read_txt(f'{base}/author_file_ann.txt') |
| with open(f'{base}/feature.pkl', 'rb') as f: paper_feat_raw = pkl.load(f) |
|
|
| |
| df_cite = pd.DataFrame(citing_raw, columns=['source','target']) |
| df_ref = pd.DataFrame(train_raw, columns=['source','target']) |
| df_coa = pd.DataFrame(coauthor_raw, columns=['source','target']) |
|
|
| tmp = pd.concat([df_cite['source'], df_cite['target'], df_ref['target']]) |
| paper_nodes = pd.DataFrame(index=pd.unique(tmp)) |
| tmp = pd.concat([df_ref['source'], df_coa['source'], df_coa['target']]) |
| author_nodes = pd.DataFrame(index=pd.unique(tmp)) |
| N_AUTHORS = len(author_nodes) |
| N_PAPERS = len(paper_nodes) |
| print(f"Authors: {N_AUTHORS}, Papers: {N_PAPERS}") |
|
|
| |
| author_deg = np.zeros(N_AUTHORS, np.float32) |
| paper_deg = np.zeros(N_PAPERS, np.float32) |
| paper_cout = np.zeros(N_PAPERS, np.float32) |
| paper_cin = np.zeros(N_PAPERS, np.float32) |
| for s, t in train_raw: author_deg[s] += 1; paper_deg[t] += 1 |
| for s, t in citing_raw: paper_cout[s] += 1; paper_cin[t] += 1 |
|
|
|
|
| def log_norm(x): |
| x = np.log1p(x); return (x - x.mean()) / (x.std() + 1e-8) |
|
|
|
|
| pf = paper_feat_raw.numpy().astype(np.float32) |
| pdeg = np.stack([log_norm(paper_deg), log_norm(paper_cout), log_norm(paper_cin)], -1) |
| PAPER_FEAT = np.concatenate([pf, pdeg], -1) |
| PAPER_FEAT = (PAPER_FEAT - PAPER_FEAT.mean(0)) / (PAPER_FEAT.std(0) + 1e-8) |
|
|
| |
| popular = np.where(paper_deg >= np.percentile(paper_deg[paper_deg > 0], 70))[0] |
| coauthor_set = {i: set() for i in range(N_AUTHORS)} |
| for s, t in coauthor_raw: coauthor_set[s].add(t); coauthor_set[t].add(s) |
| author_papers_set = {i: set() for i in range(N_AUTHORS)} |
| for s, t in train_raw: author_papers_set[s].add(t) |
| coauthor_pool = {} |
| for a in range(N_AUTHORS): |
| pool = set() |
| for c in coauthor_set[a]: pool.update(author_papers_set[c]) |
| pool -= author_papers_set[a] |
| coauthor_pool[a] = list(pool) if pool else list(range(N_PAPERS)) |
| TRAIN_SET = set(map(tuple, train_raw)) |
|
|
|
|
| |
| df_ref_idx = df_ref.copy() |
| train_90 = df_ref_idx.sample(frac=0.9, random_state=0, axis=0) |
| val_pos_df = df_ref_idx[~df_ref_idx.index.isin(train_90.index)].copy() |
| val_pos_df['label'] = 1 |
|
|
| neg_list = [] |
| while len(neg_list) < len(val_pos_df): |
| s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS) |
| if (s, d) not in TRAIN_SET: neg_list.append((s, d)) |
| val_neg_df = pd.DataFrame(neg_list, columns=['source', 'target']) |
| val_neg_df['label'] = 0 |
| VAL_DF = pd.concat([val_pos_df, val_neg_df], ignore_index=True) |
| VAL_DF = VAL_DF.sample(frac=1, random_state=0) |
|
|
|
|
| |
| def build_data(edges_df): |
| rt = torch.as_tensor(edges_df[['source','target']].to_numpy(), dtype=torch.long) |
| ct = torch.as_tensor(df_cite[['source','target']].to_numpy(), dtype=torch.long) |
| cot = torch.as_tensor(df_coa[['source','target']].to_numpy(), dtype=torch.long) |
| d = HeteroData() |
| d['author'].num_nodes = N_AUTHORS |
| d['paper'].num_nodes = N_PAPERS |
| d['paper'].x = torch.as_tensor(PAPER_FEAT, dtype=torch.float) |
| d['author','ref','paper'].edge_index = rt.t().contiguous() |
| d['paper','beref','author'].edge_index = rt[:, [1,0]].t().contiguous() |
| d['paper','cite','paper'].edge_index = torch.cat([ct, ct[:, [1,0]]], 0).t().contiguous() |
| d['author','coauthor','author'].edge_index = torch.cat([cot, cot[:, [1,0]]], 0).t().contiguous() |
| return d.to(device) |
|
|
|
|
| def sample_hard_neg(n): |
| nl = [] |
| def add_rand(tgt): |
| nonlocal nl |
| while len(nl) < tgt: |
| s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS) |
| if (s, d) not in TRAIN_SET: nl.append((s, d)) |
| add_rand(int(n * 0.5)) |
| cnt = 0 |
| while len(nl) < int(n * 0.75) and cnt < n * 2: |
| cnt += 1; s = np.random.randint(0, N_AUTHORS) |
| d = popular[np.random.randint(0, len(popular))] |
| if (s, d) not in TRAIN_SET: nl.append((s, d)) |
| cnt = 0 |
| while len(nl) < n and cnt < n * 3: |
| cnt += 1; s = np.random.randint(0, N_AUTHORS) |
| pl = coauthor_pool.get(s, []) |
| d = pl[np.random.randint(0, len(pl))] if pl else np.random.randint(0, N_PAPERS) |
| if (s, d) not in TRAIN_SET: nl.append((s, d)) |
| add_rand(n) |
| return torch.tensor(nl[:n], dtype=torch.long, device=device).t().contiguous() |
|
|
|
|
| def cos_sim(a, b, eps=1e-12): |
| return np.sum(a * b, 1) / (norm(a, 1) * norm(b, 1) + eps) |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, data): |
| model.eval() |
| z = model.encode(data) |
| zc = {k: v.cpu() for k, v in z.items()} |
| va = VAL_DF[['source','target']].to_numpy(dtype=np.int64) |
| sc = cos_sim(zc['author'][va[:,0]].numpy(), zc['paper'][va[:,1]].numpy()) |
| lb = VAL_DF['label'].to_numpy() |
| p, r, t = precision_recall_curve(lb, sc) |
| f1s = 2 * p * r / (p + r + 1e-12) |
| bi = np.argmax(f1s) |
| return f1s[bi], roc_auc_score(lb, sc), t[bi] if bi < len(t) else 0.5 |
|
|
|
|
| |
| |
| |
|
|
| class MeanAggLayer(nn.Module): |
| """LightGCN-style mean aggregation (no params).""" |
| def forward(self, xd, eid): |
| ad = {nt: [] for nt in xd} |
| for et in ETS: |
| if et not in eid: continue |
| st, _, dt = et; src, dst = eid[et]; sx = xd[st] |
| a = sx.new_zeros((xd[dt].size(0), sx.size(-1))) |
| d = sx.new_zeros((xd[dt].size(0), 1)) |
| a.index_add_(0, dst, sx[src]) |
| d.index_add_(0, dst, torch.ones((dst.numel(), 1), dtype=sx.dtype, device=sx.device)) |
| ad[dt].append(a / d.clamp(min=1.0)) |
| return {nt: sum(ad[nt]) / len(ad[nt]) if ad[nt] else xd[nt] for nt in xd} |
|
|
|
|
| class GATAggLayer(nn.Module): |
| """GATv2 attention-based aggregation.""" |
| def __init__(self, hdim, heads=2): |
| super().__init__() |
| self.conv = HeteroConv({ |
| et: GATv2Conv(hdim, hdim // heads, heads=heads, |
| add_self_loops=False, dropout=0.1) |
| for et in ETS |
| }, aggr='mean') |
|
|
| def forward(self, xd, eid): |
| h = self.conv(xd, eid) |
| return {nt: h.get(nt, xd[nt]) for nt in xd} |
|
|
|
|
| class SAGEAggLayer(nn.Module): |
| """GraphSAGE aggregation with mean pooling.""" |
| def __init__(self, hdim): |
| super().__init__() |
| self.conv = HeteroConv({ |
| et: SAGEConv(hdim, hdim) for et in ETS |
| }, aggr='mean') |
|
|
| def forward(self, xd, eid): |
| return self.conv(xd, eid) |
|
|
|
|
| |
| |
| |
|
|
| def make_vanilla_lgcn(edim=256, nlayers=4): |
| """Baseline LightGCN.""" |
| m = nn.Module() |
| m.ae = nn.Embedding(N_AUTHORS, edim) |
| m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) |
| m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) |
| m.L = nlayers |
| m._type = 'vanilla' |
|
|
| def encode(data): |
| xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} |
| als = [xd] |
| for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) |
| w = 1.0 / (m.L + 1) |
| return {nt: sum(w * l[nt] for l in als) for nt in xd} |
|
|
| def decode(zd, ei): |
| s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) |
|
|
| def reset(): |
| nn.init.xavier_uniform_(m.ae.weight) |
| nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) |
|
|
| m.encode = encode; m.decode = decode; m.reset_parameters = reset |
| m.reset_parameters() |
| return m |
|
|
|
|
| def make_learnw_lgcn(edim=256, nlayers=4): |
| """LightGCN with learnable layer weights.""" |
| m = nn.Module() |
| m.ae = nn.Embedding(N_AUTHORS, edim) |
| m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) |
| m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) |
| m.L = nlayers |
| m.layer_w = nn.Parameter(torch.ones(nlayers + 1) / (nlayers + 1)) |
| m._type = 'learnw' |
|
|
| def encode(data): |
| xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} |
| als = [xd] |
| for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) |
| w = F.softmax(m.layer_w, dim=0) |
| return {nt: sum(w[i] * l[nt] for i, l in enumerate(als)) for nt in xd} |
|
|
| def decode(zd, ei): |
| s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) |
|
|
| def reset(): |
| nn.init.xavier_uniform_(m.ae.weight) |
| nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) |
|
|
| m.encode = encode; m.decode = decode; m.reset_parameters = reset |
| m.reset_parameters() |
| return m |
|
|
|
|
| def make_gat_lgcn(edim=256, nlayers=3, heads=2): |
| """LightGCN framework with GAT aggregation.""" |
| m = nn.Module() |
| m.ae = nn.Embedding(N_AUTHORS, edim) |
| m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) |
| m.layers = nn.ModuleList([GATAggLayer(edim, heads) for _ in range(nlayers)]) |
| m.L = nlayers |
| m._type = 'gat' |
|
|
| def encode(data): |
| xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} |
| als = [xd] |
| for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) |
| w = 1.0 / (m.L + 1) |
| return {nt: sum(w * l[nt] for l in als) for nt in xd} |
|
|
| def decode(zd, ei): |
| s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) |
|
|
| def reset(): |
| nn.init.xavier_uniform_(m.ae.weight) |
| nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) |
|
|
| m.encode = encode; m.decode = decode; m.reset_parameters = reset |
| m.reset_parameters() |
| return m |
|
|
|
|
| def make_sage_lgcn(edim=256, nlayers=3): |
| """LightGCN framework with SAGE aggregation.""" |
| m = nn.Module() |
| m.ae = nn.Embedding(N_AUTHORS, edim) |
| m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) |
| m.layers = nn.ModuleList([SAGEAggLayer(edim) for _ in range(nlayers)]) |
| m.L = nlayers |
| m._type = 'sage' |
|
|
| def encode(data): |
| xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} |
| als = [xd] |
| for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) |
| w = 1.0 / (m.L + 1) |
| return {nt: F.relu(sum(w * l[nt] for l in als)) for nt in xd} |
|
|
| def decode(zd, ei): |
| s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) |
|
|
| def reset(): |
| nn.init.xavier_uniform_(m.ae.weight) |
| nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) |
|
|
| m.encode = encode; m.decode = decode; m.reset_parameters = reset |
| m.reset_parameters() |
| return m |
|
|
|
|
| def make_deep_lgcn(edim=256, nlayers=6): |
| """Deeper LightGCN (6 layers).""" |
| m = nn.Module() |
| m.ae = nn.Embedding(N_AUTHORS, edim) |
| m.pp = nn.Linear(PAPER_FEAT.shape[1], edim) |
| m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)]) |
| m.L = nlayers |
| m._type = 'deep' |
|
|
| def encode(data): |
| xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)} |
| als = [xd] |
| for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd) |
| w = 1.0 / (m.L + 1) |
| return {nt: sum(w * l[nt] for l in als) for nt in xd} |
|
|
| def decode(zd, ei): |
| s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1) |
|
|
| def reset(): |
| nn.init.xavier_uniform_(m.ae.weight) |
| nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias) |
|
|
| m.encode = encode; m.decode = decode; m.reset_parameters = reset |
| m.reset_parameters() |
| return m |
|
|
|
|
| def make_wide_lgcn(edim=384, nlayers=3): |
| """Wider LightGCN (384 dim).""" |
| return make_vanilla_lgcn(edim=edim, nlayers=nlayers) |
|
|
|
|
| |
| |
| |
| def run_trial(name, make_fn, epochs=120): |
| print(f'\n--- {name} ---') |
| t0 = time.time() |
| set_seed(0) |
| data = build_data(train_90) |
| model = make_fn().to(device) |
| opt = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5) |
| pe = data['author', 'ref', 'paper'].edge_index |
| bs = min(32768, pe.size(1)) |
| best_f1, best_th = 0, 0.5 |
|
|
| for ep in range(epochs): |
| model.train() |
| perm = torch.randperm(pe.size(1), device=device)[:bs] |
| pos = pe[:, perm] |
| neg = sample_hard_neg(pos.size(1) * 2) |
| z = model.encode(data) |
| ps = model.decode(z, pos).repeat_interleave(2) |
| ns = model.decode(z, neg) |
| loss = -F.logsigmoid(ps - ns).mean() |
| opt.zero_grad(); loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| opt.step() |
|
|
| if ep % 20 == 0 or ep == epochs - 1: |
| f1, auc, th = evaluate(model, data) |
| if f1 > best_f1: best_f1 = f1; best_th = th |
| mkr = '>' if f1 == best_f1 else ' ' |
| print(f'{mkr} ep={ep:03d} loss={loss.item():.4f} f1={f1:.4f} auc={auc:.4f}') |
|
|
| t = time.time() - t0 |
| npar = sum(p.numel() for p in model.parameters()) |
| print(f' Best: F1={best_f1:.4f} Thresh={best_th:.4f} Params={npar/1e6:.1f}M Time={t:.0f}s') |
| return best_f1 |
|
|
|
|
| |
| results = {} |
| configs = [ |
| ('1. Vanilla LightGCN (4L, 256d)', lambda: make_vanilla_lgcn(256, 4)), |
| ('2. Learnable Weights (4L, 256d)', lambda: make_learnw_lgcn(256, 4)), |
| ('3. GAT Aggregation (3L, 256d, 2h)', lambda: make_gat_lgcn(256, 3, 2)), |
| ('4. SAGE Aggregation (3L, 256d)', lambda: make_sage_lgcn(256, 3)), |
| ('5. Deep LightGCN (6L, 256d)', lambda: make_deep_lgcn(256, 6)), |
| ('6. Wide LightGCN (3L, 384d)', lambda: make_wide_lgcn(384, 3)), |
| ('7. Vanilla LightGCN (5L, 256d)', lambda: make_vanilla_lgcn(256, 5)), |
| ('8. GAT Aggregation (4L, 256d, 4h)', lambda: make_gat_lgcn(256, 4, 4)), |
| ] |
|
|
| for name, fn in configs: |
| try: |
| f1 = run_trial(name, fn) |
| results[name] = f1 |
| except Exception as e: |
| print(f' FAILED: {e}') |
|
|
| print('\n' + '=' * 60) |
| print('RESULTS (sorted by val F1):') |
| print('=' * 60) |
| for name, f1 in sorted(results.items(), key=lambda x: -x[1]): |
| print(f' {f1:.4f} {name}') |
|
|