import torch import numpy as np from torch_geometric.data import Data from torch_geometric.nn import Node2Vec from sklearn.linear_model import LogisticRegression from sklearn.metrics import accuracy_score, f1_score, roc_auc_score, average_precision_score def train_node2vec( data: Data, embedding_dim: int = 64, walk_length: int = 20, context_size: int = 10, walks_per_node: int = 10, num_epochs: int = 5, lr: float = 0.01, device: torch.device = torch.device("cpu"), ) -> np.ndarray: """Train Node2Vec and return node embeddings.""" model = Node2Vec( data.edge_index, embedding_dim=embedding_dim, walk_length=walk_length, context_size=context_size, walks_per_node=walks_per_node, num_nodes=data.num_nodes, ).to(device) loader = model.loader(batch_size=128, shuffle=True) optimizer = torch.optim.SparseAdam(list(model.parameters()), lr=lr) model.train() for _ in range(num_epochs): for pos_rw, neg_rw in loader: optimizer.zero_grad() loss = model.loss(pos_rw.to(device), neg_rw.to(device)) loss.backward() optimizer.step() model.eval() with torch.no_grad(): embeddings = model().cpu().numpy() return embeddings def evaluate_node2vec( data: Data, device: torch.device, ) -> dict[str, dict[str, float]]: """Train Node2Vec and evaluate on node classification.""" if data.y is None: return {} labels = data.y.squeeze().cpu().numpy() if labels.ndim != 1: return {} embeddings = train_node2vec(data, device=device) train_mask = data.train_mask.cpu().numpy() if hasattr(data, "train_mask") and data.train_mask is not None else None test_mask = data.test_mask.cpu().numpy() if hasattr(data, "test_mask") and data.test_mask is not None else None if train_mask is None: n = data.num_nodes idx = np.random.permutation(n) train_mask = np.zeros(n, dtype=bool) test_mask = np.zeros(n, dtype=bool) train_mask[idx[: int(0.6 * n)]] = True test_mask[idx[int(0.6 * n) :]] = True clf = LogisticRegression(max_iter=1000, random_state=42) clf.fit(embeddings[train_mask], labels[train_mask]) preds = clf.predict(embeddings[test_mask]) return { "node2vec": { "accuracy": float(accuracy_score(labels[test_mask], preds)), "macro_f1": float(f1_score(labels[test_mask], preds, average="macro", zero_division=0)), } }