netfm-training / netfm /evaluate /baselines.py
henribonamy's picture
Upload netfm/evaluate/baselines.py with huggingface_hub
6a03f6d verified
import numpy as np
import networkx as nx
from sklearn.metrics import roc_auc_score, average_precision_score
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx
def common_neighbors_scores(
G: nx.Graph, edges: list[tuple[int, int]]
) -> np.ndarray:
"""Compute Common Neighbors score for each edge pair."""
return np.array([len(list(nx.common_neighbors(G, u, v))) for u, v in edges])
def jaccard_scores(
G: nx.Graph, edges: list[tuple[int, int]]
) -> np.ndarray:
"""Compute Jaccard coefficient for each edge pair."""
preds = nx.jaccard_coefficient(G, edges)
return np.array([p for _, _, p in preds])
def adamic_adar_scores(
G: nx.Graph, edges: list[tuple[int, int]]
) -> np.ndarray:
"""Compute Adamic-Adar index for each edge pair."""
preds = nx.adamic_adar_index(G, edges)
return np.array([p for _, _, p in preds])
def preferential_attachment_scores(
G: nx.Graph, edges: list[tuple[int, int]]
) -> np.ndarray:
"""Compute Preferential Attachment score for each edge pair."""
preds = nx.preferential_attachment(G, edges)
return np.array([p for _, _, p in preds])
def evaluate_link_prediction_baselines(
data: Data,
pos_edges: np.ndarray,
neg_edges: np.ndarray,
) -> dict[str, dict[str, float]]:
"""Run all heuristic baselines for link prediction."""
G = to_networkx(data, to_undirected=True)
pos_list = list(zip(pos_edges[0], pos_edges[1]))
neg_list = list(zip(neg_edges[0], neg_edges[1]))
all_edges = pos_list + neg_list
labels = np.concatenate([np.ones(len(pos_list)), np.zeros(len(neg_list))])
results = {}
for name, scorer in [
("common_neighbors", common_neighbors_scores),
("jaccard", jaccard_scores),
("adamic_adar", adamic_adar_scores),
("preferential_attachment", preferential_attachment_scores),
]:
scores = scorer(G, all_edges)
if np.std(scores) < 1e-10:
results[name] = {"auc": 0.5, "ap": 0.5}
continue
results[name] = {
"auc": roc_auc_score(labels, scores),
"ap": average_precision_score(labels, scores),
}
return results