"""Self-contained PyG loader for the GNN Disassembly dataset. Two loader variants: - load_pyg_frame_products_only(ep, frame) → constraint graph only, no robot - load_pyg_frame_with_robot(ep, frame) → constraint graph + robot agent node Both return torch_geometric.data.Data with: x (N, 268) node features edge_index (2, N*(N-1)) fully connected directed message-passing edges edge_attr (N*(N-1), 3) [has_constraint, is_locked, src_blocks_dst] num_nodes N Notes on the edge feature design: - The graph is FULLY CONNECTED and structurally symmetric. Both (i, j) and (j, i) exist in edge_index for every node pair i != j. - Direction is NOT encoded in the graph structure. It is encoded as a feature: `src_blocks_dst`. - `has_constraint` and `is_locked` are symmetric per pair (same value for both (i, j) and (j, i)). - `src_blocks_dst` is asymmetric: it is 1 if the edge's src node physically blocks its dst node, 0 otherwise. """ import json from dataclasses import dataclass from pathlib import Path from typing import Dict, List, Optional, Tuple import numpy as np import torch from torch_geometric.data import Data # ───────────────────────────────────────────────────────────────────────────── # Helpers # ───────────────────────────────────────────────────────────────────────────── def list_labeled_frames(episode_dir: Path) -> List[int]: """Return sorted list of frame indices that have saved annotations.""" mask_dir = episode_dir / "annotations" / "side_masks" if not mask_dir.exists(): return [] frames = [] for p in mask_dir.glob("frame_*.npz"): try: frames.append(int(p.stem.split("_")[1])) except (ValueError, IndexError): continue return sorted(frames) def resolve_frame_state(graph_json: dict, frame_idx: int) -> Tuple[Dict[str, bool], Dict[str, bool]]: """Resolve delta-encoded constraints + visibility at a frame. Walks frame_states from frame 0 to frame_idx, accumulating deltas. Returns (constraints_dict, visibility_dict). """ constraints: Dict[str, bool] = {} visibility: Dict[str, bool] = {} # Defaults: every component visible, every edge locked for c in graph_json["components"]: visibility[c["id"]] = True for e in graph_json["edges"]: constraints[f"{e['src']}->{e['dst']}"] = True # Walk deltas up to frame_idx fs_dict = graph_json.get("frame_states", {}) for f in sorted([int(k) for k in fs_dict]): if f > frame_idx: break fs = fs_dict[str(f)] for k, v in fs.get("constraints", {}).items(): constraints[k] = v for k, v in fs.get("visibility", {}).items(): visibility[k] = v return constraints, visibility def type_one_hot(comp_type: str, type_vocab: List[str]) -> List[float]: """9-dim one-hot encoding of component type based on type_vocab.""" return [1.0 if t == comp_type else 0.0 for t in type_vocab] # ───────────────────────────────────────────────────────────────────────────── # Raw data loader (NumPy only, no torch) # ───────────────────────────────────────────────────────────────────────────── @dataclass class FrameData: graph: dict masks: Dict[str, np.ndarray] embeddings: Dict[str, np.ndarray] depth_info: dict robot: Optional[dict] constraints: Dict[str, bool] visibility: Dict[str, bool] def load_frame_data(episode_dir: Path, frame_idx: int) -> FrameData: """Load all v3 annotation files for one frame.""" anno = episode_dir / "annotations" with open(anno / "side_graph.json") as f: graph = json.load(f) def _load_npz_dict(path: Path) -> Dict[str, np.ndarray]: if not path.exists(): return {} d = np.load(path) return {k: d[k] for k in d.files} masks = _load_npz_dict(anno / "side_masks" / f"frame_{frame_idx:06d}.npz") embeddings = _load_npz_dict(anno / "side_embeddings" / f"frame_{frame_idx:06d}.npz") depth_info = _load_npz_dict(anno / "side_depth_info" / f"frame_{frame_idx:06d}.npz") robot: Optional[dict] = None robot_path = anno / "side_robot" / f"frame_{frame_idx:06d}.npz" if robot_path.exists(): r = np.load(robot_path) if r["visible"][0] == 1: robot = {k: r[k] for k in r.files} constraints, visibility = resolve_frame_state(graph, frame_idx) return FrameData(graph, masks, embeddings, depth_info, robot, constraints, visibility) # ───────────────────────────────────────────────────────────────────────────── # PyG loader — products only # ───────────────────────────────────────────────────────────────────────────── def load_pyg_frame_products_only(episode_dir: Path, frame_idx: int) -> Data: """Fully connected PyG graph WITHOUT robot. Returns Data( x=[N, 268], edge_index=[2, N*(N-1)], edge_attr=[N*(N-1), 3], # [has_constraint, is_locked, src_blocks_dst] num_nodes=N, ) where N = number of product components (robot excluded). """ fd = load_frame_data(episode_dir, frame_idx) graph = fd.graph type_vocab = graph["type_vocab"] # 9 entries incl. robot nodes = graph["components"] # robot already excluded per spec N = len(nodes) # ── Node features ── # [256D SAM2 embedding, 3D position, 9D type one-hot, 1D visibility] = 269 # NOTE: 256 + 3 + 9 + 1 = 269 (not 268). Adjust if you need a different layout. x_list = [] for node in nodes: cid = node["id"] emb = fd.embeddings.get(cid, np.zeros(256, dtype=np.float32)) depth_valid_key = f"{cid}_depth_valid" centroid_key = f"{cid}_centroid" if (depth_valid_key in fd.depth_info and int(fd.depth_info[depth_valid_key][0]) == 1): pos = fd.depth_info[centroid_key].astype(np.float32) else: pos = np.zeros(3, dtype=np.float32) type_oh = type_one_hot(node["type"], type_vocab) # 9D vis = 1.0 if fd.visibility.get(cid, True) else 0.0 feat = np.concatenate([ emb.astype(np.float32), pos, np.array(type_oh, dtype=np.float32), np.array([vis], dtype=np.float32), ]) x_list.append(feat) x = torch.tensor(np.stack(x_list), dtype=torch.float32) if x_list else torch.empty((0, 269)) # ── Fully connected edges with 3D features ── # Edge feature: [has_constraint, is_locked, src_blocks_dst] # - has_constraint & is_locked are SYMMETRIC for the pair (A, B) # - src_blocks_dst is ASYMMETRIC: 1 if edge's src physically blocks dst constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]} pair_forward = {} # frozenset({a, b}) -> (blocker, blocked) for (s, d) in constraint_set: pair_forward[frozenset([s, d])] = (s, d) src_idx, dst_idx, edge_attr = [], [], [] for i in range(N): for j in range(N): if i == j: continue src_id = nodes[i]["id"] dst_id = nodes[j]["id"] src_idx.append(i) dst_idx.append(j) pair_key = frozenset([src_id, dst_id]) if pair_key in pair_forward: forward = pair_forward[pair_key] constraint_key = f"{forward[0]}->{forward[1]}" is_locked = fd.constraints.get(constraint_key, True) src_blocks_dst = 1.0 if src_id == forward[0] else 0.0 edge_attr.append([ 1.0, 1.0 if is_locked else 0.0, src_blocks_dst, ]) else: edge_attr.append([0.0, 0.0, 0.0]) # message passing only return Data( x=x, edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), edge_attr=torch.tensor(edge_attr, dtype=torch.float32), y=torch.tensor([frame_idx], dtype=torch.long), num_nodes=N, ) # ───────────────────────────────────────────────────────────────────────────── # PyG loader — with robot agent node # ───────────────────────────────────────────────────────────────────────────── def load_pyg_frame_with_robot(episode_dir: Path, frame_idx: int) -> Data: """Fully connected PyG graph WITH robot appended as agent node. Robot is node N (the last node). All edges involving the robot have features [0, 0, 0] because the robot has no physical constraints. If the robot is not visible at this frame, returns the products-only graph. Additional attached tensors when robot is visible: data.robot_point_cloud (M, 3) float32 data.robot_pixel_coords (M, 2) int32 data.robot_mask (H, W) uint8 """ data = load_pyg_frame_products_only(episode_dir, frame_idx) fd = load_frame_data(episode_dir, frame_idx) if fd.robot is None: return data graph = fd.graph type_vocab = graph["type_vocab"] products = graph["components"] N_prod = len(products) N = N_prod + 1 # ── Build robot node features ── robot_emb = fd.robot["embedding"].astype(np.float32) robot_pos = (fd.robot["centroid"].astype(np.float32) if int(fd.robot["depth_valid"][0]) == 1 else np.zeros(3, dtype=np.float32)) robot_type_oh = type_one_hot("robot", type_vocab) robot_feat = np.concatenate([ robot_emb, robot_pos, np.array(robot_type_oh, dtype=np.float32), np.array([1.0], dtype=np.float32), ]) x = torch.cat([data.x, torch.tensor(robot_feat, dtype=torch.float32).unsqueeze(0)], dim=0) # ── Rebuild edges with 3D features ── constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]} pair_forward = {} for (s, d) in constraint_set: pair_forward[frozenset([s, d])] = (s, d) src_idx, dst_idx, edge_attr = [], [], [] # Products × Products for i in range(N_prod): for j in range(N_prod): if i == j: continue src_id = products[i]["id"] dst_id = products[j]["id"] src_idx.append(i) dst_idx.append(j) pair_key = frozenset([src_id, dst_id]) if pair_key in pair_forward: forward = pair_forward[pair_key] is_locked = fd.constraints.get(f"{forward[0]}->{forward[1]}", True) src_blocks_dst = 1.0 if src_id == forward[0] else 0.0 edge_attr.append([1.0, 1.0 if is_locked else 0.0, src_blocks_dst]) else: edge_attr.append([0.0, 0.0, 0.0]) # Robot ↔ Products (both directions, message-passing only) robot_idx = N_prod for i in range(N_prod): src_idx.append(robot_idx); dst_idx.append(i); edge_attr.append([0.0, 0.0, 0.0]) src_idx.append(i); dst_idx.append(robot_idx); edge_attr.append([0.0, 0.0, 0.0]) data = Data( x=x, edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), edge_attr=torch.tensor(edge_attr, dtype=torch.float32), y=torch.tensor([frame_idx], dtype=torch.long), num_nodes=N, ) data.robot_point_cloud = torch.tensor(fd.robot["point_cloud"], dtype=torch.float32) data.robot_pixel_coords = torch.tensor(fd.robot["pixel_coords"], dtype=torch.int32) data.robot_mask = torch.tensor(fd.robot["mask"], dtype=torch.uint8) return data # ───────────────────────────────────────────────────────────────────────────── # Episode iterator # ───────────────────────────────────────────────────────────────────────────── def iterate_episode(episode_dir: Path, with_robot: bool = True): """Yield (frame_idx, Data) pairs for all labeled frames in an episode.""" loader = load_pyg_frame_with_robot if with_robot else load_pyg_frame_products_only for frame_idx in list_labeled_frames(episode_dir): yield frame_idx, loader(episode_dir, frame_idx)