atharva-pantheon commited on
Commit
20b27fa
·
verified ·
1 Parent(s): 371dfea

Upload code/collect_prefix.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. code/collect_prefix.py +176 -0
code/collect_prefix.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to
3
+ collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training.
4
+
5
+ This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly —
6
+ it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be
7
+ running in PREFIX-CACHE mode:
8
+
9
+ MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \
10
+ MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \
11
+ MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \
12
+ ./run_finetune_server.sh --default-task "stack the cubes"
13
+
14
+ The server caches EVERY inference it receives (server STRIDE=1); this script does
15
+ the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is
16
+ sent over its OWN ws connection so the server's next_episode() bumps episode_id
17
+ → shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate.
18
+
19
+ Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra):
20
+ * observation.state / task ← data parquet
21
+ * camera frames ← h264 mp4s, decoded sequentially with cv2 and
22
+ sliced per episode by `length` (exact alignment).
23
+
24
+ Run (after the server is up in prefix mode):
25
+ ./lerobot/.venv/bin/python collect_prefix.py --stride 5
26
+ """
27
+ from __future__ import annotations
28
+
29
+ import argparse
30
+ import asyncio
31
+ import os
32
+ import time
33
+
34
+ import cv2
35
+ import msgpack
36
+ import numpy as np
37
+ import pyarrow.parquet as pq
38
+ import websockets
39
+
40
+ # Reuse the server's exact msgpack wire codecs so the encoding matches 1:1.
41
+ from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text
42
+
43
+ CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"]
44
+ CAM_TO_OBS = { # dataset key -> karma obs key the server expects
45
+ "observation.images.top": "top_camera-images-rgb",
46
+ "observation.images.left": "left_camera-images-rgb",
47
+ "observation.images.right": "right_camera-images-rgb",
48
+ }
49
+ SNAPSHOT = os.path.expanduser(
50
+ "~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/"
51
+ "snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96"
52
+ )
53
+
54
+
55
+ def _decode_response(raw: bytes) -> dict:
56
+ obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object)
57
+ return _decode_text(obj)
58
+
59
+
60
+ def video_path(cam: str, file_index: int) -> str:
61
+ return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4")
62
+
63
+
64
+ def load_metadata():
65
+ """Return (states[N,14], tasks_per_episode, episodes[list of dicts])."""
66
+ data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict()
67
+ states = np.asarray(data["observation.state"], dtype=np.float32) # (N, 14)
68
+
69
+ ep = pq.read_table(
70
+ os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet")
71
+ ).to_pydict()
72
+ episodes = []
73
+ for i in range(len(ep["episode_index"])):
74
+ cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS}
75
+ cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS}
76
+ # all cameras should share the same file split; we read sequentially so
77
+ # only the ordering (by from_timestamp within a file) needs to agree.
78
+ episodes.append({
79
+ "episode_index": int(ep["episode_index"][i]),
80
+ "length": int(ep["length"][i]),
81
+ "from": int(ep["dataset_from_index"][i]),
82
+ "to": int(ep["dataset_to_index"][i]),
83
+ "task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes",
84
+ "file_index": cam_file,
85
+ "from_ts": cam_ts,
86
+ })
87
+ return states, episodes
88
+
89
+
90
+ def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray:
91
+ """cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants)."""
92
+ rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
93
+ return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)))
94
+
95
+
96
+ async def send_episode(uri, ep, states, caps, stride, dtype_note):
97
+ """Stream one demo's frames over a fresh ws connection (→ one server episode)."""
98
+ length, base = ep["length"], ep["from"]
99
+ sent = 0
100
+ async with websockets.connect(uri, max_size=None, compression=None) as ws:
101
+ await ws.recv() # discard the server's metadata frame
102
+ for i in range(length):
103
+ send_this = (i % stride == 0)
104
+ frames = {}
105
+ for c in CAMS:
106
+ if send_this:
107
+ ok, fr = caps[c].read() # grab + decode
108
+ else:
109
+ ok = caps[c].grab(); fr = None # advance only, no decode
110
+ if not ok:
111
+ raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}")
112
+ frames[c] = fr
113
+ if not send_this:
114
+ continue
115
+ obs = {"state": states[base + i], "task": ep["task"]}
116
+ for c in CAMS:
117
+ obs[CAM_TO_OBS[c]] = chw_rgb(frames[c])
118
+ await ws.send(pack(obs))
119
+ resp = _decode_response(await ws.recv())
120
+ if "error" in resp:
121
+ raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}")
122
+ sent += 1
123
+ return sent
124
+
125
+
126
+ async def main_async(args):
127
+ states, episodes = load_metadata()
128
+ if args.max_episodes:
129
+ episodes = episodes[: args.max_episodes]
130
+ n_frames_total = sum(ep["length"] for ep in episodes)
131
+ est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes)
132
+ print(f"episodes: {len(episodes)} frames: {n_frames_total} stride: {args.stride} "
133
+ f"→ ~{est_shards} shards")
134
+
135
+ # Group by camera file_index, read each mp4 in episode order (= from_timestamp order).
136
+ by_file: dict[int, list] = {}
137
+ for ep in episodes:
138
+ by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep)
139
+ for fidx in by_file:
140
+ by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]])
141
+
142
+ t0 = time.time()
143
+ done = 0
144
+ for fidx, eps in sorted(by_file.items()):
145
+ # open per-camera captures for THIS file group
146
+ caps = {}
147
+ for c in CAMS:
148
+ caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c]))
149
+ if not caps[c].isOpened():
150
+ raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}")
151
+ try:
152
+ for ep in eps:
153
+ sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype)
154
+ done += 1
155
+ rate = done / (time.time() - t0)
156
+ print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} "
157
+ f"len={ep['length']} sent={sent} ({rate*60:.1f} demos/min)")
158
+ finally:
159
+ for c in caps.values():
160
+ c.release()
161
+ print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. "
162
+ f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.")
163
+
164
+
165
+ def main():
166
+ p = argparse.ArgumentParser()
167
+ p.add_argument("--uri", default="ws://127.0.0.1:8112")
168
+ p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)")
169
+ p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)")
170
+ p.add_argument("--dtype", default="float16")
171
+ args = p.parse_args()
172
+ asyncio.run(main_async(args))
173
+
174
+
175
+ if __name__ == "__main__":
176
+ main()