cs3319-project2 / code /train_val_lgcn_ensemble.py
NLP-beginner's picture
CS3319 Project 2 final deliverable (public F1 = 0.96626)
f28d994
Raw
History Blame Contribute Delete
20.1 kB
"""Train and validate LightGCN ensembles on the notebook-style split."""
from __future__ import annotations
import argparse
import pickle as pkl
import random
from pathlib import Path
import numpy as np
import pandas as pd
import torch
import torch.nn as nn
import torch.nn.functional as F
from numpy.linalg import norm
from sklearn.metrics import precision_recall_curve, roc_auc_score
from torch_geometric.data import HeteroData
EDGE_TYPES = [
("author", "ref", "paper"),
("paper", "beref", "author"),
("paper", "cite", "paper"),
("author", "coauthor", "author"),
]
def set_seed(seed: int) -> None:
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed_all(seed)
def read_txt(path: Path) -> list[list[int]]:
rows: list[list[int]] = []
with path.open("r") as f:
for line in f:
rows.append(list(map(int, line.strip().split())))
return rows
def log_norm(x: np.ndarray) -> np.ndarray:
x = np.log1p(x)
return (x - x.mean()) / (x.std() + 1e-8)
class LightGCNLayer(nn.Module):
def forward(self, x_dict, edge_index_dict):
agg_dict = {node_type: [] for node_type in x_dict}
for edge_type in EDGE_TYPES:
if edge_type not in edge_index_dict:
continue
src_type, _, dst_type = edge_type
src, dst = edge_index_dict[edge_type]
src_x = x_dict[src_type]
agg = src_x.new_zeros((x_dict[dst_type].size(0), src_x.size(-1)))
deg = src_x.new_zeros((x_dict[dst_type].size(0), 1))
agg.index_add_(0, dst, src_x[src])
deg.index_add_(
0,
dst,
torch.ones((dst.numel(), 1), dtype=src_x.dtype, device=src_x.device),
)
agg_dict[dst_type].append(agg / deg.clamp(min=1.0))
return {
node_type: sum(aggs) / len(aggs) if aggs else x_dict[node_type]
for node_type, aggs in agg_dict.items()
}
class LightGCN(nn.Module):
def __init__(self, num_authors: int, paper_feat_dim: int, embed_dim: int, num_layers: int = 4):
super().__init__()
self.author_emb = nn.Embedding(num_authors, embed_dim)
self.paper_proj = nn.Linear(paper_feat_dim, embed_dim)
self.layers = nn.ModuleList([LightGCNLayer() for _ in range(num_layers)])
self.num_layers = num_layers
self.reset_parameters()
def reset_parameters(self) -> None:
nn.init.xavier_uniform_(self.author_emb.weight)
nn.init.xavier_uniform_(self.paper_proj.weight)
nn.init.zeros_(self.paper_proj.bias)
def encode(self, data):
x_dict = {
"author": self.author_emb.weight,
"paper": self.paper_proj(data["paper"].x),
}
all_layers = [x_dict]
for layer in self.layers:
x_dict = layer(x_dict, data.edge_index_dict)
all_layers.append(x_dict)
weight = 1.0 / (self.num_layers + 1)
return {
node_type: sum(weight * layer[node_type] for layer in all_layers)
for node_type in x_dict
}
def decode(self, z_dict, edge_index):
src, dst = edge_index
return (z_dict["author"][src] * z_dict["paper"][dst]).sum(dim=-1)
class LearnableWeightLightGCN(LightGCN):
def __init__(self, num_authors: int, paper_feat_dim: int, embed_dim: int, num_layers: int = 4):
super().__init__(num_authors, paper_feat_dim, embed_dim, num_layers)
self.layer_weight = nn.Parameter(torch.zeros(num_layers + 1))
def encode(self, data):
x_dict = {
"author": self.author_emb.weight,
"paper": self.paper_proj(data["paper"].x),
}
all_layers = [x_dict]
for layer in self.layers:
x_dict = layer(x_dict, data.edge_index_dict)
all_layers.append(x_dict)
weights = F.softmax(self.layer_weight, dim=0)
return {
node_type: sum(weights[i] * layer[node_type] for i, layer in enumerate(all_layers))
for node_type in x_dict
}
def cos_sim(a: np.ndarray, b: np.ndarray, eps: float = 1e-12) -> np.ndarray:
return np.sum(a * b, axis=1) / (norm(a, axis=1) * norm(b, axis=1) + eps)
def make_notebook_style_split(root: Path, seed: int, train_frac: float):
data_dir = root / "data_and_docs"
refs = read_txt(data_dir / "bipartite_train_ann.txt")
coauthor = read_txt(data_dir / "author_file_ann.txt")
citation = read_txt(data_dir / "paper_file_ann.txt")
ref_edges = pd.DataFrame(refs, columns=["source", "target"])
ref_edges = ref_edges.set_index("r-" + ref_edges.index.astype(str))
coauthor_edges = pd.DataFrame(coauthor, columns=["source", "target"])
citation_edges = pd.DataFrame(citation, columns=["source", "target"])
node_tmp = pd.concat([citation_edges["source"], citation_edges["target"], ref_edges["target"]])
paper_ids = pd.unique(node_tmp).astype(np.int64)
node_tmp = pd.concat([ref_edges["source"], coauthor_edges["source"], coauthor_edges["target"]])
author_ids = pd.unique(node_tmp).astype(np.int64)
train_refs = ref_edges.sample(frac=train_frac, random_state=seed, axis=0)
val_pos = ref_edges[~ref_edges.index.isin(train_refs.index)].copy()
val_pos.loc[:, "label"] = 1
existing_ref_set = set(map(tuple, ref_edges[["source", "target"]].to_numpy().tolist()))
neg_pairs: list[tuple[int, int]] = []
rng = np.random.default_rng(seed)
while len(neg_pairs) < len(val_pos):
src = int(rng.choice(author_ids))
dst = int(rng.choice(paper_ids))
if (src, dst) not in existing_ref_set:
neg_pairs.append((src, dst))
val_neg = pd.DataFrame(neg_pairs, columns=["source", "target"])
val_neg.loc[:, "label"] = 0
val_pairs = pd.concat([val_pos.reset_index(drop=True), val_neg], ignore_index=True)
val_pairs = val_pairs.sample(frac=1, random_state=seed, axis=0).reset_index(drop=True)
return train_refs[["source", "target"]].reset_index(drop=True), val_pairs
def build_parts(
root: Path,
split_dir: Path | None,
num_papers: int,
split_seed: int | None = None,
train_frac: float = 0.9,
):
data_dir = root / "data_and_docs"
if split_seed is None:
if split_dir is None:
raise ValueError("split_dir is required when split_seed is not set")
train_refs = pd.read_csv(split_dir / "train_refs.csv")
val_pairs = pd.read_csv(split_dir / "val_pairs.csv")
else:
train_refs, val_pairs = make_notebook_style_split(root, split_seed, train_frac)
citation = read_txt(data_dir / "paper_file_ann.txt")
coauthor = read_txt(data_dir / "author_file_ann.txt")
with (data_dir / "feature.pkl").open("rb") as f:
paper_feature = pkl.load(f)
paper_ref_deg = np.zeros(num_papers, dtype=np.float32)
paper_cite_out = np.zeros(num_papers, dtype=np.float32)
paper_cite_in = np.zeros(num_papers, dtype=np.float32)
for _, paper in train_refs[["source", "target"]].to_numpy():
paper_ref_deg[paper] += 1
for source, target in citation:
paper_cite_out[source] += 1
paper_cite_in[target] += 1
paper_feat_np = paper_feature.numpy().astype(np.float32)
paper_deg_feat = np.stack(
[log_norm(paper_ref_deg), log_norm(paper_cite_out), log_norm(paper_cite_in)],
axis=-1,
)
paper_feat_aug = np.concatenate([paper_feat_np, paper_deg_feat], axis=-1)
paper_feat_aug = (paper_feat_aug - paper_feat_aug.mean(axis=0)) / (
paper_feat_aug.std(axis=0) + 1e-8
)
coauthor_map = {}
for s, t in coauthor:
coauthor_map.setdefault(s, set()).add(t)
coauthor_map.setdefault(t, set()).add(s)
author_papers = {}
for s, t in train_refs[["source", "target"]].to_numpy():
author_papers.setdefault(int(s), set()).add(int(t))
coauthor_pool = {}
for author in range(6611):
pool = set()
for co in coauthor_map.get(author, ()):
pool.update(author_papers.get(co, ()))
pool -= author_papers.get(author, set())
coauthor_pool[author] = np.array(list(pool), dtype=np.int64) if pool else None
popular_threshold = np.percentile(paper_ref_deg[paper_ref_deg > 0], 70)
popular = np.where(paper_ref_deg >= popular_threshold)[0]
train_set = set(map(tuple, train_refs[["source", "target"]].to_numpy().tolist()))
return {
"train_refs": train_refs,
"val_pairs": val_pairs,
"citation": pd.DataFrame(citation, columns=["source", "target"]),
"coauthor": pd.DataFrame(coauthor, columns=["source", "target"]),
"paper_feat_aug": paper_feat_aug,
"popular": popular,
"coauthor_pool": coauthor_pool,
"train_set": train_set,
}
def build_data(
parts,
num_authors: int,
num_papers: int,
device: torch.device,
use_citation: bool = True,
use_coauthor: bool = True,
):
ref_tensor = torch.as_tensor(
parts["train_refs"][["source", "target"]].to_numpy(), dtype=torch.long
)
cite_tensor = torch.as_tensor(
parts["citation"][["source", "target"]].to_numpy(), dtype=torch.long
)
coauthor_tensor = torch.as_tensor(
parts["coauthor"][["source", "target"]].to_numpy(), dtype=torch.long
)
data = HeteroData()
data["author"].num_nodes = num_authors
data["paper"].num_nodes = num_papers
data["paper"].x = torch.as_tensor(parts["paper_feat_aug"], dtype=torch.float)
data["author", "ref", "paper"].edge_index = ref_tensor.t().contiguous()
data["paper", "beref", "author"].edge_index = ref_tensor[:, [1, 0]].t().contiguous()
if use_citation:
data["paper", "cite", "paper"].edge_index = torch.cat(
[cite_tensor, cite_tensor[:, [1, 0]]], dim=0
).t().contiguous()
if use_coauthor:
data["author", "coauthor", "author"].edge_index = torch.cat(
[coauthor_tensor, coauthor_tensor[:, [1, 0]]], dim=0
).t().contiguous()
return data.to(device)
def sample_hard_negatives(parts, n_samples: int, num_authors: int, num_papers: int, device):
neg_list: list[tuple[int, int]] = []
train_set = parts["train_set"]
popular = parts["popular"]
coauthor_pool = parts["coauthor_pool"]
def add_random(target: int) -> None:
while len(neg_list) < target:
s = np.random.randint(0, num_authors)
d = np.random.randint(0, num_papers)
if (s, d) not in train_set:
neg_list.append((s, d))
add_random(int(n_samples * 0.5))
attempts = 0
while len(neg_list) < int(n_samples * 0.75) and attempts < n_samples * 3:
attempts += 1
s = np.random.randint(0, num_authors)
d = int(popular[np.random.randint(0, len(popular))])
if (s, d) not in train_set:
neg_list.append((s, d))
attempts = 0
while len(neg_list) < n_samples and attempts < n_samples * 4:
attempts += 1
s = np.random.randint(0, num_authors)
pool = coauthor_pool.get(s)
if pool is None or len(pool) == 0:
continue
d = int(pool[np.random.randint(0, len(pool))])
if (s, d) not in train_set:
neg_list.append((s, d))
add_random(n_samples)
return torch.tensor(neg_list[:n_samples], dtype=torch.long, device=device).t().contiguous()
@torch.no_grad()
def predict_scores(
model: LightGCN,
data,
pairs: np.ndarray,
batch_size: int,
mode: str = "cos",
normalize_embeddings: bool = False,
) -> np.ndarray:
model.eval()
z_dict = model.encode(data)
if normalize_embeddings:
z_dict = {k: F.normalize(v, p=2, dim=1) for k, v in z_dict.items()}
author_z = z_dict["author"].detach().cpu().numpy()
paper_z = z_dict["paper"].detach().cpu().numpy()
scores = []
for start in range(0, len(pairs), batch_size):
batch = pairs[start : start + batch_size]
a = author_z[batch[:, 0]]
p = paper_z[batch[:, 1]]
if mode == "cos":
score = cos_sim(a, p)
elif mode == "dot":
score = np.sum(a * p, axis=1)
elif mode == "neg_l2":
score = -np.sum((a - p) ** 2, axis=1)
else:
raise ValueError(mode)
scores.append(score.astype(np.float32))
return np.concatenate(scores)
def best_f1(labels: np.ndarray, scores: np.ndarray) -> tuple[float, float, float]:
precision, recall, thresholds = precision_recall_curve(labels, scores)
f1s = 2 * precision * recall / (precision + recall + 1e-12)
idx = int(np.argmax(f1s))
threshold = float(thresholds[idx]) if idx < len(thresholds) else 0.5
auc = float(roc_auc_score(labels, scores))
return float(f1s[idx]), threshold, auc
def train_one(args, parts, data, seed: int, embed_dim: int, model_dir: Path, score_dir: Path):
set_seed(seed)
device = torch.device(args.device)
model_cls = LearnableWeightLightGCN if args.variant == "learnw" else LightGCN
model = model_cls(args.num_authors, parts["paper_feat_aug"].shape[1], embed_dim, args.layers).to(device)
optimizer = torch.optim.Adam(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
pos_edges = data["author", "ref", "paper"].edge_index
batch_size = min(args.train_batch_size, pos_edges.size(1))
best = (-1.0, 0.0, 0.0)
best_state = None
val_arr = parts["val_pairs"][["source", "target"]].to_numpy(dtype=np.int64)
labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8)
for epoch in range(args.epochs):
model.train()
perm = torch.randperm(pos_edges.size(1), device=device)[:batch_size]
pos = pos_edges[:, perm]
neg = sample_hard_negatives(
parts, pos.size(1) * args.neg_per_pos, args.num_authors, args.num_papers, device
)
z_dict = model.encode(data)
if args.normalize_embeddings:
z_dict = {k: F.normalize(v, p=2, dim=1) for k, v in z_dict.items()}
raw_pos_scores = model.decode(z_dict, pos)
pos_scores = raw_pos_scores.repeat_interleave(args.neg_per_pos)
neg_scores = model.decode(z_dict, neg)
if args.loss == "bpr":
loss = -F.logsigmoid(pos_scores - neg_scores).mean()
elif args.loss == "hinge":
loss = (args.margin - pos_scores + neg_scores).clamp(min=0).mean()
elif args.loss == "bce":
logits = torch.cat([raw_pos_scores, neg_scores])
targets = torch.cat([torch.ones_like(raw_pos_scores), torch.zeros_like(neg_scores)])
loss = F.binary_cross_entropy_with_logits(logits, targets)
else:
raise ValueError(args.loss)
optimizer.zero_grad()
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
optimizer.step()
should_eval = (epoch + 1) % args.eval_every == 0 or epoch == args.epochs - 1
if should_eval:
scores = predict_scores(
model,
data,
val_arr,
args.pred_batch_size,
args.eval_mode,
args.normalize_embeddings,
)
f1, threshold, auc = best_f1(labels, scores)
if f1 > best[0]:
best = (f1, threshold, auc)
best_state = {k: v.detach().cpu() for k, v in model.state_dict().items()}
np.save(score_dir / f"val_{args.variant}_{args.eval_mode}_s{seed}_d{embed_dim}.npy", scores)
print(
f"seed={seed} dim={embed_dim} epoch={epoch+1:03d} "
f"loss={loss.item():.4f} val_f1={f1:.5f} th={threshold:.5f} auc={auc:.5f}"
)
if best_state is not None:
torch.save(best_state, model_dir / f"{args.variant}_val_s{seed}_d{embed_dim}.pt")
return best
def main() -> None:
parser = argparse.ArgumentParser()
parser.add_argument("--package-root", type=Path, default=Path(__file__).resolve().parents[1])
parser.add_argument("--split-dir", type=Path, default=None)
parser.add_argument("--split-seed", type=int, default=None)
parser.add_argument("--train-frac", type=float, default=0.9)
parser.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
parser.add_argument("--seeds", nargs="*", type=int, default=[0, 42, 2024])
parser.add_argument("--dims", nargs="*", type=int, default=[256])
parser.add_argument("--layers", type=int, default=4)
parser.add_argument("--epochs", type=int, default=160)
parser.add_argument("--eval-every", type=int, default=20)
parser.add_argument("--lr", type=float, default=0.005)
parser.add_argument("--weight-decay", type=float, default=1e-5)
parser.add_argument("--train-batch-size", type=int, default=32768)
parser.add_argument("--pred-batch-size", type=int, default=65536)
parser.add_argument("--neg-per-pos", type=int, default=2)
parser.add_argument("--loss", choices=["bpr", "bce", "hinge"], default="bpr")
parser.add_argument("--margin", type=float, default=1.0)
parser.add_argument("--normalize-embeddings", action="store_true")
parser.add_argument("--num-authors", type=int, default=6611)
parser.add_argument("--num-papers", type=int, default=79937)
parser.add_argument("--run-name", default=None)
parser.add_argument("--variant", choices=["vanilla", "learnw"], default="vanilla")
parser.add_argument("--eval-mode", choices=["cos", "dot", "neg_l2"], default="cos")
parser.add_argument("--drop-citation", action="store_true")
parser.add_argument("--drop-coauthor", action="store_true")
args = parser.parse_args()
root = args.package_root
split_dir = args.split_dir or root / "splits" / "notebook_seed0"
run_name = args.run_name or (
f"dims{'-'.join(map(str, args.dims))}_"
f"seeds{'-'.join(map(str, args.seeds))}_"
f"L{args.layers}_E{args.epochs}"
)
split_name = f"dynamic_seed{args.split_seed}" if args.split_seed is not None else split_dir.name
out_dir = root / "validation_runs" / split_name / run_name
model_dir = out_dir / "checkpoints"
score_dir = out_dir / "scores"
model_dir.mkdir(parents=True, exist_ok=True)
score_dir.mkdir(parents=True, exist_ok=True)
parts = build_parts(
root,
split_dir if args.split_seed is None else None,
args.num_papers,
split_seed=args.split_seed,
train_frac=args.train_frac,
)
data = build_data(
parts,
args.num_authors,
args.num_papers,
torch.device(args.device),
use_citation=not args.drop_citation,
use_coauthor=not args.drop_coauthor,
)
labels = parts["val_pairs"]["label"].to_numpy(dtype=np.int8)
rows = []
for dim in args.dims:
for seed in args.seeds:
best = train_one(args, parts, data, seed, dim, model_dir, score_dir)
rows.append({"seed": seed, "dim": dim, "f1": best[0], "threshold": best[1], "auc": best[2]})
result = pd.DataFrame(rows).sort_values("f1", ascending=False)
result.to_csv(out_dir / "model_results.csv", index=False)
print("\nModel results:")
print(result.to_string(index=False))
val_scores = []
names = []
for row in rows:
path = score_dir / f"val_{args.variant}_{args.eval_mode}_s{int(row['seed'])}_d{int(row['dim'])}.npy"
if path.exists():
val_scores.append(np.load(path))
names.append(path.stem)
if val_scores:
ensemble = np.mean(val_scores, axis=0)
f1, threshold, auc = best_f1(labels, ensemble)
np.save(score_dir / f"val_{args.variant}_ensemble_mean.npy", ensemble)
with (out_dir / "ensemble_result.txt").open("w") as f:
f.write(f"models={','.join(names)}\n")
f.write(f"f1={f1:.8f}\nthreshold={threshold:.8f}\nauc={auc:.8f}\n")
print(f"\nMean ensemble: f1={f1:.5f} threshold={threshold:.5f} auc={auc:.5f}")
if __name__ == "__main__":
main()