clique / src /generate_dataset_dot.py
qingy2024's picture
Upload folder using huggingface_hub
f74dd01 verified
"""
Export a PyG dataset (Planetoid or Heterophilous) to a minimal Graphviz DOT.
- Nodes are colored by class.
- If --filter {train|val|test} is set, nodes in that split are colored red,
other nodes are left uncolored.
- Undirected edges are deduplicated; directed edges are written as‑is.
- For HeterophilousGraphDataset with multiple splits (e.g., Amazon‑ratings has 10),
use --split-index (default 0).
"""
import argparse
from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset
PALETTE = [
"#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
"#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
"#bcbd22", "#17becf"
]
HIGHLIGHT_RED = "#ff0000"
def _extend_palette(base, need):
if need <= len(base):
return base[:need]
def brighten(hex_color, factor):
c = int(hex_color[1:], 16)
r = (c >> 16) & 255
g = (c >> 8) & 255
b = c & 255
r = int(min(255, r + (255 - r) * factor))
g = int(min(255, g + (255 - g) * factor))
b = int(min(255, b + (255 - b) * factor))
return f"#{r:02x}{g:02x}{b:02x}"
out = []
for i in range(need):
base_hex = base[i % len(base)]
factor = 0.18 * (i // len(base))
out.append(brighten(base_hex, factor))
return out
def _infer_num_classes(dataset, data):
num_classes = getattr(dataset, "num_classes", None)
if not isinstance(num_classes, int) or num_classes <= 0:
num_classes = int(data.y.max().item()) + 1
return num_classes
def _get_split_mask(data, which):
name = {"train": "train_mask", "val": "val_mask", "test": "test_mask"}[which]
m = getattr(data, name, None)
if m is None and which == "val":
m = getattr(data, "valid_mask", None) or getattr(data, "validation_mask", None)
return m
def load_graph(root: str, use_hetero: bool, name: str, split_index: int):
if use_hetero:
ds = HeterophilousGraphDataset(root=root, name=name)
idx = max(0, min(split_index, len(ds) - 1))
data = ds[idx]
num_classes = _infer_num_classes(ds, data)
return data, num_classes
else:
ds = Planetoid(root=f"{root}/Planetoid", name=name)
data = ds[0]
num_classes = _infer_num_classes(ds, data)
return data, num_classes
def write_dot(path: str, data, num_classes: int, directed: bool, filter_split: str | None):
y = data.y
edge_index = data.edge_index
colors = _extend_palette(PALETTE, num_classes)
highlight_mask = None
if filter_split is not None:
m = _get_split_mask(data, filter_split)
if m is not None:
highlight_mask = m.bool()
gtype = "digraph" if directed else "graph"
eop = "->" if directed else "--"
with open(path, "w", encoding="utf-8") as f:
f.write(f"{gtype} {{\n")
for i in range(data.num_nodes):
cls = int(y[i])
if cls < 0 or cls >= num_classes:
cls = cls % num_classes
base_col = colors[cls]
if filter_split is not None:
if highlight_mask is not None and bool(highlight_mask[i]):
col = HIGHLIGHT_RED
f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n')
else:
f.write(f' {i} ;\n')
else:
col = base_col
f.write(f' {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n')
if directed:
for s, t in edge_index.t().tolist():
if s != t:
f.write(f" {s} {eop} {t};\n")
else:
seen = set()
for s, t in edge_index.t().tolist():
if s == t:
continue
a, b = (s, t) if s <= t else (t, s)
if (a, b) in seen:
continue
seen.add((a, b))
f.write(f" {a} {eop} {b};\n")
f.write("}\n")
def main():
parser = argparse.ArgumentParser(
description="Export PyG datasets to minimal DOT with class colors and optional split highlighting."
)
parser.add_argument("-o", "--output", default="graph.dot", help="Output .dot file (default: graph.dot)")
parser.add_argument("--directed", action="store_true", help="Write directed edges (default: undirected)")
parser.add_argument("--heterophilous", action="store_true", help="Use HeterophilousGraphDataset")
parser.add_argument("--name", default=None, help="Dataset name (Planetoid: Cora/CiteSeer/PubMed; Heterophilous: Amazon-ratings, Roman-empire, etc.)")
parser.add_argument("--root", default="data", help="Root folder for datasets (default: data)")
parser.add_argument("--split-index", type=int, default=0, help="Split index for heterophilous datasets (default: 0)")
parser.add_argument("--filter", choices=["train", "val", "test"], default=None, help="Highlight nodes in the selected split as red (others keep class colors)")
args = parser.parse_args()
dataset_name = args.name if args.name is not None else ("Amazon-ratings" if args.heterophilous else "Cora")
data, num_classes = load_graph(args.root, args.heterophilous, dataset_name, args.split_index)
write_dot(args.output, data, num_classes, args.directed, args.filter)
suffix = f", highlight={args.filter}" if args.filter else ""
print(f"Wrote {args.output} using {'HeterophilousGraphDataset' if args.heterophilous else 'Planetoid'}('{dataset_name}') | nodes={data.num_nodes}, edges={data.edge_index.size(1)}, classes={num_classes}{suffix}")
if __name__ == "__main__":
main()