import os import json import zipfile import urllib.request from dataclasses import dataclass from io import BytesIO import torch import numpy as np import pandas as pd from torch_geometric.data import Data from torch_geometric.datasets import ( Planetoid, PPI, CitationFull, ) from torch_geometric.utils import to_undirected from ogb.nodeproppred import PygNodePropPredDataset from netfm.data.structural_features import compute_structural_features, normalize_features @dataclass class GraphDataset: """Container for a processed graph with structural features.""" name: str domain: str data: Data num_nodes: int num_edges: int PRETRAIN_DATASETS: dict[str, dict] = { "cora": {"loader": "planetoid", "domain": "citation"}, "citeseer": {"loader": "planetoid", "domain": "citation"}, "pubmed": {"loader": "planetoid", "domain": "citation"}, "twitch-ENGB": {"loader": "twitch_snap", "domain": "social"}, "twitch-DE": {"loader": "twitch_snap", "domain": "social"}, "twitch-ES": {"loader": "twitch_snap", "domain": "social"}, "lastfm": {"loader": "lastfm_snap", "domain": "social"}, "facebook": {"loader": "facebook_snap", "domain": "social"}, "ppi-train": {"loader": "ppi", "domain": "biological"}, "dblp": {"loader": "dblp", "domain": "collaboration"}, } HOLDOUT_DATASETS: dict[str, dict] = { "chameleon": {"loader": "wikipedia_snap", "domain": "heterophilic"}, } PLANETOID_NAMES: dict[str, str] = { "cora": "Cora", "citeseer": "CiteSeer", "pubmed": "PubMed", } def _download_and_extract_zip(url: str, dest_dir: str) -> None: """Download a zip file and extract it to dest_dir.""" os.makedirs(dest_dir, exist_ok=True) print(f" Downloading {url}...", flush=True) with urllib.request.urlopen(url) as resp: data = resp.read() with zipfile.ZipFile(BytesIO(data)) as zf: zf.extractall(dest_dir) print(f" Extracted to {dest_dir}", flush=True) def _load_planetoid(name: str, root: str) -> Data: """Load a Planetoid dataset (Cora, CiteSeer, PubMed).""" dataset = Planetoid(root=os.path.join(root, "Planetoid"), name=PLANETOID_NAMES[name]) return dataset[0] def _load_ogb(name: str, root: str) -> Data: """Load an OGB node property prediction dataset.""" dataset = PygNodePropPredDataset(name=name, root=os.path.join(root, "OGB")) data = dataset[0] if hasattr(dataset, "get_idx_split"): split = dataset.get_idx_split() train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) train_mask[split["train"]] = True test_mask[split["test"]] = True val_mask[split["valid"]] = True data.train_mask = train_mask data.test_mask = test_mask data.val_mask = val_mask return data def _load_twitch_snap(region: str, root: str) -> Data: """Load Twitch network from SNAP zip.""" snap_dir = os.path.join(root, "twitch_snap") region_upper = region.upper() edges_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_edges.csv") if not os.path.exists(edges_file): _download_and_extract_zip("https://snap.stanford.edu/data/twitch.zip", snap_dir) edges_df = pd.read_csv(edges_file) src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) mask = src != dst src, dst = src[mask], dst[mask] edge_index = torch.stack([ torch.cat([src, dst]), torch.cat([dst, src]), ]) num_nodes = int(edge_index.max().item()) + 1 target_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_target.csv") y = None if os.path.exists(target_file): target_df = pd.read_csv(target_file) if "mature" in target_df.columns: y = torch.tensor( target_df.sort_values("new_id")["mature"].astype(int).values, dtype=torch.long, ) return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) def _load_lastfm_snap(root: str) -> Data: """Load LastFM Asia network from SNAP zip.""" snap_dir = os.path.join(root, "lastfm_snap") edges_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_edges.csv") if not os.path.exists(edges_file): _download_and_extract_zip("https://snap.stanford.edu/data/lastfm_asia.zip", snap_dir) edges_df = pd.read_csv(edges_file) src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) mask = src != dst src, dst = src[mask], dst[mask] edge_index = torch.stack([ torch.cat([src, dst]), torch.cat([dst, src]), ]) num_nodes = int(edge_index.max().item()) + 1 target_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_target.csv") y = None if os.path.exists(target_file): target_df = pd.read_csv(target_file) y = torch.tensor(target_df.sort_values("id")["target"].values, dtype=torch.long) return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) def _load_facebook_snap(root: str) -> Data: """Load Facebook Page-Page network from SNAP zip.""" snap_dir = os.path.join(root, "facebook_snap") edges_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_edges.csv") if not os.path.exists(edges_file): _download_and_extract_zip("https://snap.stanford.edu/data/facebook_large.zip", snap_dir) edges_df = pd.read_csv(edges_file) src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) mask = src != dst src, dst = src[mask], dst[mask] edge_index = torch.stack([ torch.cat([src, dst]), torch.cat([dst, src]), ]) num_nodes = int(edge_index.max().item()) + 1 target_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_target.csv") y = None if os.path.exists(target_file): target_df = pd.read_csv(target_file) label_map = {v: i for i, v in enumerate(sorted(target_df["page_type"].unique()))} y = torch.tensor( target_df.sort_values("id")["page_type"].map(label_map).values, dtype=torch.long, ) return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) def _load_wikipedia_snap(name: str, root: str) -> Data: """Load heterophilic Wikipedia network (chameleon/squirrel) from SNAP zip.""" snap_dir = os.path.join(root, "wikipedia_snap") edges_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_edges.csv") if not os.path.exists(edges_file): _download_and_extract_zip("https://snap.stanford.edu/data/wikipedia.zip", snap_dir) edges_df = pd.read_csv(edges_file) src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) mask = src != dst src, dst = src[mask], dst[mask] edge_index = torch.stack([ torch.cat([src, dst]), torch.cat([dst, src]), ]) num_nodes = int(edge_index.max().item()) + 1 target_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_target.csv") y = None if os.path.exists(target_file): target_df = pd.read_csv(target_file) if "target" in target_df.columns: targets = target_df.sort_values("id")["target"].values bins = np.quantile(targets, [0.2, 0.4, 0.6, 0.8]) y = torch.tensor(np.digitize(targets, bins), dtype=torch.long) return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) def _load_ogbl(name: str, root: str) -> Data: """Load an OGB link prediction dataset.""" from ogb.linkproppred import PygLinkPropPredDataset dataset = PygLinkPropPredDataset(name=name, root=os.path.join(root, "OGBL")) data = dataset[0] return data def _load_ppi(root: str, split: str = "train") -> list[Data]: """Load the PPI dataset (multiple graphs).""" dataset = PPI(root=os.path.join(root, "PPI"), split=split) return list(dataset) def _load_dblp(root: str) -> Data: """Load the DBLP co-authorship citation network.""" dataset = CitationFull(root=os.path.join(root, "CitationFull"), name="DBLP") return dataset[0] def load_dataset(name: str, root: str = "./data") -> list[Data]: """Load a dataset by name and return a list of PyG Data objects.""" config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}.get(name) if config is None: raise ValueError(f"Unknown dataset: {name}") loader = config["loader"] if loader == "planetoid": return [_load_planetoid(name, root)] elif loader == "ogb": return [_load_ogb(name, root)] elif loader == "twitch_snap": region = name.split("-")[1] return [_load_twitch_snap(region, root)] elif loader == "lastfm_snap": return [_load_lastfm_snap(root)] elif loader == "facebook_snap": return [_load_facebook_snap(root)] elif loader == "wikipedia_snap": return [_load_wikipedia_snap(name, root)] elif loader == "ogbl": return [_load_ogbl(name, root)] elif loader == "ppi": return _load_ppi(root, split="train") elif loader == "ppi_test": return _load_ppi(root, split="test") elif loader == "dblp": return [_load_dblp(root)] else: raise ValueError(f"Unknown loader: {loader}") def _remove_self_loops(edge_index: torch.Tensor) -> torch.Tensor: """Remove self-loops from edge_index.""" mask = edge_index[0] != edge_index[1] return edge_index[:, mask] def prepare_graph(data: Data) -> Data: """Compute structural features and attach them to a PyG Data object.""" data.edge_index = _remove_self_loops(data.edge_index) if not data.is_undirected(): data.edge_index = to_undirected(data.edge_index) structural = compute_structural_features(data) data.structural_features = normalize_features(structural) data.x_original = data.x data.x = data.structural_features return data def load_and_prepare(name: str, root: str = "./data") -> list[GraphDataset]: """Load a dataset and compute structural features for all graphs.""" raw_graphs = load_dataset(name, root) config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}[name] results = [] for data in raw_graphs: data = prepare_graph(data) results.append( GraphDataset( name=name, domain=config["domain"], data=data, num_nodes=data.num_nodes, num_edges=data.num_edges, ) ) return results