Instructions to use EndeavourDD/gnn_wm with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Diffusers
How to use EndeavourDD/gnn_wm with Diffusers:
pip install -U diffusers transformers accelerate
import torch from diffusers import DiffusionPipeline # switch to "mps" for apple devices pipe = DiffusionPipeline.from_pretrained("EndeavourDD/gnn_wm", dtype=torch.bfloat16, device_map="cuda") prompt = "Astronaut in a jungle, cold color palette, muted colors, detailed, 8k" image = pipe(prompt).images[0] - Notebooks
- Google Colab
- Kaggle
| """Self-contained PyG loader for the GNN Disassembly dataset. | |
| Two loader variants: | |
| - load_pyg_frame_products_only(ep, frame) β constraint graph only, no robot | |
| - load_pyg_frame_with_robot(ep, frame) β constraint graph + robot agent node | |
| Both return torch_geometric.data.Data with: | |
| x (N, 268) node features | |
| edge_index (2, N*(N-1)) fully connected directed message-passing edges | |
| edge_attr (N*(N-1), 3) [has_constraint, is_locked, src_blocks_dst] | |
| num_nodes N | |
| Notes on the edge feature design: | |
| - The graph is FULLY CONNECTED and structurally symmetric. | |
| Both (i, j) and (j, i) exist in edge_index for every node pair i != j. | |
| - Direction is NOT encoded in the graph structure. It is encoded as | |
| a feature: `src_blocks_dst`. | |
| - `has_constraint` and `is_locked` are symmetric per pair (same value | |
| for both (i, j) and (j, i)). | |
| - `src_blocks_dst` is asymmetric: it is 1 if the edge's src node | |
| physically blocks its dst node, 0 otherwise. | |
| """ | |
| import json | |
| from dataclasses import dataclass | |
| from pathlib import Path | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| from torch_geometric.data import Data | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Helpers | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def list_labeled_frames(episode_dir: Path) -> List[int]: | |
| """Return sorted list of frame indices that have saved annotations.""" | |
| mask_dir = episode_dir / "annotations" / "side_masks" | |
| if not mask_dir.exists(): | |
| return [] | |
| frames = [] | |
| for p in mask_dir.glob("frame_*.npz"): | |
| try: | |
| frames.append(int(p.stem.split("_")[1])) | |
| except (ValueError, IndexError): | |
| continue | |
| return sorted(frames) | |
| def resolve_frame_state(graph_json: dict, frame_idx: int) -> Tuple[Dict[str, bool], Dict[str, bool]]: | |
| """Resolve delta-encoded constraints + visibility at a frame. | |
| Walks frame_states from frame 0 to frame_idx, accumulating deltas. | |
| Returns (constraints_dict, visibility_dict). | |
| """ | |
| constraints: Dict[str, bool] = {} | |
| visibility: Dict[str, bool] = {} | |
| # Defaults: every component visible, every edge locked | |
| for c in graph_json["components"]: | |
| visibility[c["id"]] = True | |
| for e in graph_json["edges"]: | |
| constraints[f"{e['src']}->{e['dst']}"] = True | |
| # Walk deltas up to frame_idx | |
| fs_dict = graph_json.get("frame_states", {}) | |
| for f in sorted([int(k) for k in fs_dict]): | |
| if f > frame_idx: | |
| break | |
| fs = fs_dict[str(f)] | |
| for k, v in fs.get("constraints", {}).items(): | |
| constraints[k] = v | |
| for k, v in fs.get("visibility", {}).items(): | |
| visibility[k] = v | |
| return constraints, visibility | |
| def type_one_hot(comp_type: str, type_vocab: List[str]) -> List[float]: | |
| """9-dim one-hot encoding of component type based on type_vocab.""" | |
| return [1.0 if t == comp_type else 0.0 for t in type_vocab] | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Raw data loader (NumPy only, no torch) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| class FrameData: | |
| graph: dict | |
| masks: Dict[str, np.ndarray] | |
| embeddings: Dict[str, np.ndarray] | |
| depth_info: dict | |
| robot: Optional[dict] | |
| constraints: Dict[str, bool] | |
| visibility: Dict[str, bool] | |
| def load_frame_data(episode_dir: Path, frame_idx: int) -> FrameData: | |
| """Load all v3 annotation files for one frame.""" | |
| anno = episode_dir / "annotations" | |
| with open(anno / "side_graph.json") as f: | |
| graph = json.load(f) | |
| def _load_npz_dict(path: Path) -> Dict[str, np.ndarray]: | |
| if not path.exists(): | |
| return {} | |
| d = np.load(path) | |
| return {k: d[k] for k in d.files} | |
| masks = _load_npz_dict(anno / "side_masks" / f"frame_{frame_idx:06d}.npz") | |
| embeddings = _load_npz_dict(anno / "side_embeddings" / f"frame_{frame_idx:06d}.npz") | |
| depth_info = _load_npz_dict(anno / "side_depth_info" / f"frame_{frame_idx:06d}.npz") | |
| robot: Optional[dict] = None | |
| robot_path = anno / "side_robot" / f"frame_{frame_idx:06d}.npz" | |
| if robot_path.exists(): | |
| r = np.load(robot_path) | |
| if r["visible"][0] == 1: | |
| robot = {k: r[k] for k in r.files} | |
| constraints, visibility = resolve_frame_state(graph, frame_idx) | |
| return FrameData(graph, masks, embeddings, depth_info, robot, constraints, visibility) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PyG loader β products only | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_pyg_frame_products_only(episode_dir: Path, frame_idx: int) -> Data: | |
| """Fully connected PyG graph WITHOUT robot. | |
| Returns Data( | |
| x=[N, 268], | |
| edge_index=[2, N*(N-1)], | |
| edge_attr=[N*(N-1), 3], # [has_constraint, is_locked, src_blocks_dst] | |
| num_nodes=N, | |
| ) | |
| where N = number of product components (robot excluded). | |
| """ | |
| fd = load_frame_data(episode_dir, frame_idx) | |
| graph = fd.graph | |
| type_vocab = graph["type_vocab"] # 9 entries incl. robot | |
| nodes = graph["components"] # robot already excluded per spec | |
| N = len(nodes) | |
| # ββ Node features ββ | |
| # [256D SAM2 embedding, 3D position, 9D type one-hot, 1D visibility] = 269 | |
| # NOTE: 256 + 3 + 9 + 1 = 269 (not 268). Adjust if you need a different layout. | |
| x_list = [] | |
| for node in nodes: | |
| cid = node["id"] | |
| emb = fd.embeddings.get(cid, np.zeros(256, dtype=np.float32)) | |
| depth_valid_key = f"{cid}_depth_valid" | |
| centroid_key = f"{cid}_centroid" | |
| if (depth_valid_key in fd.depth_info | |
| and int(fd.depth_info[depth_valid_key][0]) == 1): | |
| pos = fd.depth_info[centroid_key].astype(np.float32) | |
| else: | |
| pos = np.zeros(3, dtype=np.float32) | |
| type_oh = type_one_hot(node["type"], type_vocab) # 9D | |
| vis = 1.0 if fd.visibility.get(cid, True) else 0.0 | |
| feat = np.concatenate([ | |
| emb.astype(np.float32), | |
| pos, | |
| np.array(type_oh, dtype=np.float32), | |
| np.array([vis], dtype=np.float32), | |
| ]) | |
| x_list.append(feat) | |
| x = torch.tensor(np.stack(x_list), dtype=torch.float32) if x_list else torch.empty((0, 269)) | |
| # ββ Fully connected edges with 3D features ββ | |
| # Edge feature: [has_constraint, is_locked, src_blocks_dst] | |
| # - has_constraint & is_locked are SYMMETRIC for the pair (A, B) | |
| # - src_blocks_dst is ASYMMETRIC: 1 if edge's src physically blocks dst | |
| constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]} | |
| pair_forward = {} # frozenset({a, b}) -> (blocker, blocked) | |
| for (s, d) in constraint_set: | |
| pair_forward[frozenset([s, d])] = (s, d) | |
| src_idx, dst_idx, edge_attr = [], [], [] | |
| for i in range(N): | |
| for j in range(N): | |
| if i == j: | |
| continue | |
| src_id = nodes[i]["id"] | |
| dst_id = nodes[j]["id"] | |
| src_idx.append(i) | |
| dst_idx.append(j) | |
| pair_key = frozenset([src_id, dst_id]) | |
| if pair_key in pair_forward: | |
| forward = pair_forward[pair_key] | |
| constraint_key = f"{forward[0]}->{forward[1]}" | |
| is_locked = fd.constraints.get(constraint_key, True) | |
| src_blocks_dst = 1.0 if src_id == forward[0] else 0.0 | |
| edge_attr.append([ | |
| 1.0, | |
| 1.0 if is_locked else 0.0, | |
| src_blocks_dst, | |
| ]) | |
| else: | |
| edge_attr.append([0.0, 0.0, 0.0]) # message passing only | |
| return Data( | |
| x=x, | |
| edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), | |
| edge_attr=torch.tensor(edge_attr, dtype=torch.float32), | |
| y=torch.tensor([frame_idx], dtype=torch.long), | |
| num_nodes=N, | |
| ) | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # PyG loader β with robot agent node | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def load_pyg_frame_with_robot(episode_dir: Path, frame_idx: int) -> Data: | |
| """Fully connected PyG graph WITH robot appended as agent node. | |
| Robot is node N (the last node). All edges involving the robot have | |
| features [0, 0, 0] because the robot has no physical constraints. | |
| If the robot is not visible at this frame, returns the products-only graph. | |
| Additional attached tensors when robot is visible: | |
| data.robot_point_cloud (M, 3) float32 | |
| data.robot_pixel_coords (M, 2) int32 | |
| data.robot_mask (H, W) uint8 | |
| """ | |
| data = load_pyg_frame_products_only(episode_dir, frame_idx) | |
| fd = load_frame_data(episode_dir, frame_idx) | |
| if fd.robot is None: | |
| return data | |
| graph = fd.graph | |
| type_vocab = graph["type_vocab"] | |
| products = graph["components"] | |
| N_prod = len(products) | |
| N = N_prod + 1 | |
| # ββ Build robot node features ββ | |
| robot_emb = fd.robot["embedding"].astype(np.float32) | |
| robot_pos = (fd.robot["centroid"].astype(np.float32) | |
| if int(fd.robot["depth_valid"][0]) == 1 | |
| else np.zeros(3, dtype=np.float32)) | |
| robot_type_oh = type_one_hot("robot", type_vocab) | |
| robot_feat = np.concatenate([ | |
| robot_emb, robot_pos, | |
| np.array(robot_type_oh, dtype=np.float32), | |
| np.array([1.0], dtype=np.float32), | |
| ]) | |
| x = torch.cat([data.x, torch.tensor(robot_feat, dtype=torch.float32).unsqueeze(0)], dim=0) | |
| # ββ Rebuild edges with 3D features ββ | |
| constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]} | |
| pair_forward = {} | |
| for (s, d) in constraint_set: | |
| pair_forward[frozenset([s, d])] = (s, d) | |
| src_idx, dst_idx, edge_attr = [], [], [] | |
| # Products Γ Products | |
| for i in range(N_prod): | |
| for j in range(N_prod): | |
| if i == j: | |
| continue | |
| src_id = products[i]["id"] | |
| dst_id = products[j]["id"] | |
| src_idx.append(i) | |
| dst_idx.append(j) | |
| pair_key = frozenset([src_id, dst_id]) | |
| if pair_key in pair_forward: | |
| forward = pair_forward[pair_key] | |
| is_locked = fd.constraints.get(f"{forward[0]}->{forward[1]}", True) | |
| src_blocks_dst = 1.0 if src_id == forward[0] else 0.0 | |
| edge_attr.append([1.0, 1.0 if is_locked else 0.0, src_blocks_dst]) | |
| else: | |
| edge_attr.append([0.0, 0.0, 0.0]) | |
| # Robot β Products (both directions, message-passing only) | |
| robot_idx = N_prod | |
| for i in range(N_prod): | |
| src_idx.append(robot_idx); dst_idx.append(i); edge_attr.append([0.0, 0.0, 0.0]) | |
| src_idx.append(i); dst_idx.append(robot_idx); edge_attr.append([0.0, 0.0, 0.0]) | |
| data = Data( | |
| x=x, | |
| edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long), | |
| edge_attr=torch.tensor(edge_attr, dtype=torch.float32), | |
| y=torch.tensor([frame_idx], dtype=torch.long), | |
| num_nodes=N, | |
| ) | |
| data.robot_point_cloud = torch.tensor(fd.robot["point_cloud"], dtype=torch.float32) | |
| data.robot_pixel_coords = torch.tensor(fd.robot["pixel_coords"], dtype=torch.int32) | |
| data.robot_mask = torch.tensor(fd.robot["mask"], dtype=torch.uint8) | |
| return data | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| # Episode iterator | |
| # βββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββββ | |
| def iterate_episode(episode_dir: Path, with_robot: bool = True): | |
| """Yield (frame_idx, Data) pairs for all labeled frames in an episode.""" | |
| loader = load_pyg_frame_with_robot if with_robot else load_pyg_frame_products_only | |
| for frame_idx in list_labeled_frames(episode_dir): | |
| yield frame_idx, loader(episode_dir, frame_idx) | |