Spaces:
Sleeping
Sleeping
| import os | |
| import json | |
| import zipfile | |
| import urllib.request | |
| from dataclasses import dataclass | |
| from io import BytesIO | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from torch_geometric.data import Data | |
| from torch_geometric.datasets import ( | |
| Planetoid, | |
| PPI, | |
| CitationFull, | |
| ) | |
| from torch_geometric.utils import to_undirected | |
| from ogb.nodeproppred import PygNodePropPredDataset | |
| from netfm.data.structural_features import compute_structural_features, normalize_features | |
| class GraphDataset: | |
| """Container for a processed graph with structural features.""" | |
| name: str | |
| domain: str | |
| data: Data | |
| num_nodes: int | |
| num_edges: int | |
| PRETRAIN_DATASETS: dict[str, dict] = { | |
| "cora": {"loader": "planetoid", "domain": "citation"}, | |
| "citeseer": {"loader": "planetoid", "domain": "citation"}, | |
| "pubmed": {"loader": "planetoid", "domain": "citation"}, | |
| "twitch-ENGB": {"loader": "twitch_snap", "domain": "social"}, | |
| "twitch-DE": {"loader": "twitch_snap", "domain": "social"}, | |
| "twitch-ES": {"loader": "twitch_snap", "domain": "social"}, | |
| "lastfm": {"loader": "lastfm_snap", "domain": "social"}, | |
| "facebook": {"loader": "facebook_snap", "domain": "social"}, | |
| "ppi-train": {"loader": "ppi", "domain": "biological"}, | |
| "dblp": {"loader": "dblp", "domain": "collaboration"}, | |
| } | |
| HOLDOUT_DATASETS: dict[str, dict] = { | |
| "chameleon": {"loader": "wikipedia_snap", "domain": "heterophilic"}, | |
| } | |
| PLANETOID_NAMES: dict[str, str] = { | |
| "cora": "Cora", | |
| "citeseer": "CiteSeer", | |
| "pubmed": "PubMed", | |
| } | |
| def _download_and_extract_zip(url: str, dest_dir: str) -> None: | |
| """Download a zip file and extract it to dest_dir.""" | |
| os.makedirs(dest_dir, exist_ok=True) | |
| print(f" Downloading {url}...", flush=True) | |
| with urllib.request.urlopen(url) as resp: | |
| data = resp.read() | |
| with zipfile.ZipFile(BytesIO(data)) as zf: | |
| zf.extractall(dest_dir) | |
| print(f" Extracted to {dest_dir}", flush=True) | |
| def _load_planetoid(name: str, root: str) -> Data: | |
| """Load a Planetoid dataset (Cora, CiteSeer, PubMed).""" | |
| dataset = Planetoid(root=os.path.join(root, "Planetoid"), name=PLANETOID_NAMES[name]) | |
| return dataset[0] | |
| def _load_ogb(name: str, root: str) -> Data: | |
| """Load an OGB node property prediction dataset.""" | |
| dataset = PygNodePropPredDataset(name=name, root=os.path.join(root, "OGB")) | |
| data = dataset[0] | |
| if hasattr(dataset, "get_idx_split"): | |
| split = dataset.get_idx_split() | |
| train_mask = torch.zeros(data.num_nodes, dtype=torch.bool) | |
| test_mask = torch.zeros(data.num_nodes, dtype=torch.bool) | |
| val_mask = torch.zeros(data.num_nodes, dtype=torch.bool) | |
| train_mask[split["train"]] = True | |
| test_mask[split["test"]] = True | |
| val_mask[split["valid"]] = True | |
| data.train_mask = train_mask | |
| data.test_mask = test_mask | |
| data.val_mask = val_mask | |
| return data | |
| def _load_twitch_snap(region: str, root: str) -> Data: | |
| """Load Twitch network from SNAP zip.""" | |
| snap_dir = os.path.join(root, "twitch_snap") | |
| region_upper = region.upper() | |
| edges_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_edges.csv") | |
| if not os.path.exists(edges_file): | |
| _download_and_extract_zip("https://snap.stanford.edu/data/twitch.zip", snap_dir) | |
| edges_df = pd.read_csv(edges_file) | |
| src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) | |
| dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) | |
| mask = src != dst | |
| src, dst = src[mask], dst[mask] | |
| edge_index = torch.stack([ | |
| torch.cat([src, dst]), | |
| torch.cat([dst, src]), | |
| ]) | |
| num_nodes = int(edge_index.max().item()) + 1 | |
| target_file = os.path.join(snap_dir, "twitch", region_upper, f"musae_{region_upper}_target.csv") | |
| y = None | |
| if os.path.exists(target_file): | |
| target_df = pd.read_csv(target_file) | |
| if "mature" in target_df.columns: | |
| y = torch.tensor( | |
| target_df.sort_values("new_id")["mature"].astype(int).values, | |
| dtype=torch.long, | |
| ) | |
| return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) | |
| def _load_lastfm_snap(root: str) -> Data: | |
| """Load LastFM Asia network from SNAP zip.""" | |
| snap_dir = os.path.join(root, "lastfm_snap") | |
| edges_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_edges.csv") | |
| if not os.path.exists(edges_file): | |
| _download_and_extract_zip("https://snap.stanford.edu/data/lastfm_asia.zip", snap_dir) | |
| edges_df = pd.read_csv(edges_file) | |
| src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) | |
| dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) | |
| mask = src != dst | |
| src, dst = src[mask], dst[mask] | |
| edge_index = torch.stack([ | |
| torch.cat([src, dst]), | |
| torch.cat([dst, src]), | |
| ]) | |
| num_nodes = int(edge_index.max().item()) + 1 | |
| target_file = os.path.join(snap_dir, "lasftm_asia", "lastfm_asia_target.csv") | |
| y = None | |
| if os.path.exists(target_file): | |
| target_df = pd.read_csv(target_file) | |
| y = torch.tensor(target_df.sort_values("id")["target"].values, dtype=torch.long) | |
| return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) | |
| def _load_facebook_snap(root: str) -> Data: | |
| """Load Facebook Page-Page network from SNAP zip.""" | |
| snap_dir = os.path.join(root, "facebook_snap") | |
| edges_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_edges.csv") | |
| if not os.path.exists(edges_file): | |
| _download_and_extract_zip("https://snap.stanford.edu/data/facebook_large.zip", snap_dir) | |
| edges_df = pd.read_csv(edges_file) | |
| src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) | |
| dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) | |
| mask = src != dst | |
| src, dst = src[mask], dst[mask] | |
| edge_index = torch.stack([ | |
| torch.cat([src, dst]), | |
| torch.cat([dst, src]), | |
| ]) | |
| num_nodes = int(edge_index.max().item()) + 1 | |
| target_file = os.path.join(snap_dir, "facebook_large", "musae_facebook_target.csv") | |
| y = None | |
| if os.path.exists(target_file): | |
| target_df = pd.read_csv(target_file) | |
| label_map = {v: i for i, v in enumerate(sorted(target_df["page_type"].unique()))} | |
| y = torch.tensor( | |
| target_df.sort_values("id")["page_type"].map(label_map).values, | |
| dtype=torch.long, | |
| ) | |
| return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) | |
| def _load_wikipedia_snap(name: str, root: str) -> Data: | |
| """Load heterophilic Wikipedia network (chameleon/squirrel) from SNAP zip.""" | |
| snap_dir = os.path.join(root, "wikipedia_snap") | |
| edges_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_edges.csv") | |
| if not os.path.exists(edges_file): | |
| _download_and_extract_zip("https://snap.stanford.edu/data/wikipedia.zip", snap_dir) | |
| edges_df = pd.read_csv(edges_file) | |
| src = torch.tensor(edges_df.iloc[:, 0].values, dtype=torch.long) | |
| dst = torch.tensor(edges_df.iloc[:, 1].values, dtype=torch.long) | |
| mask = src != dst | |
| src, dst = src[mask], dst[mask] | |
| edge_index = torch.stack([ | |
| torch.cat([src, dst]), | |
| torch.cat([dst, src]), | |
| ]) | |
| num_nodes = int(edge_index.max().item()) + 1 | |
| target_file = os.path.join(snap_dir, "wikipedia", name, f"musae_{name}_target.csv") | |
| y = None | |
| if os.path.exists(target_file): | |
| target_df = pd.read_csv(target_file) | |
| if "target" in target_df.columns: | |
| targets = target_df.sort_values("id")["target"].values | |
| bins = np.quantile(targets, [0.2, 0.4, 0.6, 0.8]) | |
| y = torch.tensor(np.digitize(targets, bins), dtype=torch.long) | |
| return Data(edge_index=edge_index, num_nodes=num_nodes, y=y) | |
| def _load_ogbl(name: str, root: str) -> Data: | |
| """Load an OGB link prediction dataset.""" | |
| from ogb.linkproppred import PygLinkPropPredDataset | |
| dataset = PygLinkPropPredDataset(name=name, root=os.path.join(root, "OGBL")) | |
| data = dataset[0] | |
| return data | |
| def _load_ppi(root: str, split: str = "train") -> list[Data]: | |
| """Load the PPI dataset (multiple graphs).""" | |
| dataset = PPI(root=os.path.join(root, "PPI"), split=split) | |
| return list(dataset) | |
| def _load_dblp(root: str) -> Data: | |
| """Load the DBLP co-authorship citation network.""" | |
| dataset = CitationFull(root=os.path.join(root, "CitationFull"), name="DBLP") | |
| return dataset[0] | |
| def load_dataset(name: str, root: str = "./data") -> list[Data]: | |
| """Load a dataset by name and return a list of PyG Data objects.""" | |
| config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}.get(name) | |
| if config is None: | |
| raise ValueError(f"Unknown dataset: {name}") | |
| loader = config["loader"] | |
| if loader == "planetoid": | |
| return [_load_planetoid(name, root)] | |
| elif loader == "ogb": | |
| return [_load_ogb(name, root)] | |
| elif loader == "twitch_snap": | |
| region = name.split("-")[1] | |
| return [_load_twitch_snap(region, root)] | |
| elif loader == "lastfm_snap": | |
| return [_load_lastfm_snap(root)] | |
| elif loader == "facebook_snap": | |
| return [_load_facebook_snap(root)] | |
| elif loader == "wikipedia_snap": | |
| return [_load_wikipedia_snap(name, root)] | |
| elif loader == "ogbl": | |
| return [_load_ogbl(name, root)] | |
| elif loader == "ppi": | |
| return _load_ppi(root, split="train") | |
| elif loader == "ppi_test": | |
| return _load_ppi(root, split="test") | |
| elif loader == "dblp": | |
| return [_load_dblp(root)] | |
| else: | |
| raise ValueError(f"Unknown loader: {loader}") | |
| def _remove_self_loops(edge_index: torch.Tensor) -> torch.Tensor: | |
| """Remove self-loops from edge_index.""" | |
| mask = edge_index[0] != edge_index[1] | |
| return edge_index[:, mask] | |
| def prepare_graph(data: Data) -> Data: | |
| """Compute structural features and attach them to a PyG Data object.""" | |
| data.edge_index = _remove_self_loops(data.edge_index) | |
| if not data.is_undirected(): | |
| data.edge_index = to_undirected(data.edge_index) | |
| structural = compute_structural_features(data) | |
| data.structural_features = normalize_features(structural) | |
| data.x_original = data.x | |
| data.x = data.structural_features | |
| return data | |
| def load_and_prepare(name: str, root: str = "./data") -> list[GraphDataset]: | |
| """Load a dataset and compute structural features for all graphs.""" | |
| raw_graphs = load_dataset(name, root) | |
| config = {**PRETRAIN_DATASETS, **HOLDOUT_DATASETS}[name] | |
| results = [] | |
| for data in raw_graphs: | |
| data = prepare_graph(data) | |
| results.append( | |
| GraphDataset( | |
| name=name, | |
| domain=config["domain"], | |
| data=data, | |
| num_nodes=data.num_nodes, | |
| num_edges=data.num_edges, | |
| ) | |
| ) | |
| return results | |