netfm-training / netfm /data /datasets.py
henribonamy's picture
Upload netfm/data/datasets.py with huggingface_hub
3648a5a verified
import os
import json
import zipfile
import urllib.request
from dataclasses import dataclass
from io import BytesIO
import torch
import numpy as np
import pandas as pd
from torch_geometric.data import Data
from torch_geometric.datasets import (
Planetoid,
PPI,
CitationFull,
)
from torch_geometric.utils import to_undirected
from ogb.nodeproppred import PygNodePropPredDataset
from netfm.data.structural_features import compute_structural_features, normalize_features
@dataclass
class GraphDataset:
"""Container for a processed graph with structural features."""
name: str
domain: str
data: Data
num_nodes: int
num_edges: int
PRETRAIN_DATASETS: dict[str, dict] = {
"cora": {"loader": "planetoid", "domain": "citation"},
"citeseer": {"loader": "planetoid", "domain": "citation"},
"pubmed": {"loader": "planetoid", "domain": "citation"},
"twitch-ENGB": {"loader": "twitch_snap", "domain": "social"},
"twitch-DE": {"loader": "twitch_snap", "domain": "social"},
"twitch-ES": {"loader": "twitch_snap", "domain": "social"},
"lastfm": {"loader": "lastfm_snap", "domain": "social"},
"facebook": {"loader": "facebook_snap", "domain": "social"},
"ppi-train": {"loader": "ppi", "domain": "biological"},
"dblp": {"loader": "dblp", "domain": "collaboration"},
}
HOLDOUT_DATASETS: dict[str, dict] = {
"chameleon": {"loader": "wikipedia_snap", "domain": "heterophilic"},
}
PLANETOID_NAMES: dict[str, str] = {
"cora": "Cora",
"citeseer": "CiteSeer",
"pubmed": "PubMed",
}
def _download_and_extract_zip(url: str, dest_dir: str) -> None:
"""Download a zip file and extract it to dest_dir."""
os.makedirs(dest_dir, exist_ok=True)
print(f" Downloading {url}...", flush=True)
with urllib.request.urlopen(url) as resp:
data = resp.read()
with zipfile.ZipFile(BytesIO(data)) as zf:
zf.extractall(dest_dir)
print(f" Extracted to {dest_dir}", flush=True)
def _load_planetoid(name: str, root: str) -> Data:
"""Load a Planetoid dataset (Cora, CiteSeer, PubMed)."""
dataset = Planetoid(root=os.path.join(root, "Planetoid"), name=PLANETOID_NAMES[name])
return dataset[0]
def _load_ogb(name: str, root: str) -> Data:
"""Load an OGB node property prediction dataset."""
dataset = PygNodePropPredDataset(name=name, root=os.path.join(root, "OGB"))
data = dataset[0]
if hasattr(dataset, "get_idx_split"):
split = dataset.get_idx_split()
train_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
test_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
val_mask = torch.zeros(data.num_nodes, dtype=torch.bool)
train_mask[split["train"]] = True
test_mask[split["test"]] = True
val_mask[split["valid"]] = True
data.train_mask = train_mask
data.test_mask = test_mask
data.val_mask = val_mask
return data
def _load_twitch_snap(region: str, root: str) -> Data:
"""Load Twitch network from SNAP zip."""
snap_dir = os.path.join(root, "twitch_snap")
region_upper = region.upper()
edges_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_edges.csv")
if not os.path.exists(edges_file):
_download_and_extract_zip("https://snap.stanford.edu/data/twitch.zip", snap_dir)
edges_df = pd.read_csv(edges_file)
src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long)
dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long)
mask = src != dst
src, dst = src[mask], dst[mask]
edge_index = torch.stack([
torch.cat([src, dst]),
torch.cat([dst, src]),
])
num_nodes = int(edge_index.max().item()) + 1
target_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_target.csv")
y = None
if os.path.exists(target_file):
target_df = pd.read_csv(target_file)
if "mature" in target_df.columns:
y = torch.tensor(
target_df.sort_values("new_id")["mature"].astype(int).values,
dtype=torch.long,
)
return Data(edge_index=edge_index, num_nodes=num_nodes, y=y)
def _load_lastfm_snap(root: str) -> Data:
"""Load LastFM Asia network from SNAP zip."""
snap_dir = os.path.join(root, "lastfm_snap")
edges_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_edges.csv")
if not os.path.exists(edges_file):
_download_and_extract_zip("https://snap.stanford.edu/data/lastfm_asia.zip", snap_dir)
edges_df = pd.read_csv(edges_file)
src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long)
dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long)
mask = src != dst
src, dst = src[mask], dst[mask]
edge_index = torch.stack([
torch.cat([src, dst]),
torch.cat([dst, src]),
])
num_nodes = int(edge_index.max().item()) + 1
target_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_target.csv")
y = None
if os.path.exists(target_file):
target_df = pd.read_csv(target_file)
y = torch.tensor(target_df.sort_values("id")["target"].values, dtype=torch.long)
return Data(edge_index=edge_index, num_nodes=num_nodes, y=y)
def _load_facebook_snap(root: str) -> Data:
"""Load Facebook Page-Page network from SNAP zip."""
snap_dir = os.path.join(root, "facebook_snap")
edges_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_edges.csv")
if not os.path.exists(edges_file):
_download_and_extract_zip("https://snap.stanford.edu/data/facebook_large.zip", snap_dir)
edges_df = pd.read_csv(edges_file)
src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long)
dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long)
mask = src != dst
src, dst = src[mask], dst[mask]
edge_index = torch.stack([
torch.cat([src, dst]),
torch.cat([dst, src]),
])
num_nodes = int(edge_index.max().item()) + 1
target_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_target.csv")
y = None
if os.path.exists(target_file):
target_df = pd.read_csv(target_file)
label_map = {v: i for i, v in enumerate(sorted(target_df["page_type"].unique()))}
y = torch.tensor(
target_df.sort_values("id")["page_type"].map(label_map).values,
dtype=torch.long,
)
return Data(edge_index=edge_index, num_nodes=num_nodes, y=y)
def _load_wikipedia_snap(name: str, root: str) -> Data:
"""Load heterophilic Wikipedia network (chameleon/squirrel) from SNAP zip."""
snap_dir = os.path.join(root, "wikipedia_snap")
edges_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_edges.csv")
if not os.path.exists(edges_file):
_download_and_extract_zip("https://snap.stanford.edu/data/wikipedia.zip", snap_dir)
edges_df = pd.read_csv(edges_file)
src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long)
dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long)
mask = src != dst
src, dst = src[mask], dst[mask]
edge_index = torch.stack([
torch.cat([src, dst]),
torch.cat([dst, src]),
])
num_nodes = int(edge_index.max().item()) + 1
target_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_target.csv")
y = None
if os.path.exists(target_file):
target_df = pd.read_csv(target_file)
if "target" in target_df.columns:
targets = target_df.sort_values("id")["target"].values
bins = np.quantile(targets, [0.2, 0.4, 0.6, 0.8])
y = torch.tensor(np.digitize(targets, bins), dtype=torch.long)
return Data(edge_index=edge_index, num_nodes=num_nodes, y=y)
def _load_ogbl(name: str, root: str) -> Data:
"""Load an OGB link prediction dataset."""
from ogb.linkproppred import PygLinkPropPredDataset
dataset = PygLinkPropPredDataset(name=name, root=os.path.join(root, "OGBL"))
data = dataset[0]
return data
def _load_ppi(root: str, split: str = "train") -> list[Data]:
"""Load the PPI dataset (multiple graphs)."""
dataset = PPI(root=os.path.join(root, "PPI"), split=split)
return list(dataset)
def _load_dblp(root: str) -> Data:
"""Load the DBLP co-authorship citation network."""
dataset = CitationFull(root=os.path.join(root, "CitationFull"), name="DBLP")
return dataset[0]
def load_dataset(name: str, root: str = "./data") -> list[Data]:
"""Load a dataset by name and return a list of PyG Data objects."""
config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}.get(name)
if config is None:
raise ValueError(f"Unknown dataset: {name}")
loader = config["loader"]
if loader == "planetoid":
return [_load_planetoid(name, root)]
elif loader == "ogb":
return [_load_ogb(name, root)]
elif loader == "twitch_snap":
region = name.split("-")[1]
return [_load_twitch_snap(region, root)]
elif loader == "lastfm_snap":
return [_load_lastfm_snap(root)]
elif loader == "facebook_snap":
return [_load_facebook_snap(root)]
elif loader == "wikipedia_snap":
return [_load_wikipedia_snap(name, root)]
elif loader == "ogbl":
return [_load_ogbl(name, root)]
elif loader == "ppi":
return _load_ppi(root, split="train")
elif loader == "ppi_test":
return _load_ppi(root, split="test")
elif loader == "dblp":
return [_load_dblp(root)]
else:
raise ValueError(f"Unknown loader: {loader}")
def _remove_self_loops(edge_index: torch.Tensor) -> torch.Tensor:
"""Remove self-loops from edge_index."""
mask = edge_index[0] != edge_index[1]
return edge_index[:, mask]
def prepare_graph(data: Data) -> Data:
"""Compute structural features and attach them to a PyG Data object."""
data.edge_index = _remove_self_loops(data.edge_index)
if not data.is_undirected():
data.edge_index = to_undirected(data.edge_index)
structural = compute_structural_features(data)
data.structural_features = normalize_features(structural)
data.x_original = data.x
data.x = data.structural_features
return data
def load_and_prepare(name: str, root: str = "./data") -> list[GraphDataset]:
"""Load a dataset and compute structural features for all graphs."""
raw_graphs = load_dataset(name, root)
config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}[name]
results = []
for data in raw_graphs:
data = prepare_graph(data)
results.append(
GraphDataset(
name=name,
domain=config["domain"],
data=data,
num_nodes=data.num_nodes,
num_edges=data.num_edges,
)
)
return results