netfm-training / netfm /data /structural_features.py
henribonamy's picture
Upload netfm/data/structural_features.py with huggingface_hub
2a281a1 verified
import torch
import networkx as nx
import numpy as np
from torch_geometric.data import Data
from torch_geometric.utils import to_networkx, degree as pyg_degree
LARGE_GRAPH_THRESHOLD = 50_000
def _compute_degree(data: Data) -> np.ndarray:
"""Compute node degree directly from edge_index (fast, no NetworkX)."""
n = data.num_nodes
deg = pyg_degree(data.edge_index[0], num_nodes=n) + pyg_degree(data.edge_index[1], num_nodes=n)
return deg.numpy().astype(np.float32) / 2
def compute_structural_features(data: Data) -> torch.Tensor:
"""Compute 6 domain-agnostic structural node features for a PyG Data object."""
n = data.num_nodes
is_large = n > LARGE_GRAPH_THRESHOLD
degree = _compute_degree(data)
G = to_networkx(data, to_undirected=True)
clustering_dict = nx.clustering(G)
clustering = np.array([clustering_dict.get(i, 0.0) for i in range(n)], dtype=np.float32)
pagerank_dict = nx.pagerank(G, max_iter=100, tol=1e-04)
pagerank = np.array([pagerank_dict.get(i, 0.0) for i in range(n)], dtype=np.float32)
triangles_dict = nx.triangles(G)
triangles = np.array([triangles_dict.get(i, 0) for i in range(n)], dtype=np.float32)
core_dict = nx.core_number(G)
core_number = np.array([core_dict.get(i, 0) for i in range(n)], dtype=np.float32)
if is_large:
eigenvector = np.zeros(n, dtype=np.float32)
else:
try:
eig_dict = nx.eigenvector_centrality(G, max_iter=300, tol=1e-06)
eigenvector = np.array([eig_dict.get(i, 0.0) for i in range(n)], dtype=np.float32)
except (nx.PowerIterationFailedConvergence, nx.NetworkXError):
eigenvector = np.zeros(n, dtype=np.float32)
features = np.stack(
[degree, clustering, pagerank, triangles, core_number, eigenvector], axis=1
)
return torch.from_numpy(features)
def normalize_features(features: torch.Tensor) -> torch.Tensor:
"""Z-score normalize features column-wise."""
mean = features.mean(dim=0, keepdim=True)
std = features.std(dim=0, keepdim=True).clamp(min=1e-8)
return (features - mean) / std