""" Export a PyG dataset (Planetoid or Heterophilous) to a minimal Graphviz DOT. - Nodes are colored by class. - If --filter {train|val|test} is set, nodes in that split are colored red, other nodes are left uncolored. - Undirected edges are deduplicated; directed edges are written as‑is. - For HeterophilousGraphDataset with multiple splits (e.g., Amazon‑ratings has 10), use --split-index (default 0). """ import argparse from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset PALETTE = [ "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728", "#9467bd", "#8c564b", "#e377c2", "#7f7f7f", "#bcbd22", "#17becf" ] HIGHLIGHT_RED = "#ff0000" def _extend_palette(base, need): if need <= len(base): return base[:need] def brighten(hex_color, factor): c = int(hex_color[1:], 16) r = (c >> 16) & 255 g = (c >> 8) & 255 b = c & 255 r = int(min(255, r + (255 - r) * factor)) g = int(min(255, g + (255 - g) * factor)) b = int(min(255, b + (255 - b) * factor)) return f"#{r:02x}{g:02x}{b:02x}" out = [] for i in range(need): base_hex = base[i % len(base)] factor = 0.18 * (i // len(base)) out.append(brighten(base_hex, factor)) return out def _infer_num_classes(dataset, data): num_classes = getattr(dataset, "num_classes", None) if not isinstance(num_classes, int) or num_classes <= 0: num_classes = int(data.y.max().item()) + 1 return num_classes def _get_split_mask(data, which): name = {"train": "train_mask", "val": "val_mask", "test": "test_mask"}[which] m = getattr(data, name, None) if m is None and which == "val": m = getattr(data, "valid_mask", None) or getattr(data, "validation_mask", None) return m def load_graph(root: str, use_hetero: bool, name: str, split_index: int): if use_hetero: ds = HeterophilousGraphDataset(root=root, name=name) idx = max(0, min(split_index, len(ds) - 1)) data = ds[idx] num_classes = _infer_num_classes(ds, data) return data, num_classes else: ds = Planetoid(root=f"{root}/Planetoid", name=name) data = ds[0] num_classes = _infer_num_classes(ds, data) return data, num_classes def write_dot(path: str, data, num_classes: int, directed: bool, filter_split: str | None): y = data.y edge_index = data.edge_index colors = _extend_palette(PALETTE, num_classes) highlight_mask = None if filter_split is not None: m = _get_split_mask(data, filter_split) if m is not None: highlight_mask = m.bool() gtype = "digraph" if directed else "graph" eop = "->" if directed else "--" with open(path, "w", encoding="utf-8") as f: f.write(f"{gtype} {{\n") for i in range(data.num_nodes): cls = int(y[i]) if cls < 0 or cls >= num_classes: cls = cls % num_classes base_col = colors[cls] if filter_split is not None: if highlight_mask is not None and bool(highlight_mask[i]): col = HIGHLIGHT_RED f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') else: f.write(f' {i} ;\n') else: col = base_col f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n') if directed: for s, t in edge_index.t().tolist(): if s != t: f.write(f" {s} {eop} {t};\n") else: seen = set() for s, t in edge_index.t().tolist(): if s == t: continue a, b = (s, t) if s <= t else (t, s) if (a, b) in seen: continue seen.add((a, b)) f.write(f" {a} {eop} {b};\n") f.write("}\n") def main(): parser = argparse.ArgumentParser( description="Export PyG datasets to minimal DOT with class colors and optional split highlighting." ) parser.add_argument("-o", "--output", default="graph.dot", help="Output .dot file (default: graph.dot)") parser.add_argument("--directed", action="store_true", help="Write directed edges (default: undirected)") parser.add_argument("--heterophilous", action="store_true", help="Use HeterophilousGraphDataset") parser.add_argument("--name", default=None, help="Dataset name (Planetoid: Cora/CiteSeer/PubMed; Heterophilous: Amazon-ratings, Roman-empire, etc.)") parser.add_argument("--root", default="data", help="Root folder for datasets (default: data)") parser.add_argument("--split-index", type=int, default=0, help="Split index for heterophilous datasets (default: 0)") parser.add_argument("--filter", choices=["train", "val", "test"], default=None, help="Highlight nodes in the selected split as red (others keep class colors)") args = parser.parse_args() dataset_name = args.name if args.name is not None else ("Amazon-ratings" if args.heterophilous else "Cora") data, num_classes = load_graph(args.root, args.heterophilous, dataset_name, args.split_index) write_dot(args.output, data, num_classes, args.directed, args.filter) suffix = f", highlight={args.filter}" if args.filter else "" print(f"Wrote {args.output} using {'HeterophilousGraphDataset' if args.heterophilous else 'Planetoid'}('{dataset_name}') | nodes={data.num_nodes}, edges={data.edge_index.size(1)}, classes={num_classes}{suffix}") if __name__ == "__main__": main()