| """ | |
| Export a PyG dataset (Planetoid or Heterophilous) to a minimal Graphviz DOT. | |
| - Nodes are colored by class. | |
| - If --filter {train|val|test} is set, nodes in that split are colored red, | |
| other nodes are left uncolored. | |
| - Undirected edges are deduplicated; directed edges are written as‑is. | |
| - For HeterophilousGraphDataset with multiple splits (e.g., Amazon‑ratings has 10), | |
| use --split-index (default 0). | |
| """ | |
| import argparse | |
| from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset | |
| PALETTE = [ | |
| "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", | |
| "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", | |
| "#bcbd22", "#17becf" | |
| ] | |
| HIGHLIGHT_RED = "#ff0000" | |
| def _extend_palette(base, need): | |
| if need <= len(base): | |
| return base[:need] | |
| def brighten(hex_color, factor): | |
| c = int(hex_color[1:], 16) | |
| r = (c >> 16) & 255 | |
| g = (c >> 8) & 255 | |
| b = c & 255 | |
| r = int(min(255, r + (255 - r) * factor)) | |
| g = int(min(255, g + (255 - g) * factor)) | |
| b = int(min(255, b + (255 - b) * factor)) | |
| return f"#{r:02x}{g:02x}{b:02x}" | |
| out = [] | |
| for i in range(need): | |
| base_hex = base[i % len(base)] | |
| factor = 0.18 * (i // len(base)) | |
| out.append(brighten(base_hex, factor)) | |
| return out | |
| def _infer_num_classes(dataset, data): | |
| num_classes = getattr(dataset, "num_classes", None) | |
| if not isinstance(num_classes, int) or num_classes <= 0: | |
| num_classes = int(data.y.max().item()) + 1 | |
| return num_classes | |
| def _get_split_mask(data, which): | |
| name = {"train": "train_mask", "val": "val_mask", "test": "test_mask"}[which] | |
| m = getattr(data, name, None) | |
| if m is None and which == "val": | |
| m = getattr(data, "valid_mask", None) or getattr(data, "validation_mask", None) | |
| return m | |
| def load_graph(root: str, use_hetero: bool, name: str, split_index: int): | |
| if use_hetero: | |
| ds = HeterophilousGraphDataset(root=root, name=name) | |
| idx = max(0, min(split_index, len(ds) - 1)) | |
| data = ds[idx] | |
| num_classes = _infer_num_classes(ds, data) | |
| return data, num_classes | |
| else: | |
| ds = Planetoid(root=f"{root}/Planetoid", name=name) | |
| data = ds[0] | |
| num_classes = _infer_num_classes(ds, data) | |
| return data, num_classes | |
| def write_dot(path: str, data, num_classes: int, directed: bool, filter_split: str | None): | |
| y = data.y | |
| edge_index = data.edge_index | |
| colors = _extend_palette(PALETTE, num_classes) | |
| highlight_mask = None | |
| if filter_split is not None: | |
| m = _get_split_mask(data, filter_split) | |
| if m is not None: | |
| highlight_mask = m.bool() | |
| gtype = "digraph" if directed else "graph" | |
| eop = "->" if directed else "--" | |
| with open(path, "w", encoding="utf-8") as f: | |
| f.write(f"{gtype} {{\n") | |
| for i in range(data.num_nodes): | |
| cls = int(y[i]) | |
| if cls < 0 or cls >= num_classes: | |
| cls = cls % num_classes | |
| base_col = colors[cls] | |
| if filter_split is not None: | |
| if highlight_mask is not None and bool(highlight_mask[i]): | |
| col = HIGHLIGHT_RED | |
| f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') | |
| else: | |
| f.write(f' {i} ;\n') | |
| else: | |
| col = base_col | |
| f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') | |
| if directed: | |
| for s, t in edge_index.t().tolist(): | |
| if s != t: | |
| f.write(f" {s} {eop} {t};\n") | |
| else: | |
| seen = set() | |
| for s, t in edge_index.t().tolist(): | |
| if s == t: | |
| continue | |
| a, b = (s, t) if s <= t else (t, s) | |
| if (a, b) in seen: | |
| continue | |
| seen.add((a, b)) | |
| f.write(f" {a} {eop} {b};\n") | |
| f.write("}\n") | |
| def main(): | |
| parser = argparse.ArgumentParser( | |
| description="Export PyG datasets to minimal DOT with class colors and optional split highlighting." | |
| ) | |
| parser.add_argument("-o", "--output", default="graph.dot", help="Output .dot file (default: graph.dot)") | |
| parser.add_argument("--directed", action="store_true", help="Write directed edges (default: undirected)") | |
| parser.add_argument("--heterophilous", action="store_true", help="Use HeterophilousGraphDataset") | |
| parser.add_argument("--name", default=None, help="Dataset name (Planetoid: Cora/CiteSeer/PubMed; Heterophilous: Amazon-ratings, Roman-empire, etc.)") | |
| parser.add_argument("--root", default="data", help="Root folder for datasets (default: data)") | |
| parser.add_argument("--split-index", type=int, default=0, help="Split index for heterophilous datasets (default: 0)") | |
| parser.add_argument("--filter", choices=["train", "val", "test"], default=None, help="Highlight nodes in the selected split as red (others keep class colors)") | |
| args = parser.parse_args() | |
| dataset_name = args.name if args.name is not None else ("Amazon-ratings" if args.heterophilous else "Cora") | |
| data, num_classes = load_graph(args.root, args.heterophilous, dataset_name, args.split_index) | |
| write_dot(args.output, data, num_classes, args.directed, args.filter) | |
| suffix = f", highlight={args.filter}" if args.filter else "" | |
| print(f"Wrote {args.output} using {'HeterophilousGraphDataset' if args.heterophilous else 'Planetoid'}('{dataset_name}') | nodes={data.num_nodes}, edges={data.edge_index.size(1)}, classes={num_classes}{suffix}") | |
| if __name__ == "__main__": | |
| main() |