"""Extract the graph for frame 42 as a plain dict, then reconstruct and visualize it.""" import json from pathlib import Path import numpy as np import matplotlib.pyplot as plt import matplotlib.patches as mpatches import networkx as nx from PIL import Image from gnn_disassembly_loader import load_frame_data # ───────────────────────────────────────────────────────────────────────────── # Step 1: Extract graph as a plain dict # ───────────────────────────────────────────────────────────────────────────── episode = Path("session_0408_162129/episode_00") frame_idx = 42 fd = load_frame_data(episode, frame_idx) graph_json = fd.graph # Build graph_dict — a self-contained dict describing this frame's graph graph_dict = { "frame_idx": frame_idx, "nodes": [], "edges": [], } # Nodes: one per product component + robot for comp in graph_json["components"]: cid = comp["id"] centroid_key = f"{cid}_centroid" depth_valid_key = f"{cid}_depth_valid" has_depth = (depth_valid_key in fd.depth_info and int(fd.depth_info[depth_valid_key][0]) == 1) graph_dict["nodes"].append({ "id": cid, "type": comp["type"], "color": comp["color"], "visible": fd.visibility.get(cid, True), "centroid_3d": fd.depth_info[centroid_key].tolist() if has_depth else [0, 0, 0], "embedding_norm": float(np.linalg.norm(fd.embeddings[cid])) if cid in fd.embeddings else 0.0, "mask_area": int(fd.depth_info[f"{cid}_area"][0]) if f"{cid}_area" in fd.depth_info else 0, }) # Robot node if fd.robot is not None: graph_dict["nodes"].append({ "id": "robot", "type": "robot", "color": "#F5F5F5", "visible": True, "centroid_3d": fd.robot["centroid"].tolist(), "embedding_norm": float(np.linalg.norm(fd.robot["embedding"])), "mask_area": int(fd.robot["area"][0]), }) # Edges: only physical constraints (the meaningful ones for visualization) for edge in graph_json["edges"]: constraint_key = f"{edge['src']}->{edge['dst']}" is_locked = fd.constraints.get(constraint_key, True) graph_dict["edges"].append({ "src": edge["src"], "dst": edge["dst"], "is_locked": is_locked, }) # Save to JSON dict_path = Path("frame_042_graph.json") with open(dict_path, "w") as f: json.dump(graph_dict, f, indent=2) print(f"Saved graph dict to {dict_path}") print(f" {len(graph_dict['nodes'])} nodes, {len(graph_dict['edges'])} constraint edges") # ───────────────────────────────────────────────────────────────────────────── # Step 2: Read the dict back and reconstruct as a networkx graph # ───────────────────────────────────────────────────────────────────────────── with open(dict_path) as f: gd = json.load(f) G = nx.DiGraph() for node in gd["nodes"]: G.add_node(node["id"], **node) for edge in gd["edges"]: G.add_edge(edge["src"], edge["dst"], is_locked=edge["is_locked"]) print(f"\nReconstructed graph: {G.number_of_nodes()} nodes, {G.number_of_edges()} directed edges") # ───────────────────────────────────────────────────────────────────────────── # Step 3: Visualize — two panels: RGB image + graph overlay # ───────────────────────────────────────────────────────────────────────────── # Use node_positions from side_graph.json for layout (if available), else spring layout stored_pos = graph_json.get("node_positions", {}) # Build positions dict — use stored positions, place missing nodes with spring layout pos = {} for nid in G.nodes: if nid in stored_pos: x, y = stored_pos[nid] pos[nid] = (x, -y) # flip y so it matches visual top-down convention elif nid == "robot": # Place robot off to the side pos[nid] = (450, 0) # For nodes without stored positions, use spring layout seeded by known positions missing = [n for n in G.nodes if n not in pos] if missing: sub = nx.spring_layout(G, pos=pos, fixed=list(pos.keys()), seed=42) for n in missing: pos[n] = sub[n] fig, axes = plt.subplots(1, 2, figsize=(20, 8)) # Panel 1: RGB image rgb_path = episode / "side" / "rgb" / f"frame_{frame_idx:06d}.png" if rgb_path.exists(): img = np.array(Image.open(rgb_path)) axes[0].imshow(img) axes[0].set_title(f"Frame {frame_idx} — RGB", fontsize=14) axes[0].axis("off") else: axes[0].text(0.5, 0.5, "RGB image not found", ha="center", va="center", fontsize=14) axes[0].set_title(f"Frame {frame_idx} — RGB", fontsize=14) # Panel 2: constraint graph ax = axes[1] # Separate constraint edges by lock state locked_edges = [(e["src"], e["dst"]) for e in gd["edges"] if e["is_locked"]] unlocked_edges = [(e["src"], e["dst"]) for e in gd["edges"] if not e["is_locked"]] # Node colors and sizes node_colors = [] node_sizes = [] for nid in G.nodes: ndata = G.nodes[nid] node_colors.append(ndata["color"]) if ndata["type"] == "robot": node_sizes.append(800) elif ndata["type"] == "motherboard": node_sizes.append(1200) else: node_sizes.append(600) # Draw nodes nx.draw_networkx_nodes(G, pos, ax=ax, node_color=node_colors, node_size=node_sizes, edgecolors="black", linewidths=1.5) # Draw locked constraint edges (solid red arrows) if locked_edges: nx.draw_networkx_edges(G, pos, edgelist=locked_edges, ax=ax, edge_color="#E74C3C", width=2.0, alpha=0.8, arrows=True, arrowsize=15, arrowstyle="-|>", connectionstyle="arc3,rad=0.1") # Draw unlocked constraint edges (dashed green arrows) if unlocked_edges: nx.draw_networkx_edges(G, pos, edgelist=unlocked_edges, ax=ax, edge_color="#2ECC71", width=2.0, alpha=0.8, style="dashed", arrows=True, arrowsize=15, arrowstyle="-|>", connectionstyle="arc3,rad=0.1") # Labels labels = {} for nid in G.nodes: ndata = G.nodes[nid] vis_marker = "" if ndata["visible"] else " (hidden)" labels[nid] = f"{nid}{vis_marker}" nx.draw_networkx_labels(G, pos, labels, ax=ax, font_size=7, font_weight="bold") # Legend legend_handles = [ mpatches.Patch(color="#E74C3C", label="Constraint (locked)"), mpatches.Patch(color="#2ECC71", label="Constraint (unlocked)"), ] # Add type color legend type_colors_seen = {} for node in gd["nodes"]: if node["type"] not in type_colors_seen: type_colors_seen[node["type"]] = node["color"] for t, c in type_colors_seen.items(): legend_handles.append(mpatches.Patch(facecolor=c, edgecolor="black", label=t)) ax.legend(handles=legend_handles, loc="upper left", fontsize=8, framealpha=0.9) ax.set_title(f"Frame {frame_idx} — Constraint Graph ({len(locked_edges)} locked, " f"{len(unlocked_edges)} unlocked)", fontsize=14) ax.axis("off") plt.tight_layout() out_path = Path("frame_042_graph_viz.png") plt.savefig(out_path, dpi=150, bbox_inches="tight") print(f"\nSaved visualization to {out_path}") plt.close()