File size: 13,535 Bytes
4ee0c8c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
"""Self-contained PyG loader for the GNN Disassembly dataset.

Two loader variants:
  - load_pyg_frame_products_only(ep, frame)  β†’ constraint graph only, no robot
  - load_pyg_frame_with_robot(ep, frame)     β†’ constraint graph + robot agent node

Both return torch_geometric.data.Data with:
  x            (N, 268)      node features
  edge_index   (2, N*(N-1))  fully connected directed message-passing edges
  edge_attr    (N*(N-1), 3)  [has_constraint, is_locked, src_blocks_dst]
  num_nodes    N

Notes on the edge feature design:
- The graph is FULLY CONNECTED and structurally symmetric.
  Both (i, j) and (j, i) exist in edge_index for every node pair i != j.
- Direction is NOT encoded in the graph structure. It is encoded as
  a feature: `src_blocks_dst`.
- `has_constraint` and `is_locked` are symmetric per pair (same value
  for both (i, j) and (j, i)).
- `src_blocks_dst` is asymmetric: it is 1 if the edge's src node
  physically blocks its dst node, 0 otherwise.
"""

import json
from dataclasses import dataclass
from pathlib import Path
from typing import Dict, List, Optional, Tuple

import numpy as np
import torch
from torch_geometric.data import Data


# ─────────────────────────────────────────────────────────────────────────────
# Helpers
# ─────────────────────────────────────────────────────────────────────────────

def list_labeled_frames(episode_dir: Path) -> List[int]:
    """Return sorted list of frame indices that have saved annotations."""
    mask_dir = episode_dir / "annotations" / "side_masks"
    if not mask_dir.exists():
        return []
    frames = []
    for p in mask_dir.glob("frame_*.npz"):
        try:
            frames.append(int(p.stem.split("_")[1]))
        except (ValueError, IndexError):
            continue
    return sorted(frames)


def resolve_frame_state(graph_json: dict, frame_idx: int) -> Tuple[Dict[str, bool], Dict[str, bool]]:
    """Resolve delta-encoded constraints + visibility at a frame.

    Walks frame_states from frame 0 to frame_idx, accumulating deltas.
    Returns (constraints_dict, visibility_dict).
    """
    constraints: Dict[str, bool] = {}
    visibility: Dict[str, bool] = {}
    # Defaults: every component visible, every edge locked
    for c in graph_json["components"]:
        visibility[c["id"]] = True
    for e in graph_json["edges"]:
        constraints[f"{e['src']}->{e['dst']}"] = True
    # Walk deltas up to frame_idx
    fs_dict = graph_json.get("frame_states", {})
    for f in sorted([int(k) for k in fs_dict]):
        if f > frame_idx:
            break
        fs = fs_dict[str(f)]
        for k, v in fs.get("constraints", {}).items():
            constraints[k] = v
        for k, v in fs.get("visibility", {}).items():
            visibility[k] = v
    return constraints, visibility


def type_one_hot(comp_type: str, type_vocab: List[str]) -> List[float]:
    """9-dim one-hot encoding of component type based on type_vocab."""
    return [1.0 if t == comp_type else 0.0 for t in type_vocab]


# ─────────────────────────────────────────────────────────────────────────────
# Raw data loader (NumPy only, no torch)
# ─────────────────────────────────────────────────────────────────────────────

@dataclass
class FrameData:
    graph: dict
    masks: Dict[str, np.ndarray]
    embeddings: Dict[str, np.ndarray]
    depth_info: dict
    robot: Optional[dict]
    constraints: Dict[str, bool]
    visibility: Dict[str, bool]


def load_frame_data(episode_dir: Path, frame_idx: int) -> FrameData:
    """Load all v3 annotation files for one frame."""
    anno = episode_dir / "annotations"

    with open(anno / "side_graph.json") as f:
        graph = json.load(f)

    def _load_npz_dict(path: Path) -> Dict[str, np.ndarray]:
        if not path.exists():
            return {}
        d = np.load(path)
        return {k: d[k] for k in d.files}

    masks = _load_npz_dict(anno / "side_masks" / f"frame_{frame_idx:06d}.npz")
    embeddings = _load_npz_dict(anno / "side_embeddings" / f"frame_{frame_idx:06d}.npz")
    depth_info = _load_npz_dict(anno / "side_depth_info" / f"frame_{frame_idx:06d}.npz")

    robot: Optional[dict] = None
    robot_path = anno / "side_robot" / f"frame_{frame_idx:06d}.npz"
    if robot_path.exists():
        r = np.load(robot_path)
        if r["visible"][0] == 1:
            robot = {k: r[k] for k in r.files}

    constraints, visibility = resolve_frame_state(graph, frame_idx)
    return FrameData(graph, masks, embeddings, depth_info, robot, constraints, visibility)


# ─────────────────────────────────────────────────────────────────────────────
# PyG loader β€” products only
# ─────────────────────────────────────────────────────────────────────────────

def load_pyg_frame_products_only(episode_dir: Path, frame_idx: int) -> Data:
    """Fully connected PyG graph WITHOUT robot.

    Returns Data(
        x=[N, 268],
        edge_index=[2, N*(N-1)],
        edge_attr=[N*(N-1), 3],   # [has_constraint, is_locked, src_blocks_dst]
        num_nodes=N,
    )
    where N = number of product components (robot excluded).
    """
    fd = load_frame_data(episode_dir, frame_idx)
    graph = fd.graph
    type_vocab = graph["type_vocab"]  # 9 entries incl. robot
    nodes = graph["components"]       # robot already excluded per spec
    N = len(nodes)

    # ── Node features ──
    # [256D SAM2 embedding, 3D position, 9D type one-hot, 1D visibility] = 269
    # NOTE: 256 + 3 + 9 + 1 = 269 (not 268). Adjust if you need a different layout.
    x_list = []
    for node in nodes:
        cid = node["id"]
        emb = fd.embeddings.get(cid, np.zeros(256, dtype=np.float32))

        depth_valid_key = f"{cid}_depth_valid"
        centroid_key = f"{cid}_centroid"
        if (depth_valid_key in fd.depth_info
                and int(fd.depth_info[depth_valid_key][0]) == 1):
            pos = fd.depth_info[centroid_key].astype(np.float32)
        else:
            pos = np.zeros(3, dtype=np.float32)

        type_oh = type_one_hot(node["type"], type_vocab)  # 9D
        vis = 1.0 if fd.visibility.get(cid, True) else 0.0

        feat = np.concatenate([
            emb.astype(np.float32),
            pos,
            np.array(type_oh, dtype=np.float32),
            np.array([vis], dtype=np.float32),
        ])
        x_list.append(feat)
    x = torch.tensor(np.stack(x_list), dtype=torch.float32) if x_list else torch.empty((0, 269))

    # ── Fully connected edges with 3D features ──
    # Edge feature: [has_constraint, is_locked, src_blocks_dst]
    # - has_constraint & is_locked are SYMMETRIC for the pair (A, B)
    # - src_blocks_dst is ASYMMETRIC: 1 if edge's src physically blocks dst
    constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]}
    pair_forward = {}  # frozenset({a, b}) -> (blocker, blocked)
    for (s, d) in constraint_set:
        pair_forward[frozenset([s, d])] = (s, d)

    src_idx, dst_idx, edge_attr = [], [], []
    for i in range(N):
        for j in range(N):
            if i == j:
                continue
            src_id = nodes[i]["id"]
            dst_id = nodes[j]["id"]
            src_idx.append(i)
            dst_idx.append(j)

            pair_key = frozenset([src_id, dst_id])
            if pair_key in pair_forward:
                forward = pair_forward[pair_key]
                constraint_key = f"{forward[0]}->{forward[1]}"
                is_locked = fd.constraints.get(constraint_key, True)
                src_blocks_dst = 1.0 if src_id == forward[0] else 0.0
                edge_attr.append([
                    1.0,
                    1.0 if is_locked else 0.0,
                    src_blocks_dst,
                ])
            else:
                edge_attr.append([0.0, 0.0, 0.0])  # message passing only

    return Data(
        x=x,
        edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
        edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
        y=torch.tensor([frame_idx], dtype=torch.long),
        num_nodes=N,
    )


# ─────────────────────────────────────────────────────────────────────────────
# PyG loader β€” with robot agent node
# ─────────────────────────────────────────────────────────────────────────────

def load_pyg_frame_with_robot(episode_dir: Path, frame_idx: int) -> Data:
    """Fully connected PyG graph WITH robot appended as agent node.

    Robot is node N (the last node). All edges involving the robot have
    features [0, 0, 0] because the robot has no physical constraints.

    If the robot is not visible at this frame, returns the products-only graph.
    Additional attached tensors when robot is visible:
        data.robot_point_cloud  (M, 3) float32
        data.robot_pixel_coords (M, 2) int32
        data.robot_mask         (H, W) uint8
    """
    data = load_pyg_frame_products_only(episode_dir, frame_idx)
    fd = load_frame_data(episode_dir, frame_idx)
    if fd.robot is None:
        return data

    graph = fd.graph
    type_vocab = graph["type_vocab"]
    products = graph["components"]
    N_prod = len(products)
    N = N_prod + 1

    # ── Build robot node features ──
    robot_emb = fd.robot["embedding"].astype(np.float32)
    robot_pos = (fd.robot["centroid"].astype(np.float32)
                 if int(fd.robot["depth_valid"][0]) == 1
                 else np.zeros(3, dtype=np.float32))
    robot_type_oh = type_one_hot("robot", type_vocab)
    robot_feat = np.concatenate([
        robot_emb, robot_pos,
        np.array(robot_type_oh, dtype=np.float32),
        np.array([1.0], dtype=np.float32),
    ])
    x = torch.cat([data.x, torch.tensor(robot_feat, dtype=torch.float32).unsqueeze(0)], dim=0)

    # ── Rebuild edges with 3D features ──
    constraint_set = {(e["src"], e["dst"]) for e in graph["edges"]}
    pair_forward = {}
    for (s, d) in constraint_set:
        pair_forward[frozenset([s, d])] = (s, d)

    src_idx, dst_idx, edge_attr = [], [], []

    # Products Γ— Products
    for i in range(N_prod):
        for j in range(N_prod):
            if i == j:
                continue
            src_id = products[i]["id"]
            dst_id = products[j]["id"]
            src_idx.append(i)
            dst_idx.append(j)
            pair_key = frozenset([src_id, dst_id])
            if pair_key in pair_forward:
                forward = pair_forward[pair_key]
                is_locked = fd.constraints.get(f"{forward[0]}->{forward[1]}", True)
                src_blocks_dst = 1.0 if src_id == forward[0] else 0.0
                edge_attr.append([1.0, 1.0 if is_locked else 0.0, src_blocks_dst])
            else:
                edge_attr.append([0.0, 0.0, 0.0])

    # Robot ↔ Products (both directions, message-passing only)
    robot_idx = N_prod
    for i in range(N_prod):
        src_idx.append(robot_idx); dst_idx.append(i); edge_attr.append([0.0, 0.0, 0.0])
        src_idx.append(i); dst_idx.append(robot_idx); edge_attr.append([0.0, 0.0, 0.0])

    data = Data(
        x=x,
        edge_index=torch.tensor([src_idx, dst_idx], dtype=torch.long),
        edge_attr=torch.tensor(edge_attr, dtype=torch.float32),
        y=torch.tensor([frame_idx], dtype=torch.long),
        num_nodes=N,
    )
    data.robot_point_cloud = torch.tensor(fd.robot["point_cloud"], dtype=torch.float32)
    data.robot_pixel_coords = torch.tensor(fd.robot["pixel_coords"], dtype=torch.int32)
    data.robot_mask = torch.tensor(fd.robot["mask"], dtype=torch.uint8)
    return data


# ─────────────────────────────────────────────────────────────────────────────
# Episode iterator
# ─────────────────────────────────────────────────────────────────────────────

def iterate_episode(episode_dir: Path, with_robot: bool = True):
    """Yield (frame_idx, Data) pairs for all labeled frames in an episode."""
    loader = load_pyg_frame_with_robot if with_robot else load_pyg_frame_products_only
    for frame_idx in list_labeled_frames(episode_dir):
        yield frame_idx, loader(episode_dir, frame_idx)