File size: 5,660 Bytes
f74dd01
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
"""
Export a PyG dataset (Planetoid or Heterophilous) to a minimal Graphviz DOT.
- Nodes are colored by class.
- If --filter {train|val|test} is set, nodes in that split are colored red,
  other nodes are left uncolored.
- Undirected edges are deduplicated; directed edges are written as‑is.
- For HeterophilousGraphDataset with multiple splits (e.g., Amazon‑ratings has 10),
  use --split-index (default 0).
"""

import argparse
from torch_geometric.datasets import Planetoid, HeterophilousGraphDataset

PALETTE = [
    "#1f77b4", "#ff7f0e", "#2ca02c", "#d62728",
    "#9467bd", "#8c564b", "#e377c2", "#7f7f7f",
    "#bcbd22", "#17becf"
]
HIGHLIGHT_RED = "#ff0000"

def _extend_palette(base, need):
    if need <= len(base):
        return base[:need]
    def brighten(hex_color, factor):
        c = int(hex_color[1:], 16)
        r = (c >> 16) & 255
        g = (c >> 8) & 255
        b = c & 255
        r = int(min(255, r + (255 - r) * factor))
        g = int(min(255, g + (255 - g) * factor))
        b = int(min(255, b + (255 - b) * factor))
        return f"#{r:02x}{g:02x}{b:02x}"
    out = []
    for i in range(need):
        base_hex = base[i % len(base)]
        factor = 0.18 * (i // len(base))
        out.append(brighten(base_hex, factor))
    return out

def _infer_num_classes(dataset, data):
    num_classes = getattr(dataset, "num_classes", None)
    if not isinstance(num_classes, int) or num_classes <= 0:
        num_classes = int(data.y.max().item()) + 1
    return num_classes

def _get_split_mask(data, which):
    name = {"train": "train_mask", "val": "val_mask", "test": "test_mask"}[which]
    m = getattr(data, name, None)
    if m is None and which == "val":
        m = getattr(data, "valid_mask", None) or getattr(data, "validation_mask", None)
    return m

def load_graph(root: str, use_hetero: bool, name: str, split_index: int):
    if use_hetero:
        ds = HeterophilousGraphDataset(root=root, name=name)
        idx = max(0, min(split_index, len(ds) - 1))
        data = ds[idx]
        num_classes = _infer_num_classes(ds, data)
        return data, num_classes
    else:
        ds = Planetoid(root=f"{root}/Planetoid", name=name)
        data = ds[0]
        num_classes = _infer_num_classes(ds, data)
        return data, num_classes

def write_dot(path: str, data, num_classes: int, directed: bool, filter_split: str | None):
    y = data.y
    edge_index = data.edge_index
    colors = _extend_palette(PALETTE, num_classes)
    highlight_mask = None
    if filter_split is not None:
        m = _get_split_mask(data, filter_split)
        if m is not None:
            highlight_mask = m.bool()
    gtype = "digraph" if directed else "graph"
    eop = "->" if directed else "--"
    with open(path, "w", encoding="utf-8") as f:
        f.write(f"{gtype} {{\n")
        for i in range(data.num_nodes):
            cls = int(y[i])
            if cls < 0 or cls >= num_classes:
                cls = cls % num_classes
            base_col = colors[cls]
            if filter_split is not None:
                if highlight_mask is not None and bool(highlight_mask[i]):
                    col = HIGHLIGHT_RED
                    f.write(f'  {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n')
                else:
                    f.write(f'  {i} ;\n')
            else:
                col = base_col
                f.write(f'  {i} [color="{col}", style="filled", fillcolor="{col}", fontcolor="white"];\n')
        if directed:
            for s, t in edge_index.t().tolist():
                if s != t:
                    f.write(f"  {s} {eop} {t};\n")
        else:
            seen = set()
            for s, t in edge_index.t().tolist():
                if s == t:
                    continue
                a, b = (s, t) if s <= t else (t, s)
                if (a, b) in seen:
                    continue
                seen.add((a, b))
                f.write(f"  {a} {eop} {b};\n")
        f.write("}\n")

def main():
    parser = argparse.ArgumentParser(
        description="Export PyG datasets to minimal DOT with class colors and optional split highlighting."
    )
    parser.add_argument("-o", "--output", default="graph.dot", help="Output .dot file (default: graph.dot)")
    parser.add_argument("--directed", action="store_true", help="Write directed edges (default: undirected)")
    parser.add_argument("--heterophilous", action="store_true", help="Use HeterophilousGraphDataset")
    parser.add_argument("--name", default=None, help="Dataset name (Planetoid: Cora/CiteSeer/PubMed; Heterophilous: Amazon-ratings, Roman-empire, etc.)")
    parser.add_argument("--root", default="data", help="Root folder for datasets (default: data)")
    parser.add_argument("--split-index", type=int, default=0, help="Split index for heterophilous datasets (default: 0)")
    parser.add_argument("--filter", choices=["train", "val", "test"], default=None, help="Highlight nodes in the selected split as red (others keep class colors)")
    args = parser.parse_args()
    dataset_name = args.name if args.name is not None else ("Amazon-ratings" if args.heterophilous else "Cora")
    data, num_classes = load_graph(args.root, args.heterophilous, dataset_name, args.split_index)
    write_dot(args.output, data, num_classes, args.directed, args.filter)
    suffix = f", highlight={args.filter}" if args.filter else ""
    print(f"Wrote {args.output} using {'HeterophilousGraphDataset' if args.heterophilous else 'Planetoid'}('{dataset_name}') | nodes={data.num_nodes}, edges={data.edge_index.size(1)}, classes={num_classes}{suffix}")

if __name__ == "__main__":
    main()