| """Batch-visualize image + graph pairs for an episode. |
| |
| Usage: |
| python visualize_graphs_batch.py \ |
| --episode session_0408_162129/episode_00 \ |
| --graphs session_0408_162129/episode_00/graph_per_side_image \ |
| --output session_0408_162129/episode_00/graph_viz \ |
| [--view side] [--stride 1] [--limit N] |
| """ |
|
|
| import argparse |
| import json |
| import re |
| from pathlib import Path |
|
|
| import numpy as np |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
| import matplotlib.patches as mpatches |
| import networkx as nx |
| from PIL import Image |
|
|
|
|
| def load_side_graph_positions(episode: Path) -> dict: |
| side_graph_path = episode / "annotations" / "side_graph.json" |
| if not side_graph_path.exists(): |
| side_graph_path = episode / "annotations" / "graph.json" |
| if not side_graph_path.exists(): |
| return {} |
| with open(side_graph_path) as f: |
| data = json.load(f) |
| if "node_positions" in data: |
| return data["node_positions"] |
| if "side_graph" in data and isinstance(data["side_graph"], dict): |
| return data["side_graph"].get("node_positions", {}) |
| return {} |
|
|
|
|
| def visualize_pair(frame_idx: int, rgb_path: Path, graph_dict: dict, |
| stored_pos: dict, out_path: Path): |
| G = nx.DiGraph() |
| for node in graph_dict["nodes"]: |
| G.add_node(node["id"], **node) |
| for edge in graph_dict["edges"]: |
| G.add_edge(edge["src"], edge["dst"], is_locked=edge["is_locked"]) |
|
|
| pos = {} |
| for nid in G.nodes: |
| if nid in stored_pos: |
| x, y = stored_pos[nid] |
| pos[nid] = (x, -y) |
| elif nid == "robot": |
| pos[nid] = (450, 0) |
| missing = [n for n in G.nodes if n not in pos] |
| if missing: |
| if pos: |
| sub = nx.spring_layout(G, pos=pos, fixed=list(pos.keys()), seed=42) |
| else: |
| sub = nx.spring_layout(G, seed=42) |
| for n in missing: |
| pos[n] = sub[n] |
|
|
| fig, axes = plt.subplots(1, 2, figsize=(20, 8)) |
|
|
| if rgb_path.exists(): |
| axes[0].imshow(np.array(Image.open(rgb_path))) |
| else: |
| axes[0].text(0.5, 0.5, "RGB not found", ha="center", va="center") |
| axes[0].set_title(f"Frame {frame_idx} — RGB") |
| axes[0].axis("off") |
|
|
| ax = axes[1] |
| locked = [(e["src"], e["dst"]) for e in graph_dict["edges"] if e["is_locked"]] |
| unlocked = [(e["src"], e["dst"]) for e in graph_dict["edges"] if not e["is_locked"]] |
|
|
| node_colors, node_sizes = [], [] |
| for nid in G.nodes: |
| nd = G.nodes[nid] |
| node_colors.append(nd.get("color", "#888")) |
| t = nd.get("type", "") |
| node_sizes.append(800 if t == "robot" else 1200 if t == "motherboard" else 600) |
|
|
| nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, |
| node_size=node_sizes, edgecolors="black", linewidths=1.5) |
| if locked: |
| nx.draw_networkx_edges(G, pos, edgelist=locked, ax=ax, edge_color="#E74C3C", |
| width=2.0, alpha=0.8, arrows=True, arrowsize=15, |
| arrowstyle="-|>", connectionstyle="arc3,rad=0.1") |
| if unlocked: |
| nx.draw_networkx_edges(G, pos, edgelist=unlocked, ax=ax, edge_color="#2ECC71", |
| width=2.0, alpha=0.8, style="dashed", arrows=True, |
| arrowsize=15, arrowstyle="-|>", |
| connectionstyle="arc3,rad=0.1") |
|
|
| labels = {nid: f"{nid}{'' if G.nodes[nid].get('visible', True) else ' (hidden)'}" |
| for nid in G.nodes} |
| nx.draw_networkx_labels(G, pos, labels, ax=ax, font_size=7, font_weight="bold") |
|
|
| legend = [mpatches.Patch(color="#E74C3C", label="locked"), |
| mpatches.Patch(color="#2ECC71", label="unlocked")] |
| seen = {} |
| for n in graph_dict["nodes"]: |
| seen.setdefault(n.get("type", ""), n.get("color", "#888")) |
| for t, c in seen.items(): |
| legend.append(mpatches.Patch(facecolor=c, edgecolor="black", label=t)) |
| ax.legend(handles=legend, loc="upper left", fontsize=8, framealpha=0.9) |
| ax.set_title(f"Frame {frame_idx} — Graph ({len(locked)} locked, {len(unlocked)} unlocked)") |
| ax.axis("off") |
|
|
| plt.tight_layout() |
| plt.savefig(out_path, dpi=120, bbox_inches="tight") |
| plt.close(fig) |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--episode", required=True) |
| p.add_argument("--graphs", required=True, help="Directory of frame_NNNNNN_graph.json") |
| p.add_argument("--output", required=True, help="Directory for visualization PNGs") |
| p.add_argument("--view", default="side") |
| p.add_argument("--stride", type=int, default=1) |
| p.add_argument("--limit", type=int, default=None) |
| args = p.parse_args() |
|
|
| episode = Path(args.episode) |
| graphs_dir = Path(args.graphs) |
| out_dir = Path(args.output) |
| out_dir.mkdir(parents=True, exist_ok=True) |
| rgb_dir = episode / args.view / "rgb" |
|
|
| stored_pos = load_side_graph_positions(episode) |
|
|
| graph_files = sorted(graphs_dir.glob("frame_*_graph.json")) |
| pairs = [] |
| for gf in graph_files: |
| m = re.search(r"\d+", gf.stem) |
| if not m: |
| continue |
| idx = int(m.group()) |
| pairs.append((idx, gf)) |
|
|
| pairs = pairs[::args.stride] |
| if args.limit: |
| pairs = pairs[:args.limit] |
|
|
| for idx, gf in pairs: |
| with open(gf) as f: |
| graph_dict = json.load(f) |
| rgb_path = rgb_dir / f"frame_{idx:06d}.png" |
| out_path = out_dir / f"frame_{idx:06d}_viz.png" |
| visualize_pair(idx, rgb_path, graph_dict, stored_pos, out_path) |
|
|
| print(f"Saved {len(pairs)} visualizations to {out_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|