gnn_wm / sampled_data /gnn_disassembly_loader.py
EndeavourDD's picture
Add files using upload-large-folder tool
4ee0c8c verified
"""Self-contained PyG loader for the GNN Disassembly dataset.
Two loader variants:
- load_pyg_frame_products_only(ep, frame) β†’ constraint graph only, no robot
- load_pyg_frame_with_robot(ep, frame) β†’ constraint graph + robot agent node
Both return torch_geometric.data.Data with:
x (N, 268) node features
edge_index (2, N*(N-1)) fully connected directed message-passing edges
edge_attr (N*(N-1), 3) [has_constraint, is_locked, src_blocks_dst]
num_nodes N
Notes on the edge feature design:
- The graph is FULLY CONNECTED and structurally symmetric.
Both (i, j) and (j, i) exist in edge_index for every node pair i != j.
- Direction is NOT encoded in the graph structure. It is encoded as
a feature: `src_blocks_dst`.
- `has_constraint` and `is_locked` are symmetric per pair (same value
for both (i, j) and (j, i)).
- `src_blocks_dst` is asymmetric: it is 1 if the edge's src node
physically blocks its dst node, 0 otherwise.
"""
import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
from torch_geometric.data import Data
# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────
def list_labeled_frames(episode_dir: Path) -> List[int]:
"""Return sorted list of frame indices that have saved annotations."""
mask_dir = episode_dir / "annotations" / "side_masks"
if not mask_dir.exists():
return []
frames = []
for p in mask_dir.glob("frame_*.npz"):
try:
frames.append(int(p.stem.split("_")[1]))
except (ValueError, IndexError):
continue
return sorted(frames)
def resolve_frame_state(graph_json: dict, frame_idx: int) -> Tuple[Dict[str, bool], Dict[str, bool]]:
"""Resolve delta-encoded constraints + visibility at a frame.
Walks frame_states from frame 0 to frame_idx, accumulating deltas.
Returns (constraints_dict, visibility_dict).
"""
constraints: Dict[str, bool] = {}
visibility: Dict[str, bool] = {}
# Defaults: every component visible, every edge locked
for c in graph_json["components"]:
visibility[c["id"]] = True
for e in graph_json["edges"]:
constraints[f"{e['src']}->{e['dst']}"] = True
# Walk deltas up to frame_idx
fs_dict = graph_json.get("frame_states", {})
for f in sorted([int(k) for k in fs_dict]):
if f > frame_idx:
break
fs = fs_dict[str(f)]
for k, v in fs.get("constraints", {}).items():
constraints[k] = v
for k, v in fs.get("visibility", {}).items():
visibility[k] = v
return constraints, visibility
def type_one_hot(comp_type: str, type_vocab: List[str]) -> List[float]:
"""9-dim one-hot encoding of component type based on type_vocab."""
return [1.0 if t == comp_type else 0.0 for t in type_vocab]
# ─────────────────────────────────────────────────────────────────────────────
# Raw data loader (NumPy only, no torch)
# ─────────────────────────────────────────────────────────────────────────────
@dataclass
class FrameData:
graph: dict
masks: Dict[str, np.ndarray]
embeddings: Dict[str, np.ndarray]
depth_info: dict
robot: Optional[dict]
constraints: Dict[str, bool]
visibility: Dict[str, bool]
def load_frame_data(episode_dir: Path, frame_idx: int) -> FrameData:
"""Load all v3 annotation files for one frame."""
anno = episode_dir / "annotations"
with open(anno / "side_graph.json") as f:
graph = json.load(f)
def _load_npz_dict(path: Path) -> Dict[str, np.ndarray]:
if not path.exists():
return {}
d = np.load(path)
return {k: d[k] for k in d.files}
masks = _load_npz_dict(anno / "side_masks" / f"frame_{frame_idx:06d}.npz")
embeddings = _load_npz_dict(anno / "side_embeddings" / f"frame_{frame_idx:06d}.npz")
depth_info = _load_npz_dict(anno / "side_depth_info" / f"frame_{frame_idx:06d}.npz")
robot: Optional[dict] = None
robot_path = anno / "side_robot" / f"frame_{frame_idx:06d}.npz"
if robot_path.exists():
r = np.load(robot_path)
if r["visible"][0] == 1:
robot = {k: r[k] for k in r.files}
constraints, visibility = resolve_frame_state(graph, frame_idx)
return FrameData(graph, masks, embeddings, depth_info, robot, constraints, visibility)
# ─────────────────────────────────────────────────────────────────────────────
# PyG loader β€” products only
# ─────────────────────────────────────────────────────────────────────────────
def load_pyg_frame_products_only(episode_dir: Path, frame_idx: int) -> Data:
"""Fully connected PyG graph WITHOUT robot.
Returns Data(
x=[N, 268],
edge_index=[2, N*(N-1)],
edge_attr=[N*(N-1), 3], # [has_constraint, is_locked, src_blocks_dst]
num_nodes=N,
)
where N = number of product components (robot excluded).
"""
fd = load_frame_data(episode_dir, frame_idx)
graph = fd.graph
type_vocab = graph["type_vocab"] # 9 entries incl. robot
nodes = graph["components"] # robot already excluded per spec
N = len(nodes)
# ── Node features ──
# [256D SAM2 embedding, 3D position, 9D type one-hot, 1D visibility] = 269
# NOTE: 256 + 3 + 9 + 1 = 269 (not 268). Adjust if you need a different layout.
x_list = []
for node in nodes:
cid = node["id"]
emb = fd.embeddings.get(cid, np.zeros(256, dtype=np.float32))
depth_valid_key = f"{cid}_depth_valid"
centroid_key = f"{cid}_centroid"
if (depth_valid_key in fd.depth_info
and int(fd.depth_info[depth_valid_key][0]) == 1):
pos = fd.depth_info[centroid_key].astype(np.float32)
else:
pos = np.zeros(3, dtype=np.float32)
type_oh = type_one_hot(node["type"], type_vocab) # 9D
vis = 1.0 if fd.visibility.get(cid, True) else 0.0
feat = np.concatenate([
emb.astype(np.float32),
pos,
np.array(type_oh, dtype=np.float32),
np.array([vis], dtype=np.float32),
])
x_list.append(feat)
x = torch.tensor(np.stack(x_list), dtype=torch.float32) if x_list else torch.empty((0, 269))
# ── Fully connected edges with 3D features ──
# Edge feature: [has_constraint, is_locked, src_blocks_dst]
# - has_constraint & is_locked are SYMMETRIC for the pair (A, B)
# - src_blocks_dst is ASYMMETRIC: 1 if edge's src physically blocks dst
constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]}
pair_forward = {} # frozenset({a, b}) -> (blocker, blocked)
for (s, d) in constraint_set:
pair_forward[frozenset([s, d])] = (s, d)
src_idx, dst_idx, edge_attr = [], [], []
for i in range(N):
for j in range(N):
if i == j:
continue
src_id = nodes[i]["id"]
dst_id = nodes[j]["id"]
src_idx.append(i)
dst_idx.append(j)
pair_key = frozenset([src_id, dst_id])
if pair_key in pair_forward:
forward = pair_forward[pair_key]
constraint_key = f"{forward[0]}->{forward[1]}"
is_locked = fd.constraints.get(constraint_key, True)
src_blocks_dst = 1.0 if src_id == forward[0] else 0.0
edge_attr.append([
1.0,
1.0 if is_locked else 0.0,
src_blocks_dst,
])
else:
edge_attr.append([0.0, 0.0, 0.0]) # message passing only
return Data(
x=x,
edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
y=torch.tensor([frame_idx], dtype=torch.long),
num_nodes=N,
)
# ─────────────────────────────────────────────────────────────────────────────
# PyG loader β€” with robot agent node
# ─────────────────────────────────────────────────────────────────────────────
def load_pyg_frame_with_robot(episode_dir: Path, frame_idx: int) -> Data:
"""Fully connected PyG graph WITH robot appended as agent node.
Robot is node N (the last node). All edges involving the robot have
features [0, 0, 0] because the robot has no physical constraints.
If the robot is not visible at this frame, returns the products-only graph.
Additional attached tensors when robot is visible:
data.robot_point_cloud (M, 3) float32
data.robot_pixel_coords (M, 2) int32
data.robot_mask (H, W) uint8
"""
data = load_pyg_frame_products_only(episode_dir, frame_idx)
fd = load_frame_data(episode_dir, frame_idx)
if fd.robot is None:
return data
graph = fd.graph
type_vocab = graph["type_vocab"]
products = graph["components"]
N_prod = len(products)
N = N_prod + 1
# ── Build robot node features ──
robot_emb = fd.robot["embedding"].astype(np.float32)
robot_pos = (fd.robot["centroid"].astype(np.float32)
if int(fd.robot["depth_valid"][0]) == 1
else np.zeros(3, dtype=np.float32))
robot_type_oh = type_one_hot("robot", type_vocab)
robot_feat = np.concatenate([
robot_emb, robot_pos,
np.array(robot_type_oh, dtype=np.float32),
np.array([1.0], dtype=np.float32),
])
x = torch.cat([data.x, torch.tensor(robot_feat, dtype=torch.float32).unsqueeze(0)], dim=0)
# ── Rebuild edges with 3D features ──
constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]}
pair_forward = {}
for (s, d) in constraint_set:
pair_forward[frozenset([s, d])] = (s, d)
src_idx, dst_idx, edge_attr = [], [], []
# Products Γ— Products
for i in range(N_prod):
for j in range(N_prod):
if i == j:
continue
src_id = products[i]["id"]
dst_id = products[j]["id"]
src_idx.append(i)
dst_idx.append(j)
pair_key = frozenset([src_id, dst_id])
if pair_key in pair_forward:
forward = pair_forward[pair_key]
is_locked = fd.constraints.get(f"{forward[0]}->{forward[1]}", True)
src_blocks_dst = 1.0 if src_id == forward[0] else 0.0
edge_attr.append([1.0, 1.0 if is_locked else 0.0, src_blocks_dst])
else:
edge_attr.append([0.0, 0.0, 0.0])
# Robot ↔ Products (both directions, message-passing only)
robot_idx = N_prod
for i in range(N_prod):
src_idx.append(robot_idx); dst_idx.append(i); edge_attr.append([0.0, 0.0, 0.0])
src_idx.append(i); dst_idx.append(robot_idx); edge_attr.append([0.0, 0.0, 0.0])
data = Data(
x=x,
edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
y=torch.tensor([frame_idx], dtype=torch.long),
num_nodes=N,
)
data.robot_point_cloud = torch.tensor(fd.robot["point_cloud"], dtype=torch.float32)
data.robot_pixel_coords = torch.tensor(fd.robot["pixel_coords"], dtype=torch.int32)
data.robot_mask = torch.tensor(fd.robot["mask"], dtype=torch.uint8)
return data
# ─────────────────────────────────────────────────────────────────────────────
# Episode iterator
# ─────────────────────────────────────────────────────────────────────────────
def iterate_episode(episode_dir: Path, with_robot: bool = True):
"""Yield (frame_idx, Data) pairs for all labeled frames in an episode."""
loader = load_pyg_frame_with_robot if with_robot else load_pyg_frame_products_only
for frame_idx in list_labeled_frames(episode_dir):
yield frame_idx, loader(episode_dir, frame_idx)