| """Improved heterogeneous GNN for academic paper recommendation. |
| |
| Key improvements over baseline: |
| 1. Proper SAGEConv-based heterogeneous GNN (3 layers, residual) |
| 2. MLP decoder instead of dot product |
| 3. Hard negative sampling (popular papers + co-author papers) |
| 4. Graph structural features (degree features) |
| 5. BCE loss with positive weight |
| 6. Longer training with LR scheduling |
| 7. Exploits train-test overlap for known positives |
| """ |
| import os |
| import pickle as pkl |
| import random |
| import itertools |
|
|
| import numpy as np |
| import pandas as pd |
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
| from torch.optim.lr_scheduler import CosineAnnealingLR, ReduceLROnPlateau |
|
|
| from torch_geometric.data import HeteroData |
| from torch_geometric.nn import SAGEConv, HeteroConv, Linear |
| from sklearn.metrics import f1_score, precision_recall_curve, roc_auc_score |
|
|
| device = torch.device('cuda:0' if torch.cuda.is_available() else 'cpu') |
| print('device:', device) |
|
|
|
|
| def set_seed(seed=0): |
| random.seed(seed) |
| np.random.seed(seed) |
| torch.manual_seed(seed) |
| if torch.cuda.is_available(): |
| torch.cuda.manual_seed_all(seed) |
|
|
|
|
| |
| base_path = "/home/lzc/cs3319-project" |
|
|
|
|
| def read_txt(file): |
| res_list = [] |
| with open(file, "r") as f: |
| for line in f: |
| res_list.append(list(map(int, line.strip().split()))) |
| return res_list |
|
|
|
|
| citation = read_txt(os.path.join(base_path, "paper_file_ann.txt")) |
| existing_refs = read_txt(os.path.join(base_path, "bipartite_train_ann.txt")) |
| refs_to_pred = read_txt(os.path.join(base_path, "bipartite_test_ann.txt")) |
| coauthor = read_txt(os.path.join(base_path, "author_file_ann.txt")) |
|
|
| with open(os.path.join(base_path, "feature.pkl"), 'rb') as f: |
| paper_feature = pkl.load(f) |
|
|
| print(f"Citations: {len(citation)}, Train refs: {len(existing_refs)}, " |
| f"Test pairs: {len(refs_to_pred)}, Coauthor: {len(coauthor)}") |
|
|
| |
| train_set = set(map(tuple, existing_refs)) |
| test_arr_full = np.array(refs_to_pred, dtype=np.int64) |
| test_set = set(map(tuple, refs_to_pred)) |
| overlap = train_set & test_set |
| print(f"Train-test overlap (known positives): {len(overlap)} / {len(test_set)} " |
| f"({100*len(overlap)/len(test_set):.1f}%)") |
|
|
| |
| cite_edges = pd.DataFrame(citation, columns=['source', 'target']) |
| ref_edges = pd.DataFrame(existing_refs, columns=['source', 'target']) |
| coauthor_edges = pd.DataFrame(coauthor, columns=['source', 'target']) |
|
|
| |
| node_tmp = pd.concat([cite_edges['source'], cite_edges['target'], ref_edges['target']]) |
| node_papers = pd.DataFrame(index=pd.unique(node_tmp)) |
| node_tmp = pd.concat([ref_edges['source'], coauthor_edges['source'], coauthor_edges['target']]) |
| node_authors = pd.DataFrame(index=pd.unique(node_tmp)) |
| num_paper_nodes = len(node_papers) |
| num_author_nodes = len(node_authors) |
| print(f"Nodes: {num_author_nodes} authors, {num_paper_nodes} papers") |
|
|
| |
| author_ref_deg = np.zeros(num_author_nodes, dtype=np.float32) |
| paper_ref_deg = np.zeros(num_paper_nodes, dtype=np.float32) |
| author_coauthor_deg = np.zeros(num_author_nodes, dtype=np.float32) |
| paper_cite_out = np.zeros(num_paper_nodes, dtype=np.float32) |
| paper_cite_in = np.zeros(num_paper_nodes, dtype=np.float32) |
|
|
| for s, t in existing_refs: |
| author_ref_deg[s] += 1 |
| paper_ref_deg[t] += 1 |
| for s, t in coauthor: |
| author_coauthor_deg[s] += 1 |
| author_coauthor_deg[t] += 1 |
| for s, t in citation: |
| paper_cite_out[s] += 1 |
| paper_cite_in[t] += 1 |
|
|
| |
| def log_norm(x): |
| x = np.log1p(x) |
| return (x - x.mean()) / (x.std() + 1e-8) |
|
|
|
|
| author_deg_feat = np.stack([ |
| log_norm(author_ref_deg), |
| log_norm(author_coauthor_deg), |
| ], axis=-1) |
| paper_deg_feat = np.stack([ |
| log_norm(paper_ref_deg), |
| log_norm(paper_cite_out), |
| log_norm(paper_cite_in), |
| ], axis=-1) |
|
|
| |
| paper_feature_np = paper_feature.numpy() |
| paper_feat_aug = np.concatenate([paper_feature_np, paper_deg_feat], axis=-1) |
| paper_feat_dim = paper_feat_aug.shape[-1] |
| author_deg_dim = author_deg_feat.shape[-1] |
| print(f"Paper features: {paper_feat_dim}d (512 + {paper_deg_feat.shape[-1]} degree), " |
| f"Author degree features: {author_deg_feat.shape[-1]}d") |
|
|
| |
| ref_edges_idx = ref_edges.copy() |
| train_refs = ref_edges_idx.sample(frac=0.9, random_state=0, axis=0) |
| val_pos = ref_edges_idx[~ref_edges_idx.index.isin(train_refs.index)].copy() |
| val_pos['label'] = 1 |
|
|
| |
| existing_ref_set = set(map(tuple, existing_refs)) |
| author_ids = node_authors.index.to_numpy(dtype=np.int64) |
| paper_ids = node_papers.index.to_numpy(dtype=np.int64) |
|
|
| num_val_neg = len(val_pos) |
| neg_pairs = [] |
| rng = np.random.default_rng(0) |
| while len(neg_pairs) < num_val_neg: |
| src = int(rng.choice(author_ids)) |
| dst = int(rng.choice(paper_ids)) |
| if (src, dst) not in existing_ref_set: |
| neg_pairs.append((src, dst)) |
|
|
| val_neg = pd.DataFrame(neg_pairs, columns=['source', 'target']) |
| val_neg['label'] = 0 |
| val_set = pd.concat([val_pos, val_neg], ignore_index=True).sample(frac=1, random_state=0) |
| print(f"Val: {len(val_set)} ({val_set['label'].sum()} pos, {len(val_set)-val_set['label'].sum()} neg)") |
|
|
| |
| |
| paper_popularity = paper_ref_deg.copy() |
| popular_threshold = np.percentile(paper_popularity[paper_popularity > 0], 70) |
| popular_papers = np.where(paper_popularity >= popular_threshold)[0] |
| print(f"Popular papers (top 30%): {len(popular_papers)}") |
|
|
| |
| coauthor_map = {i: set() for i in range(num_author_nodes)} |
| for s, t in coauthor: |
| coauthor_map[s].add(t) |
| coauthor_map[t].add(s) |
|
|
| |
| author_papers = {i: set() for i in range(num_author_nodes)} |
| for s, t in existing_refs: |
| author_papers[s].add(t) |
|
|
| |
| coauthor_paper_pool = {} |
| for author in range(num_author_nodes): |
| pool = set() |
| for coa in coauthor_map[author]: |
| pool.update(author_papers[coa]) |
| pool -= author_papers[author] |
| coauthor_paper_pool[author] = list(pool) if pool else list(range(num_paper_nodes)) |
|
|
| print(f"Authors with co-author paper pool: " |
| f"{sum(1 for v in coauthor_paper_pool.values() if len(v) > 0)}") |
|
|
| |
| train_ref_tensor = torch.as_tensor(train_refs[['source', 'target']].to_numpy(), dtype=torch.long) |
| cite_tensor = torch.as_tensor(cite_edges[['source', 'target']].to_numpy(), dtype=torch.long) |
| coauthor_tensor = torch.as_tensor(coauthor_edges[['source', 'target']].to_numpy(), dtype=torch.long) |
|
|
| num_authors = num_author_nodes |
| num_papers = num_paper_nodes |
|
|
| paper_x = torch.as_tensor(paper_feat_aug, dtype=torch.float) |
| author_x = torch.as_tensor(author_deg_feat, dtype=torch.float) |
|
|
| data = HeteroData() |
| data['author'].num_nodes = num_authors |
| data['author'].x = author_x |
| data['paper'].num_nodes = num_papers |
| data['paper'].x = paper_x |
| data['author', 'ref', 'paper'].edge_index = train_ref_tensor.t().contiguous() |
| data['paper', 'beref', 'author'].edge_index = train_ref_tensor[:, [1, 0]].t().contiguous() |
| data['paper', 'cite', 'paper'].edge_index = torch.cat([ |
| cite_tensor, cite_tensor[:, [1, 0]], |
| ], dim=0).t().contiguous() |
| data['author', 'coauthor', 'author'].edge_index = torch.cat([ |
| coauthor_tensor, coauthor_tensor[:, [1, 0]], |
| ], dim=0).t().contiguous() |
| data = data.to(device) |
| print(data) |
|
|
|
|
| |
| class ResidualHeteroConv(nn.Module): |
| """HeteroConv with residual connection and layer norm.""" |
| def __init__(self, metadata, in_dims, out_dim, dropout=0.2): |
| super().__init__() |
| node_types, edge_types = metadata |
| edge_types_used = [ |
| ('author', 'ref', 'paper'), |
| ('paper', 'beref', 'author'), |
| ('paper', 'cite', 'paper'), |
| ('author', 'coauthor', 'author'), |
| ] |
| conv_dict = {} |
| for et in edge_types_used: |
| if et in edge_types: |
| conv_dict[et] = SAGEConv( |
| (in_dims[et[0]], in_dims[et[2]]), out_dim, |
| ) |
| self.conv = HeteroConv(conv_dict, aggr='mean') |
| self.norms = nn.ModuleDict({ |
| nt: nn.LayerNorm(out_dim) for nt in node_types |
| }) |
| self.dropout = nn.Dropout(dropout) |
|
|
| |
| self.res_proj = nn.ModuleDict() |
| for nt in node_types: |
| if in_dims.get(nt, out_dim) != out_dim: |
| self.res_proj[nt] = nn.Linear(in_dims[nt], out_dim) |
|
|
| def forward(self, x_dict, edge_index_dict): |
| new_x = self.conv(x_dict, edge_index_dict) |
| out = {} |
| for nt in new_x: |
| res = x_dict[nt] |
| if nt in self.res_proj: |
| res = self.res_proj[nt](res) |
| out[nt] = self.norms[nt](new_x[nt] + res) |
| out[nt] = F.relu(out[nt]) |
| out[nt] = self.dropout(out[nt]) |
| return out |
|
|
|
|
| class MLPDecoder(nn.Module): |
| """MLP decoder: author_emb || paper_emb || author_emb * paper_emb -> score.""" |
| def __init__(self, in_dim, hidden=128): |
| super().__init__() |
| self.mlp = nn.Sequential( |
| nn.Linear(in_dim * 3, hidden), |
| nn.BatchNorm1d(hidden), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(hidden, hidden // 2), |
| nn.BatchNorm1d(hidden // 2), |
| nn.ReLU(), |
| nn.Dropout(0.3), |
| nn.Linear(hidden // 2, 1), |
| ) |
|
|
| def forward(self, author_emb, paper_emb, edge_label_index): |
| src, dst = edge_label_index |
| a = author_emb[src] |
| p = paper_emb[dst] |
| x = torch.cat([a, p, a * p], dim=-1) |
| return self.mlp(x).squeeze(-1) |
|
|
|
|
| class ImprovedHeteroGNN(nn.Module): |
| def __init__(self, metadata, author_in_dim, paper_in_dim, |
| hidden_dim=128, num_layers=3, out_dim=64): |
| super().__init__() |
| node_types, edge_types = metadata |
| self.author_proj = nn.Linear(author_in_dim, hidden_dim) |
| self.paper_proj = nn.Linear(paper_in_dim, hidden_dim) |
|
|
| in_dims_init = {'author': hidden_dim, 'paper': hidden_dim} |
| self.convs = nn.ModuleList() |
| for i in range(num_layers): |
| self.convs.append(ResidualHeteroConv( |
| metadata, in_dims_init, hidden_dim, dropout=0.2, |
| )) |
| in_dims_init = {'author': hidden_dim, 'paper': hidden_dim} |
|
|
| self.post_lin = nn.Linear(hidden_dim, out_dim) |
| self.decoder = MLPDecoder(out_dim, hidden=128) |
| self.reset_parameters() |
|
|
| def reset_parameters(self): |
| for m in self.modules(): |
| if isinstance(m, nn.Linear): |
| nn.init.xavier_uniform_(m.weight) |
| if m.bias is not None: |
| nn.init.zeros_(m.bias) |
| elif isinstance(m, SAGEConv): |
| m.reset_parameters() |
|
|
| def encode(self, data): |
| x_dict = { |
| 'author': self.author_proj(data['author'].x), |
| 'paper': self.paper_proj(data['paper'].x), |
| } |
| for conv in self.convs: |
| x_dict = conv(x_dict, data.edge_index_dict) |
| return { |
| 'author': self.post_lin(x_dict['author']), |
| 'paper': self.post_lin(x_dict['paper']), |
| } |
|
|
| def decode(self, z_dict, edge_label_index): |
| return self.decoder(z_dict['author'], z_dict['paper'], edge_label_index) |
|
|
|
|
| |
| def sample_hard_negatives(pos_batch_size, num_authors, num_papers, |
| existing_set, device, pos_src=None): |
| """Mixed negative sampling: 50% random, 25% popular, 25% co-author papers.""" |
| neg_list = [] |
| n_random = pos_batch_size // 2 |
| n_popular = pos_batch_size // 4 |
| n_coauthor = pos_batch_size - n_random - n_popular |
|
|
| |
| while len(neg_list) < n_random: |
| src = np.random.randint(0, num_authors, size=n_random) |
| dst = np.random.randint(0, num_papers, size=n_random) |
| for s, d in zip(src, dst): |
| if (s, d) not in existing_set: |
| neg_list.append((s, d)) |
| if len(neg_list) >= n_random: |
| break |
|
|
| |
| cnt = 0 |
| while len(neg_list) < n_random + n_popular and cnt < n_popular * 5: |
| cnt += 1 |
| src = np.random.randint(0, num_authors) |
| dst = popular_papers[np.random.randint(0, len(popular_papers))] |
| if (src, dst) not in existing_set: |
| neg_list.append((src, dst)) |
|
|
| |
| cnt = 0 |
| while len(neg_list) < pos_batch_size and cnt < n_coauthor * 10: |
| cnt += 1 |
| src = np.random.randint(0, num_authors) |
| pool = coauthor_paper_pool.get(src, []) |
| if pool: |
| dst = pool[np.random.randint(0, len(pool))] |
| if (src, dst) not in existing_set: |
| neg_list.append((src, dst)) |
|
|
| |
| while len(neg_list) < pos_batch_size: |
| src = np.random.randint(0, num_authors) |
| dst = np.random.randint(0, num_papers) |
| if (src, dst) not in existing_set: |
| neg_list.append((src, dst)) |
|
|
| return torch.tensor(neg_list[:pos_batch_size], dtype=torch.long, |
| device=device).t().contiguous() |
|
|
|
|
| |
| def run_experiment(seed, hidden_dim=128, num_layers=3, lr=0.003, |
| num_epochs=250, use_hard_neg=True): |
| set_seed(seed) |
| model = ImprovedHeteroGNN( |
| data.metadata(), |
| author_in_dim=author_deg_dim, |
| paper_in_dim=paper_feat_dim, |
| hidden_dim=hidden_dim, |
| num_layers=num_layers, |
| out_dim=64, |
| ).to(device) |
|
|
| |
| decay_params = [] |
| no_decay_params = [] |
| for name, param in model.named_parameters(): |
| if 'norm' in name or 'bias' in name: |
| no_decay_params.append(param) |
| else: |
| decay_params.append(param) |
|
|
| optimizer = torch.optim.AdamW([ |
| {'params': decay_params, 'weight_decay': 1e-4}, |
| {'params': no_decay_params, 'weight_decay': 0}, |
| ], lr=lr) |
|
|
| scheduler = ReduceLROnPlateau(optimizer, mode='max', factor=0.5, |
| patience=20, min_lr=1e-6) |
|
|
| pos_edge_index = data['author', 'ref', 'paper'].edge_index |
| existing_train_set = set(map(tuple, train_refs[['source', 'target']].to_numpy().tolist())) |
|
|
| batch_size = min(32768, pos_edge_index.size(1)) |
| best_val_f1 = 0 |
| best_state = None |
| patience_counter = 0 |
|
|
| for epoch in range(num_epochs): |
| model.train() |
| optimizer.zero_grad() |
|
|
| |
| perm = torch.randperm(pos_edge_index.size(1), device=device)[:batch_size] |
| pos_batch = pos_edge_index[:, perm] |
|
|
| if use_hard_neg: |
| neg_batch = sample_hard_negatives( |
| pos_batch.size(1), num_authors, num_papers, |
| existing_train_set, device, |
| pos_src=pos_batch[0].cpu().numpy(), |
| ) |
| else: |
| |
| neg_list = [] |
| while len(neg_list) < pos_batch.size(1): |
| s = torch.randint(0, num_authors, (pos_batch.size(1),)) |
| d = torch.randint(0, num_papers, (pos_batch.size(1),)) |
| for si, di in zip(s.tolist(), d.tolist()): |
| if (si, di) not in existing_train_set: |
| neg_list.append((si, di)) |
| if len(neg_list) >= pos_batch.size(1): |
| break |
| neg_batch = torch.tensor(neg_list, dtype=torch.long, |
| device=device).t().contiguous() |
|
|
| z_dict = model.encode(data) |
| pos_score = model.decode(z_dict, pos_batch) |
| neg_score = model.decode(z_dict, neg_batch) |
|
|
| |
| pos_labels = torch.ones_like(pos_score) |
| neg_labels = torch.zeros_like(neg_score) |
| scores = torch.cat([pos_score, neg_score]) |
| labels = torch.cat([pos_labels, neg_labels]) |
| loss = F.binary_cross_entropy_with_logits(scores, labels) |
|
|
| loss.backward() |
| torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0) |
| optimizer.step() |
|
|
| |
| if epoch % 5 == 0 or epoch == num_epochs - 1: |
| val_f1, val_auc, val_thresh = evaluate(model, data, val_set, device) |
| scheduler.step(val_f1) |
|
|
| if val_f1 > best_val_f1: |
| best_val_f1 = val_f1 |
| best_state = {k: v.cpu().clone() for k, v in model.state_dict().items()} |
| patience_counter = 0 |
| else: |
| patience_counter += 1 |
|
|
| if epoch % 20 == 0 or epoch == num_epochs - 1: |
| print(f'Epoch {epoch:03d} | Loss={loss.item():.4f} | ' |
| f'Val F1={val_f1:.4f} AUC={val_auc:.4f} Thresh={val_thresh:.3f} | ' |
| f'Best F1={best_val_f1:.4f}') |
|
|
| if patience_counter >= 30: |
| print(f'Early stopping at epoch {epoch}') |
| break |
|
|
| model.load_state_dict(best_state) |
| return model, best_val_f1 |
|
|
|
|
| @torch.no_grad() |
| def evaluate(model, data, val_df, device): |
| model.eval() |
| z_dict = model.encode(data) |
|
|
| val_arr = val_df[['source', 'target']].to_numpy(dtype=np.int64) |
| val_labels = val_df['label'].to_numpy() |
|
|
| edge_idx = torch.as_tensor(val_arr, device=device).t() |
| scores = model.decoder( |
| z_dict['author'], z_dict['paper'], |
| edge_idx, |
| ).sigmoid().cpu().numpy() |
|
|
| |
| precision, recall, thresholds = precision_recall_curve(val_labels, scores) |
| f1s = 2 * precision * recall / (precision + recall + 1e-12) |
| best_idx = np.argmax(f1s) |
| best_thresh = thresholds[best_idx] if best_idx < len(thresholds) else 0.5 |
| best_f1 = f1s[best_idx] |
| auc = roc_auc_score(val_labels, scores) |
|
|
| model.train() |
| return best_f1, auc, best_thresh |
|
|
|
|
| |
| print("\n" + "=" * 60) |
| print("Experiment 1: Improved GNN with hard negatives") |
| print("=" * 60) |
| model1, f1_1 = run_experiment(seed=0, hidden_dim=128, num_layers=3) |
|
|
| print("\n" + "=" * 60) |
| print("Experiment 2: Improved GNN (seed=42)") |
| print("=" * 60) |
| model2, f1_2 = run_experiment(seed=42, hidden_dim=128, num_layers=3) |
|
|
| print(f"\nModel 1 best val F1: {f1_1:.4f}") |
| print(f"Model 2 best val F1: {f1_2:.4f}") |
|
|
| |
| print("\n" + "=" * 60) |
| print("Generating ensemble submission...") |
| print("=" * 60) |
|
|
|
|
| @torch.no_grad() |
| def predict_all(model, data, test_pairs, overlap_set, device): |
| model.eval() |
| z_dict = model.encode(data) |
| |
| batch_size = 131072 |
| all_scores = [] |
| n = len(test_pairs) |
| for start in range(0, n, batch_size): |
| end = min(start + batch_size, n) |
| edge_idx = torch.as_tensor(test_pairs[start:end], device=device).t() |
| batch_scores = model.decoder( |
| z_dict['author'], z_dict['paper'], |
| edge_idx, |
| ).sigmoid().cpu().numpy() |
| all_scores.append(batch_scores) |
| return np.concatenate(all_scores) |
|
|
|
|
| test_arr = np.array(refs_to_pred, dtype=np.int64) |
|
|
| |
| scores1 = predict_all(model1, data, test_arr, overlap, device) |
| scores2 = predict_all(model2, data, test_arr, overlap, device) |
| ensemble_scores = (scores1 + scores2) / 2.0 |
|
|
| |
| known_pos_mask = np.array([tuple(p) in overlap for p in test_arr]) |
| ensemble_scores[known_pos_mask] = 1.0 |
|
|
| |
| |
| val_scores1 = predict_all(model1, data, |
| val_set[['source', 'target']].to_numpy(dtype=np.int64), |
| set(), device) |
| val_scores2 = predict_all(model2, data, |
| val_set[['source', 'target']].to_numpy(dtype=np.int64), |
| set(), device) |
| val_ens = (val_scores1 + val_scores2) / 2.0 |
| val_labels = val_set['label'].to_numpy() |
|
|
| precision, recall, thresholds = precision_recall_curve(val_labels, val_ens) |
| f1s = 2 * precision * recall / (precision + recall + 1e-12) |
| best_idx = np.argmax(f1s) |
| best_thresh = thresholds[best_idx] if best_idx < len(thresholds) else 0.5 |
| val_f1_ens = f1s[best_idx] |
| print(f"Ensemble val F1: {val_f1_ens:.4f} @ threshold={best_thresh:.4f}") |
|
|
| |
| predictions = (ensemble_scores >= best_thresh).astype(int) |
| print(f"Predicted positive ratio: {predictions.mean():.4f} " |
| f"({predictions.sum()} / {len(predictions)})") |
| print(f"Known positives set to 1: {known_pos_mask.sum()}") |
|
|
| |
| output_path = "/home/lzc/submission_improved.csv" |
| data_out = [[idx, str(int(p))] for idx, p in enumerate(predictions)] |
| df = pd.DataFrame(data_out, columns=['Index', 'Predicted'], dtype=object) |
| df.to_csv(output_path, index=False) |
| print(f"\nSubmission saved to: {output_path}") |
|
|
| |
| for i, (model, scores_i, name) in enumerate([ |
| (model1, scores1, 'model1'), |
| (model2, scores2, 'model2'), |
| ]): |
| s = scores_i.copy() |
| s[known_pos_mask] = 1.0 |
| preds = (s >= best_thresh).astype(int) |
| out_path = f"/home/lzc/submission_{name}.csv" |
| data_out = [[idx, str(int(p))] for idx, p in enumerate(preds)] |
| pd.DataFrame(data_out, columns=['Index', 'Predicted'], dtype=object).to_csv( |
| out_path, index=False) |
| print(f" {name} saved to: {out_path}") |
|
|