| |
| """Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to |
| collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training. |
| |
| This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly — |
| it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be |
| running in PREFIX-CACHE mode: |
| |
| MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \ |
| MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \ |
| MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \ |
| ./run_finetune_server.sh --default-task "stack the cubes" |
| |
| The server caches EVERY inference it receives (server STRIDE=1); this script does |
| the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is |
| sent over its OWN ws connection so the server's next_episode() bumps episode_id |
| → shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate. |
| |
| Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra): |
| * observation.state / task ← data parquet |
| * camera frames ← h264 mp4s, decoded sequentially with cv2 and |
| sliced per episode by `length` (exact alignment). |
| |
| Run (after the server is up in prefix mode): |
| ./lerobot/.venv/bin/python collect_prefix.py --stride 5 |
| """ |
| from __future__ import annotations |
|
|
| import argparse |
| import asyncio |
| import os |
| import time |
|
|
| import cv2 |
| import msgpack |
| import numpy as np |
| import pyarrow.parquet as pq |
| import websockets |
|
|
| |
| from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text |
|
|
| CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"] |
| CAM_TO_OBS = { |
| "observation.images.top": "top_camera-images-rgb", |
| "observation.images.left": "left_camera-images-rgb", |
| "observation.images.right": "right_camera-images-rgb", |
| } |
| SNAPSHOT = os.path.expanduser( |
| "~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/" |
| "snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96" |
| ) |
|
|
|
|
| def _decode_response(raw: bytes) -> dict: |
| obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object) |
| return _decode_text(obj) |
|
|
|
|
| def video_path(cam: str, file_index: int) -> str: |
| return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4") |
|
|
|
|
| def load_metadata(): |
| """Return (states[N,14], tasks_per_episode, episodes[list of dicts]).""" |
| data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict() |
| states = np.asarray(data["observation.state"], dtype=np.float32) |
|
|
| ep = pq.read_table( |
| os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet") |
| ).to_pydict() |
| episodes = [] |
| for i in range(len(ep["episode_index"])): |
| cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS} |
| cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS} |
| |
| |
| episodes.append({ |
| "episode_index": int(ep["episode_index"][i]), |
| "length": int(ep["length"][i]), |
| "from": int(ep["dataset_from_index"][i]), |
| "to": int(ep["dataset_to_index"][i]), |
| "task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes", |
| "file_index": cam_file, |
| "from_ts": cam_ts, |
| }) |
| return states, episodes |
|
|
|
|
| def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray: |
| """cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants).""" |
| rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) |
| return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1))) |
|
|
|
|
| async def send_episode(uri, ep, states, caps, stride, dtype_note): |
| """Stream one demo's frames over a fresh ws connection (→ one server episode).""" |
| length, base = ep["length"], ep["from"] |
| sent = 0 |
| async with websockets.connect(uri, max_size=None, compression=None) as ws: |
| await ws.recv() |
| for i in range(length): |
| send_this = (i % stride == 0) |
| frames = {} |
| for c in CAMS: |
| if send_this: |
| ok, fr = caps[c].read() |
| else: |
| ok = caps[c].grab(); fr = None |
| if not ok: |
| raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}") |
| frames[c] = fr |
| if not send_this: |
| continue |
| obs = {"state": states[base + i], "task": ep["task"]} |
| for c in CAMS: |
| obs[CAM_TO_OBS[c]] = chw_rgb(frames[c]) |
| await ws.send(pack(obs)) |
| resp = _decode_response(await ws.recv()) |
| if "error" in resp: |
| raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}") |
| sent += 1 |
| return sent |
|
|
|
|
| async def main_async(args): |
| states, episodes = load_metadata() |
| if args.max_episodes: |
| episodes = episodes[: args.max_episodes] |
| n_frames_total = sum(ep["length"] for ep in episodes) |
| est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes) |
| print(f"episodes: {len(episodes)} frames: {n_frames_total} stride: {args.stride} " |
| f"→ ~{est_shards} shards") |
|
|
| |
| by_file: dict[int, list] = {} |
| for ep in episodes: |
| by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep) |
| for fidx in by_file: |
| by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]]) |
|
|
| t0 = time.time() |
| done = 0 |
| for fidx, eps in sorted(by_file.items()): |
| |
| caps = {} |
| for c in CAMS: |
| caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c])) |
| if not caps[c].isOpened(): |
| raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}") |
| try: |
| for ep in eps: |
| sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype) |
| done += 1 |
| rate = done / (time.time() - t0) |
| print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} " |
| f"len={ep['length']} sent={sent} ({rate*60:.1f} demos/min)") |
| finally: |
| for c in caps.values(): |
| c.release() |
| print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. " |
| f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.") |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--uri", default="ws://127.0.0.1:8112") |
| p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)") |
| p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)") |
| p.add_argument("--dtype", default="float16") |
| args = p.parse_args() |
| asyncio.run(main_async(args)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|