cs3319-project2 / code /run_lgcn_v2.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
12.4 kB
"""LightGCN V2: L2-normalized embeddings for consistent training/eval + bigger model."""
import os
import pickle as pkl
import random
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch_geometric.data import HeteroData
from numpy.linalg import norm
device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu')
print('device:', device)
REBUILD = True # Set False to reuse saved models
def set_seed(seed=0):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
# ── Data ──────────────────────────────────────────────────────────
base_path = "/home/lzc/cs3319-project"
def read_txt(file):
res_list = []
with open(file, "r") as f:
for line in f:
res_list.append(list(map(int, line.strip().split())))
return res_list
citation = read_txt(os.path.join(base_path, "paper_file_ann.txt"))
existing_refs = read_txt(os.path.join(base_path, "bipartite_train_ann.txt"))
refs_to_pred = read_txt(os.path.join(base_path, "bipartite_test_ann.txt"))
coauthor = read_txt(os.path.join(base_path, "author_file_ann.txt"))
with open(os.path.join(base_path, "feature.pkl"), 'rb') as f:
paper_feature = pkl.load(f)
cite_edges = pd.DataFrame(citation, columns=['source', 'target'])
ref_edges = pd.DataFrame(existing_refs, columns=['source', 'target'])
coauthor_edges = pd.DataFrame(coauthor, columns=['source', 'target'])
node_tmp = pd.concat([cite_edges['source'], cite_edges['target'], ref_edges['target']])
node_papers = pd.DataFrame(index=pd.unique(node_tmp))
node_tmp = pd.concat([ref_edges['source'], coauthor_edges['source'], coauthor_edges['target']])
node_authors = pd.DataFrame(index=pd.unique(node_tmp))
num_authors = len(node_authors)
num_papers = len(node_papers)
# Features
author_ref_deg = np.zeros(num_authors, dtype=np.float32)
paper_ref_deg = np.zeros(num_papers, dtype=np.float32)
paper_cite_out = np.zeros(num_papers, dtype=np.float32)
paper_cite_in = np.zeros(num_papers, dtype=np.float32)
for s, t in existing_refs:
author_ref_deg[s] += 1
paper_ref_deg[t] += 1
for s, t in citation:
paper_cite_out[s] += 1
paper_cite_in[t] += 1
def log_norm(x):
x = np.log1p(x)
return (x - x.mean()) / (x.std() + 1e-8)
paper_feat_np = paper_feature.numpy().astype(np.float32)
paper_deg_feat = np.stack([log_norm(paper_ref_deg), log_norm(paper_cite_out),
log_norm(paper_cite_in)], axis=-1)
paper_feat_aug = np.concatenate([paper_feat_np, paper_deg_feat], axis=-1)
paper_feat_aug = (paper_feat_aug - paper_feat_aug.mean(axis=0)) / (paper_feat_aug.std(axis=0) + 1e-8)
print(f"Paper features: {paper_feat_aug.shape[1]}d")
# Negative pools
popular_threshold = np.percentile(paper_ref_deg[paper_ref_deg > 0], 70)
popular_papers = np.where(paper_ref_deg >= popular_threshold)[0]
coauthor_map = {i: set() for i in range(num_authors)}
for s, t in coauthor:
coauthor_map[s].add(t)
coauthor_map[t].add(s)
author_papers = {i: set() for i in range(num_authors)}
for s, t in existing_refs:
author_papers[s].add(t)
coauthor_paper_pool = {}
for a in range(num_authors):
pool = set()
for c in coauthor_map[a]:
pool.update(author_papers[c])
pool -= author_papers[a]
coauthor_paper_pool[a] = list(pool) if pool else list(range(num_papers))
existing_ref_set = set(map(tuple, existing_refs))
train_set = set(map(tuple, existing_refs))
overlap = train_set & set(map(tuple, refs_to_pred))
print(f"Known positives: {len(overlap)}")
# ── Graph ─────────────────────────────────────────────────────────
def build_data(ref_edges_use):
ref_tensor = torch.as_tensor(ref_edges_use[['source', 'target']].to_numpy(), dtype=torch.long)
cite_tensor = torch.as_tensor(cite_edges[['source', 'target']].to_numpy(), dtype=torch.long)
coauthor_tensor = torch.as_tensor(coauthor_edges[['source', 'target']].to_numpy(), dtype=torch.long)
d = HeteroData()
d['author'].num_nodes = num_authors
d['paper'].num_nodes = num_papers
d['paper'].x = torch.as_tensor(paper_feat_aug, dtype=torch.float)
d['author', 'ref', 'paper'].edge_index = ref_tensor.t().contiguous()
d['paper', 'beref', 'author'].edge_index = ref_tensor[:, [1, 0]].t().contiguous()
d['paper', 'cite', 'paper'].edge_index = torch.cat([
cite_tensor, cite_tensor[:, [1, 0]]], dim=0).t().contiguous()
d['author', 'coauthor', 'author'].edge_index = torch.cat([
coauthor_tensor, coauthor_tensor[:, [1, 0]]], dim=0).t().contiguous()
return d.to(device)
def sample_hard_negatives(n_samples):
neg_list = []
def add_random(target):
nonlocal neg_list
while len(neg_list) < target:
s = np.random.randint(0, num_authors)
d = np.random.randint(0, num_papers)
if (s, d) not in existing_ref_set:
neg_list.append((s, d))
add_random(int(n_samples * 0.5))
cnt = 0
while len(neg_list) < int(n_samples * 0.75) and cnt < n_samples * 2:
cnt += 1
s = np.random.randint(0, num_authors)
d = popular_papers[np.random.randint(0, len(popular_papers))]
if (s, d) not in existing_ref_set:
neg_list.append((s, d))
cnt = 0
while len(neg_list) < n_samples and cnt < n_samples * 3:
cnt += 1
s = np.random.randint(0, num_authors)
pool = coauthor_paper_pool.get(s, [])
if pool:
d = pool[np.random.randint(0, len(pool))]
if (s, d) not in existing_ref_set:
neg_list.append((s, d))
add_random(n_samples)
return torch.tensor(neg_list[:n_samples], dtype=torch.long, device=device).t().contiguous()
# ── Model with L2 normalization ───────────────────────────────────
class LightGCNLayer(nn.Module):
def __init__(self):
super().__init__()
self.ets = [('author', 'ref', 'paper'), ('paper', 'beref', 'author'),
('paper', 'cite', 'paper'), ('author', 'coauthor', 'author')]
def forward(self, x_dict, edge_index_dict):
agg_dict = {nt: [] for nt in x_dict}
for et in self.ets:
if et not in edge_index_dict:
continue
st, _, dt = et
src, dst = edge_index_dict[et]
sx = x_dict[st]
a = sx.new_zeros((x_dict[dt].size(0), sx.size(-1)))
d = sx.new_zeros((x_dict[dt].size(0), 1))
a.index_add_(0, dst, sx[src])
d.index_add_(0, dst, torch.ones((dst.numel(), 1), dtype=sx.dtype, device=sx.device))
agg_dict[dt].append(a / d.clamp(min=1.0))
return {nt: sum(agg_dict[nt]) / len(agg_dict[nt]) if agg_dict[nt] else x_dict[nt]
for nt in x_dict}
class LightGCN_Norm(nn.Module):
"""LightGCN with L2-normalized embeddings for cosine-similarity training."""
def __init__(self, embed_dim=256, num_layers=4):
super().__init__()
self.author_emb = nn.Embedding(num_authors, embed_dim)
self.paper_proj = nn.Linear(paper_feat_aug.shape[1], embed_dim)
self.layers = nn.ModuleList([LightGCNLayer() for _ in range(num_layers)])
self.num_layers = num_layers
self.reset_parameters()
def reset_parameters(self):
nn.init.xavier_uniform_(self.author_emb.weight)
nn.init.xavier_uniform_(self.paper_proj.weight)
nn.init.zeros_(self.paper_proj.bias)
def encode(self, data, normalize=True):
x_dict = {'author': self.author_emb.weight,
'paper': self.paper_proj(data['paper'].x)}
all_layers = [x_dict]
for layer in self.layers:
x_dict = layer(x_dict, data.edge_index_dict)
all_layers.append(x_dict)
w = 1.0 / (self.num_layers + 1)
result = {nt: sum(w * l[nt] for l in all_layers) for nt in x_dict}
if normalize:
result = {k: F.normalize(v, p=2, dim=-1) for k, v in result.items()}
return result
def decode(self, z_dict, edge_index):
src, dst = edge_index
return (z_dict['author'][src] * z_dict['paper'][dst]).sum(dim=-1)
def cos_sim(a, b, eps=1e-12):
return np.sum(a * b, axis=1) / (norm(a, axis=1) * norm(b, axis=1) + eps)
@torch.no_grad()
def predict_cos(model, data, pairs, batch_size=65536):
model.eval()
z_dict = model.encode(data, normalize=True)
z_cpu = {k: v.cpu() for k, v in z_dict.items()}
all_scores = []
for start in range(0, len(pairs), batch_size):
end = min(start + batch_size, len(pairs))
batch = pairs[start:end]
all_scores.append(cos_sim(
z_cpu['author'][batch[:, 0]].numpy(),
z_cpu['paper'][batch[:, 1]].numpy()))
return np.concatenate(all_scores)
# ── Training ──────────────────────────────────────────────────────
def train(seed, embed_dim=256, num_layers=4, lr=0.005, epochs=200):
set_seed(seed)
data = build_data(ref_edges)
model = LightGCN_Norm(embed_dim, num_layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=lr, weight_decay=1e-5)
pos_edges = data['author', 'ref', 'paper'].edge_index
bs = min(32768, pos_edges.size(1))
for ep in range(epochs):
model.train()
perm = torch.randperm(pos_edges.size(1), device=device)[:bs]
pos = pos_edges[:, perm]
neg = sample_hard_negatives(pos.size(1) * 2)
z = model.encode(data, normalize=True) # L2-normalized
pos_s = model.decode(z, pos).repeat_interleave(2)
neg_s = model.decode(z, neg)
loss = -F.logsigmoid(pos_s - neg_s).mean()
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
if ep % 50 == 0 or ep == epochs - 1:
print(f' [{seed}] ep={ep:03d} loss={loss.item():.4f}')
save_path = f'/home/lzc/model_lgcn_norm_s{seed}_d{embed_dim}.pt'
torch.save(model.state_dict(), save_path)
return model.cpu(), data
# ── Main ──────────────────────────────────────────────────────────
test_arr = np.array(refs_to_pred, dtype=np.int64)
# Train 5 models with L2 normalization
configs = [
(0, 256, 4),
(42, 256, 4),
(2024, 256, 4),
(10, 256, 4),
(100, 256, 4),
]
models = []
for seed, emb_dim, n_layers in configs:
print(f"\n{'='*50}\nLGCN-norm seed={seed} dim={emb_dim} layers={n_layers}\n{'='*50}")
m, d = train(seed, embed_dim=emb_dim, num_layers=n_layers, epochs=200)
models.append((m, d))
# ── Prediction ────────────────────────────────────────────────────
print(f"\n{'='*50}\nEnsemble prediction\n{'='*50}")
data_full = build_data(ref_edges)
all_scores = []
for i, (model, _) in enumerate(models):
model = model.to(device)
scores = predict_cos(model, data_full, test_arr)
all_scores.append(scores)
print(f" Model {i} (s={configs[i][0]}): mean={scores.mean():.4f} std={scores.std():.4f}")
model.cpu()
ensemble = np.mean(all_scores, axis=0)
# Force known positives
known_mask = np.array([tuple(p) in overlap for p in refs_to_pred])
ensemble[known_mask] = 1.0
print(f"\nEnsemble: mean={ensemble.mean():.4f} min={ensemble.min():.4f} max={ensemble.max():.4f}")
# Multiple thresholds
for thresh in [0.30, 0.32, 0.34, 0.35, 0.36, 0.37, 0.38, 0.40, 0.42, 0.45]:
preds = (ensemble >= thresh).astype(int)
path = f"/home/lzc/sub_norm_t{thresh:.2f}.csv"
data_out = [[idx, str(int(p))] for idx, p in enumerate(preds)]
pd.DataFrame(data_out, columns=['Index', 'Predicted'], dtype=object).to_csv(path, index=False)
print(f" t={thresh:.2f}: pos={preds.mean():.4f} (extra={preds.sum()-known_mask.sum()})")
print("\nDone!")