File size: 7,388 Bytes
20b27fa | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | #!/usr/bin/env python3
"""Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to
collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training.
This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly —
it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be
running in PREFIX-CACHE mode:
MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \
MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \
MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \
./run_finetune_server.sh --default-task "stack the cubes"
The server caches EVERY inference it receives (server STRIDE=1); this script does
the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is
sent over its OWN ws connection so the server's next_episode() bumps episode_id
→ shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate.
Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra):
* observation.state / task ← data parquet
* camera frames ← h264 mp4s, decoded sequentially with cv2 and
sliced per episode by `length` (exact alignment).
Run (after the server is up in prefix mode):
./lerobot/.venv/bin/python collect_prefix.py --stride 5
"""
from __future__ import annotations
import argparse
import asyncio
import os
import time
import cv2
import msgpack
import numpy as np
import pyarrow.parquet as pq
import websockets
# Reuse the server's exact msgpack wire codecs so the encoding matches 1:1.
from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text
CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"]
CAM_TO_OBS = { # dataset key -> karma obs key the server expects
"observation.images.top": "top_camera-images-rgb",
"observation.images.left": "left_camera-images-rgb",
"observation.images.right": "right_camera-images-rgb",
}
SNAPSHOT = os.path.expanduser(
"~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/"
"snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96"
)
def _decode_response(raw: bytes) -> dict:
obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object)
return _decode_text(obj)
def video_path(cam: str, file_index: int) -> str:
return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4")
def load_metadata():
"""Return (states[N,14], tasks_per_episode, episodes[list of dicts])."""
data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict()
states = np.asarray(data["observation.state"], dtype=np.float32) # (N, 14)
ep = pq.read_table(
os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet")
).to_pydict()
episodes = []
for i in range(len(ep["episode_index"])):
cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS}
cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS}
# all cameras should share the same file split; we read sequentially so
# only the ordering (by from_timestamp within a file) needs to agree.
episodes.append({
"episode_index": int(ep["episode_index"][i]),
"length": int(ep["length"][i]),
"from": int(ep["dataset_from_index"][i]),
"to": int(ep["dataset_to_index"][i]),
"task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes",
"file_index": cam_file,
"from_ts": cam_ts,
})
return states, episodes
def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray:
"""cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants)."""
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)))
async def send_episode(uri, ep, states, caps, stride, dtype_note):
"""Stream one demo's frames over a fresh ws connection (→ one server episode)."""
length, base = ep["length"], ep["from"]
sent = 0
async with websockets.connect(uri, max_size=None, compression=None) as ws:
await ws.recv() # discard the server's metadata frame
for i in range(length):
send_this = (i % stride == 0)
frames = {}
for c in CAMS:
if send_this:
ok, fr = caps[c].read() # grab + decode
else:
ok = caps[c].grab(); fr = None # advance only, no decode
if not ok:
raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}")
frames[c] = fr
if not send_this:
continue
obs = {"state": states[base + i], "task": ep["task"]}
for c in CAMS:
obs[CAM_TO_OBS[c]] = chw_rgb(frames[c])
await ws.send(pack(obs))
resp = _decode_response(await ws.recv())
if "error" in resp:
raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}")
sent += 1
return sent
async def main_async(args):
states, episodes = load_metadata()
if args.max_episodes:
episodes = episodes[: args.max_episodes]
n_frames_total = sum(ep["length"] for ep in episodes)
est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes)
print(f"episodes: {len(episodes)} frames: {n_frames_total} stride: {args.stride} "
f"→ ~{est_shards} shards")
# Group by camera file_index, read each mp4 in episode order (= from_timestamp order).
by_file: dict[int, list] = {}
for ep in episodes:
by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep)
for fidx in by_file:
by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]])
t0 = time.time()
done = 0
for fidx, eps in sorted(by_file.items()):
# open per-camera captures for THIS file group
caps = {}
for c in CAMS:
caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c]))
if not caps[c].isOpened():
raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}")
try:
for ep in eps:
sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype)
done += 1
rate = done / (time.time() - t0)
print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} "
f"len={ep['length']} sent={sent} ({rate*60:.1f} demos/min)")
finally:
for c in caps.values():
c.release()
print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. "
f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.")
def main():
p = argparse.ArgumentParser()
p.add_argument("--uri", default="ws://127.0.0.1:8112")
p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)")
p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)")
p.add_argument("--dtype", default="float16")
args = p.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()
|