#!/usr/bin/env python3 """Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training. This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly — it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be running in PREFIX-CACHE mode: MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \ MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \ MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \ ./run_finetune_server.sh --default-task "stack the cubes" The server caches EVERY inference it receives (server STRIDE=1); this script does the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is sent over its OWN ws connection so the server's next_episode() bumps episode_id → shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate. Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra): * observation.state / task ← data parquet * camera frames ← h264 mp4s, decoded sequentially with cv2 and sliced per episode by `length` (exact alignment). Run (after the server is up in prefix mode): ./lerobot/.venv/bin/python collect_prefix.py --stride 5 """ from __future__ import annotations import argparse import asyncio import os import time import cv2 import msgpack import numpy as np import pyarrow.parquet as pq import websockets # Reuse the server's exact msgpack wire codecs so the encoding matches 1:1. from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"] CAM_TO_OBS = { # dataset key -> karma obs key the server expects "observation.images.top": "top_camera-images-rgb", "observation.images.left": "left_camera-images-rgb", "observation.images.right": "right_camera-images-rgb", } SNAPSHOT = os.path.expanduser( "~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/" "snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96" ) def _decode_response(raw: bytes) -> dict: obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object) return _decode_text(obj) def video_path(cam: str, file_index: int) -> str: return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4") def load_metadata(): """Return (states[N,14], tasks_per_episode, episodes[list of dicts]).""" data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict() states = np.asarray(data["observation.state"], dtype=np.float32) # (N, 14) ep = pq.read_table( os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet") ).to_pydict() episodes = [] for i in range(len(ep["episode_index"])): cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS} cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS} # all cameras should share the same file split; we read sequentially so # only the ordering (by from_timestamp within a file) needs to agree. episodes.append({ "episode_index": int(ep["episode_index"][i]), "length": int(ep["length"][i]), "from": int(ep["dataset_from_index"][i]), "to": int(ep["dataset_to_index"][i]), "task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes", "file_index": cam_file, "from_ts": cam_ts, }) return states, episodes def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray: """cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants).""" rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB) return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1))) async def send_episode(uri, ep, states, caps, stride, dtype_note): """Stream one demo's frames over a fresh ws connection (→ one server episode).""" length, base = ep["length"], ep["from"] sent = 0 async with websockets.connect(uri, max_size=None, compression=None) as ws: await ws.recv() # discard the server's metadata frame for i in range(length): send_this = (i % stride == 0) frames = {} for c in CAMS: if send_this: ok, fr = caps[c].read() # grab + decode else: ok = caps[c].grab(); fr = None # advance only, no decode if not ok: raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}") frames[c] = fr if not send_this: continue obs = {"state": states[base + i], "task": ep["task"]} for c in CAMS: obs[CAM_TO_OBS[c]] = chw_rgb(frames[c]) await ws.send(pack(obs)) resp = _decode_response(await ws.recv()) if "error" in resp: raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}") sent += 1 return sent async def main_async(args): states, episodes = load_metadata() if args.max_episodes: episodes = episodes[: args.max_episodes] n_frames_total = sum(ep["length"] for ep in episodes) est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes) print(f"episodes: {len(episodes)} frames: {n_frames_total} stride: {args.stride} " f"→ ~{est_shards} shards") # Group by camera file_index, read each mp4 in episode order (= from_timestamp order). by_file: dict[int, list] = {} for ep in episodes: by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep) for fidx in by_file: by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]]) t0 = time.time() done = 0 for fidx, eps in sorted(by_file.items()): # open per-camera captures for THIS file group caps = {} for c in CAMS: caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c])) if not caps[c].isOpened(): raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}") try: for ep in eps: sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype) done += 1 rate = done / (time.time() - t0) print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} " f"len={ep['length']} sent={sent} ({rate*60:.1f} demos/min)") finally: for c in caps.values(): c.release() print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. " f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.") def main(): p = argparse.ArgumentParser() p.add_argument("--uri", default="ws://127.0.0.1:8112") p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)") p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)") p.add_argument("--dtype", default="float16") args = p.parse_args() asyncio.run(main_async(args)) if __name__ == "__main__": main()