cs3319-project2 / code /compare_gnn.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
16.1 kB
"""Compare different GNN architectures on validation set."""
import os, pickle as pkl, random, time
import numpy as np
import pandas as pd
import torch, torch.nn as nn, torch.nn.functional as F
from torch_geometric.data import HeteroData
from torch_geometric.nn import GATv2Conv, HeteroConv, SAGEConv
from sklearn.metrics import precision_recall_curve, roc_auc_score
from numpy.linalg import norm
device = torch.device('cuda:0')
ETS = [('author','ref','paper'),('paper','beref','author'),
('paper','cite','paper'),('author','coauthor','author')]
def set_seed(seed=0):
random.seed(seed); np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available(): torch.cuda.manual_seed_all(seed)
# ── Data loading ─────────────────────────────────────────────────
base = '/home/lzc/cs3319-project'
def read_txt(f):
res = []
with open(f) as fh:
for line in fh: res.append(list(map(int, line.strip().split())))
return res
train_raw = read_txt(f'{base}/bipartite_train_ann.txt')
test_raw = read_txt(f'{base}/bipartite_test_ann.txt')
citing_raw = read_txt(f'{base}/paper_file_ann.txt')
coauthor_raw = read_txt(f'{base}/author_file_ann.txt')
with open(f'{base}/feature.pkl', 'rb') as f: paper_feat_raw = pkl.load(f)
# Build node sets
df_cite = pd.DataFrame(citing_raw, columns=['source','target'])
df_ref = pd.DataFrame(train_raw, columns=['source','target'])
df_coa = pd.DataFrame(coauthor_raw, columns=['source','target'])
tmp = pd.concat([df_cite['source'], df_cite['target'], df_ref['target']])
paper_nodes = pd.DataFrame(index=pd.unique(tmp))
tmp = pd.concat([df_ref['source'], df_coa['source'], df_coa['target']])
author_nodes = pd.DataFrame(index=pd.unique(tmp))
N_AUTHORS = len(author_nodes)
N_PAPERS = len(paper_nodes)
print(f"Authors: {N_AUTHORS}, Papers: {N_PAPERS}")
# Degree features
author_deg = np.zeros(N_AUTHORS, np.float32)
paper_deg = np.zeros(N_PAPERS, np.float32)
paper_cout = np.zeros(N_PAPERS, np.float32)
paper_cin = np.zeros(N_PAPERS, np.float32)
for s, t in train_raw: author_deg[s] += 1; paper_deg[t] += 1
for s, t in citing_raw: paper_cout[s] += 1; paper_cin[t] += 1
def log_norm(x):
x = np.log1p(x); return (x - x.mean()) / (x.std() + 1e-8)
pf = paper_feat_raw.numpy().astype(np.float32)
pdeg = np.stack([log_norm(paper_deg), log_norm(paper_cout), log_norm(paper_cin)], -1)
PAPER_FEAT = np.concatenate([pf, pdeg], -1)
PAPER_FEAT = (PAPER_FEAT - PAPER_FEAT.mean(0)) / (PAPER_FEAT.std(0) + 1e-8)
# Hard negative pools
popular = np.where(paper_deg >= np.percentile(paper_deg[paper_deg > 0], 70))[0]
coauthor_set = {i: set() for i in range(N_AUTHORS)}
for s, t in coauthor_raw: coauthor_set[s].add(t); coauthor_set[t].add(s)
author_papers_set = {i: set() for i in range(N_AUTHORS)}
for s, t in train_raw: author_papers_set[s].add(t)
coauthor_pool = {}
for a in range(N_AUTHORS):
pool = set()
for c in coauthor_set[a]: pool.update(author_papers_set[c])
pool -= author_papers_set[a]
coauthor_pool[a] = list(pool) if pool else list(range(N_PAPERS))
TRAIN_SET = set(map(tuple, train_raw))
# Train/val split (90/10)
df_ref_idx = df_ref.copy()
train_90 = df_ref_idx.sample(frac=0.9, random_state=0, axis=0)
val_pos_df = df_ref_idx[~df_ref_idx.index.isin(train_90.index)].copy()
val_pos_df['label'] = 1
neg_list = []
while len(neg_list) < len(val_pos_df):
s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS)
if (s, d) not in TRAIN_SET: neg_list.append((s, d))
val_neg_df = pd.DataFrame(neg_list, columns=['source', 'target'])
val_neg_df['label'] = 0
VAL_DF = pd.concat([val_pos_df, val_neg_df], ignore_index=True)
VAL_DF = VAL_DF.sample(frac=1, random_state=0)
# ── Graph building ────────────────────────────────────────────────
def build_data(edges_df):
rt = torch.as_tensor(edges_df[['source','target']].to_numpy(), dtype=torch.long)
ct = torch.as_tensor(df_cite[['source','target']].to_numpy(), dtype=torch.long)
cot = torch.as_tensor(df_coa[['source','target']].to_numpy(), dtype=torch.long)
d = HeteroData()
d['author'].num_nodes = N_AUTHORS
d['paper'].num_nodes = N_PAPERS
d['paper'].x = torch.as_tensor(PAPER_FEAT, dtype=torch.float)
d['author','ref','paper'].edge_index = rt.t().contiguous()
d['paper','beref','author'].edge_index = rt[:, [1,0]].t().contiguous()
d['paper','cite','paper'].edge_index = torch.cat([ct, ct[:, [1,0]]], 0).t().contiguous()
d['author','coauthor','author'].edge_index = torch.cat([cot, cot[:, [1,0]]], 0).t().contiguous()
return d.to(device)
def sample_hard_neg(n):
nl = []
def add_rand(tgt):
nonlocal nl
while len(nl) < tgt:
s = np.random.randint(0, N_AUTHORS); d = np.random.randint(0, N_PAPERS)
if (s, d) not in TRAIN_SET: nl.append((s, d))
add_rand(int(n * 0.5))
cnt = 0
while len(nl) < int(n * 0.75) and cnt < n * 2:
cnt += 1; s = np.random.randint(0, N_AUTHORS)
d = popular[np.random.randint(0, len(popular))]
if (s, d) not in TRAIN_SET: nl.append((s, d))
cnt = 0
while len(nl) < n and cnt < n * 3:
cnt += 1; s = np.random.randint(0, N_AUTHORS)
pl = coauthor_pool.get(s, [])
d = pl[np.random.randint(0, len(pl))] if pl else np.random.randint(0, N_PAPERS)
if (s, d) not in TRAIN_SET: nl.append((s, d))
add_rand(n)
return torch.tensor(nl[:n], dtype=torch.long, device=device).t().contiguous()
def cos_sim(a, b, eps=1e-12):
return np.sum(a * b, 1) / (norm(a, 1) * norm(b, 1) + eps)
@torch.no_grad()
def evaluate(model, data):
model.eval()
z = model.encode(data)
zc = {k: v.cpu() for k, v in z.items()}
va = VAL_DF[['source','target']].to_numpy(dtype=np.int64)
sc = cos_sim(zc['author'][va[:,0]].numpy(), zc['paper'][va[:,1]].numpy())
lb = VAL_DF['label'].to_numpy()
p, r, t = precision_recall_curve(lb, sc)
f1s = 2 * p * r / (p + r + 1e-12)
bi = np.argmax(f1s)
return f1s[bi], roc_auc_score(lb, sc), t[bi] if bi < len(t) else 0.5
# ═══════════════════════════════════════════════════════════════════
# GNN Layer Implementations
# ═══════════════════════════════════════════════════════════════════
class MeanAggLayer(nn.Module):
"""LightGCN-style mean aggregation (no params)."""
def forward(self, xd, eid):
ad = {nt: [] for nt in xd}
for et in ETS:
if et not in eid: continue
st, _, dt = et; src, dst = eid[et]; sx = xd[st]
a = sx.new_zeros((xd[dt].size(0), sx.size(-1)))
d = sx.new_zeros((xd[dt].size(0), 1))
a.index_add_(0, dst, sx[src])
d.index_add_(0, dst, torch.ones((dst.numel(), 1), dtype=sx.dtype, device=sx.device))
ad[dt].append(a / d.clamp(min=1.0))
return {nt: sum(ad[nt]) / len(ad[nt]) if ad[nt] else xd[nt] for nt in xd}
class GATAggLayer(nn.Module):
"""GATv2 attention-based aggregation."""
def __init__(self, hdim, heads=2):
super().__init__()
self.conv = HeteroConv({
et: GATv2Conv(hdim, hdim // heads, heads=heads,
add_self_loops=False, dropout=0.1)
for et in ETS
}, aggr='mean')
def forward(self, xd, eid):
h = self.conv(xd, eid)
return {nt: h.get(nt, xd[nt]) for nt in xd}
class SAGEAggLayer(nn.Module):
"""GraphSAGE aggregation with mean pooling."""
def __init__(self, hdim):
super().__init__()
self.conv = HeteroConv({
et: SAGEConv(hdim, hdim) for et in ETS
}, aggr='mean')
def forward(self, xd, eid):
return self.conv(xd, eid)
# ═══════════════════════════════════════════════════════════════════
# Model Builders
# ═══════════════════════════════════════════════════════════════════
def make_vanilla_lgcn(edim=256, nlayers=4):
"""Baseline LightGCN."""
m = nn.Module()
m.ae = nn.Embedding(N_AUTHORS, edim)
m.pp = nn.Linear(PAPER_FEAT.shape[1], edim)
m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)])
m.L = nlayers
m._type = 'vanilla'
def encode(data):
xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)}
als = [xd]
for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd)
w = 1.0 / (m.L + 1)
return {nt: sum(w * l[nt] for l in als) for nt in xd}
def decode(zd, ei):
s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1)
def reset():
nn.init.xavier_uniform_(m.ae.weight)
nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias)
m.encode = encode; m.decode = decode; m.reset_parameters = reset
m.reset_parameters()
return m
def make_learnw_lgcn(edim=256, nlayers=4):
"""LightGCN with learnable layer weights."""
m = nn.Module()
m.ae = nn.Embedding(N_AUTHORS, edim)
m.pp = nn.Linear(PAPER_FEAT.shape[1], edim)
m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)])
m.L = nlayers
m.layer_w = nn.Parameter(torch.ones(nlayers + 1) / (nlayers + 1))
m._type = 'learnw'
def encode(data):
xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)}
als = [xd]
for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd)
w = F.softmax(m.layer_w, dim=0)
return {nt: sum(w[i] * l[nt] for i, l in enumerate(als)) for nt in xd}
def decode(zd, ei):
s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1)
def reset():
nn.init.xavier_uniform_(m.ae.weight)
nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias)
m.encode = encode; m.decode = decode; m.reset_parameters = reset
m.reset_parameters()
return m
def make_gat_lgcn(edim=256, nlayers=3, heads=2):
"""LightGCN framework with GAT aggregation."""
m = nn.Module()
m.ae = nn.Embedding(N_AUTHORS, edim)
m.pp = nn.Linear(PAPER_FEAT.shape[1], edim)
m.layers = nn.ModuleList([GATAggLayer(edim, heads) for _ in range(nlayers)])
m.L = nlayers
m._type = 'gat'
def encode(data):
xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)}
als = [xd]
for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd)
w = 1.0 / (m.L + 1)
return {nt: sum(w * l[nt] for l in als) for nt in xd}
def decode(zd, ei):
s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1)
def reset():
nn.init.xavier_uniform_(m.ae.weight)
nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias)
m.encode = encode; m.decode = decode; m.reset_parameters = reset
m.reset_parameters()
return m
def make_sage_lgcn(edim=256, nlayers=3):
"""LightGCN framework with SAGE aggregation."""
m = nn.Module()
m.ae = nn.Embedding(N_AUTHORS, edim)
m.pp = nn.Linear(PAPER_FEAT.shape[1], edim)
m.layers = nn.ModuleList([SAGEAggLayer(edim) for _ in range(nlayers)])
m.L = nlayers
m._type = 'sage'
def encode(data):
xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)}
als = [xd]
for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd)
w = 1.0 / (m.L + 1)
return {nt: F.relu(sum(w * l[nt] for l in als)) for nt in xd}
def decode(zd, ei):
s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1)
def reset():
nn.init.xavier_uniform_(m.ae.weight)
nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias)
m.encode = encode; m.decode = decode; m.reset_parameters = reset
m.reset_parameters()
return m
def make_deep_lgcn(edim=256, nlayers=6):
"""Deeper LightGCN (6 layers)."""
m = nn.Module()
m.ae = nn.Embedding(N_AUTHORS, edim)
m.pp = nn.Linear(PAPER_FEAT.shape[1], edim)
m.layers = nn.ModuleList([MeanAggLayer() for _ in range(nlayers)])
m.L = nlayers
m._type = 'deep'
def encode(data):
xd = {'author': m.ae.weight, 'paper': m.pp(data['paper'].x)}
als = [xd]
for l in m.layers: xd = l(xd, data.edge_index_dict); als.append(xd)
w = 1.0 / (m.L + 1)
return {nt: sum(w * l[nt] for l in als) for nt in xd}
def decode(zd, ei):
s, d = ei; return (zd['author'][s] * zd['paper'][d]).sum(-1)
def reset():
nn.init.xavier_uniform_(m.ae.weight)
nn.init.xavier_uniform_(m.pp.weight); nn.init.zeros_(m.pp.bias)
m.encode = encode; m.decode = decode; m.reset_parameters = reset
m.reset_parameters()
return m
def make_wide_lgcn(edim=384, nlayers=3):
"""Wider LightGCN (384 dim)."""
return make_vanilla_lgcn(edim=edim, nlayers=nlayers)
# ═══════════════════════════════════════════════════════════════════
# Training
# ═══════════════════════════════════════════════════════════════════
def run_trial(name, make_fn, epochs=120):
print(f'\n--- {name} ---')
t0 = time.time()
set_seed(0)
data = build_data(train_90)
model = make_fn().to(device)
opt = torch.optim.Adam(model.parameters(), lr=0.005, weight_decay=1e-5)
pe = data['author', 'ref', 'paper'].edge_index
bs = min(32768, pe.size(1))
best_f1, best_th = 0, 0.5
for ep in range(epochs):
model.train()
perm = torch.randperm(pe.size(1), device=device)[:bs]
pos = pe[:, perm]
neg = sample_hard_neg(pos.size(1) * 2)
z = model.encode(data)
ps = model.decode(z, pos).repeat_interleave(2)
ns = model.decode(z, neg)
loss = -F.logsigmoid(ps - ns).mean()
opt.zero_grad(); loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
opt.step()
if ep % 20 == 0 or ep == epochs - 1:
f1, auc, th = evaluate(model, data)
if f1 > best_f1: best_f1 = f1; best_th = th
mkr = '>' if f1 == best_f1 else ' '
print(f'{mkr} ep={ep:03d} loss={loss.item():.4f} f1={f1:.4f} auc={auc:.4f}')
t = time.time() - t0
npar = sum(p.numel() for p in model.parameters())
print(f' Best: F1={best_f1:.4f} Thresh={best_th:.4f} Params={npar/1e6:.1f}M Time={t:.0f}s')
return best_f1
# Run all
results = {}
configs = [
('1. Vanilla LightGCN (4L, 256d)', lambda: make_vanilla_lgcn(256, 4)),
('2. Learnable Weights (4L, 256d)', lambda: make_learnw_lgcn(256, 4)),
('3. GAT Aggregation (3L, 256d, 2h)', lambda: make_gat_lgcn(256, 3, 2)),
('4. SAGE Aggregation (3L, 256d)', lambda: make_sage_lgcn(256, 3)),
('5. Deep LightGCN (6L, 256d)', lambda: make_deep_lgcn(256, 6)),
('6. Wide LightGCN (3L, 384d)', lambda: make_wide_lgcn(384, 3)),
('7. Vanilla LightGCN (5L, 256d)', lambda: make_vanilla_lgcn(256, 5)),
('8. GAT Aggregation (4L, 256d, 4h)', lambda: make_gat_lgcn(256, 4, 4)),
]
for name, fn in configs:
try:
f1 = run_trial(name, fn)
results[name] = f1
except Exception as e:
print(f' FAILED: {e}')
print('\n' + '=' * 60)
print('RESULTS (sorted by val F1):')
print('=' * 60)
for name, f1 in sorted(results.items(), key=lambda x: -x[1]):
print(f' {f1:.4f} {name}')