rltoken-encoder / code /collect_prefix.py
atharva-pantheon's picture
Upload code/collect_prefix.py with huggingface_hub
20b27fa verified
#!/usr/bin/env python3
"""Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to
collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training.
This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly —
it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be
running in PREFIX-CACHE mode:
MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \
MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \
MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \
./run_finetune_server.sh --default-task "stack the cubes"
The server caches EVERY inference it receives (server STRIDE=1); this script does
the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is
sent over its OWN ws connection so the server's next_episode() bumps episode_id
→ shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate.
Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra):
* observation.state / task ← data parquet
* camera frames ← h264 mp4s, decoded sequentially with cv2 and
sliced per episode by `length` (exact alignment).
Run (after the server is up in prefix mode):
./lerobot/.venv/bin/python collect_prefix.py --stride 5
"""
from __future__ import annotations
import argparse
import asyncio
import os
import time
import cv2
import msgpack
import numpy as np
import pyarrow.parquet as pq
import websockets
# Reuse the server's exact msgpack wire codecs so the encoding matches 1:1.
from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text
CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"]
CAM_TO_OBS = { # dataset key -> karma obs key the server expects
"observation.images.top": "top_camera-images-rgb",
"observation.images.left": "left_camera-images-rgb",
"observation.images.right": "right_camera-images-rgb",
}
SNAPSHOT = os.path.expanduser(
"~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/"
"snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96"
)
def _decode_response(raw: bytes) -> dict:
obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object)
return _decode_text(obj)
def video_path(cam: str, file_index: int) -> str:
return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4")
def load_metadata():
"""Return (states[N,14], tasks_per_episode, episodes[list of dicts])."""
data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict()
states = np.asarray(data["observation.state"], dtype=np.float32) # (N, 14)
ep = pq.read_table(
os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet")
).to_pydict()
episodes = []
for i in range(len(ep["episode_index"])):
cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS}
cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS}
# all cameras should share the same file split; we read sequentially so
# only the ordering (by from_timestamp within a file) needs to agree.
episodes.append({
"episode_index": int(ep["episode_index"][i]),
"length": int(ep["length"][i]),
"from": int(ep["dataset_from_index"][i]),
"to": int(ep["dataset_to_index"][i]),
"task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes",
"file_index": cam_file,
"from_ts": cam_ts,
})
return states, episodes
def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray:
"""cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants)."""
rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)))
async def send_episode(uri, ep, states, caps, stride, dtype_note):
"""Stream one demo's frames over a fresh ws connection (→ one server episode)."""
length, base = ep["length"], ep["from"]
sent = 0
async with websockets.connect(uri, max_size=None, compression=None) as ws:
await ws.recv() # discard the server's metadata frame
for i in range(length):
send_this = (i % stride == 0)
frames = {}
for c in CAMS:
if send_this:
ok, fr = caps[c].read() # grab + decode
else:
ok = caps[c].grab(); fr = None # advance only, no decode
if not ok:
raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}")
frames[c] = fr
if not send_this:
continue
obs = {"state": states[base + i], "task": ep["task"]}
for c in CAMS:
obs[CAM_TO_OBS[c]] = chw_rgb(frames[c])
await ws.send(pack(obs))
resp = _decode_response(await ws.recv())
if "error" in resp:
raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}")
sent += 1
return sent
async def main_async(args):
states, episodes = load_metadata()
if args.max_episodes:
episodes = episodes[: args.max_episodes]
n_frames_total = sum(ep["length"] for ep in episodes)
est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes)
print(f"episodes: {len(episodes)} frames: {n_frames_total} stride: {args.stride} "
f"→ ~{est_shards} shards")
# Group by camera file_index, read each mp4 in episode order (= from_timestamp order).
by_file: dict[int, list] = {}
for ep in episodes:
by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep)
for fidx in by_file:
by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]])
t0 = time.time()
done = 0
for fidx, eps in sorted(by_file.items()):
# open per-camera captures for THIS file group
caps = {}
for c in CAMS:
caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c]))
if not caps[c].isOpened():
raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}")
try:
for ep in eps:
sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype)
done += 1
rate = done / (time.time() - t0)
print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} "
f"len={ep['length']} sent={sent} ({rate*60:.1f} demos/min)")
finally:
for c in caps.values():
c.release()
print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. "
f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.")
def main():
p = argparse.ArgumentParser()
p.add_argument("--uri", default="ws://127.0.0.1:8112")
p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)")
p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)")
p.add_argument("--dtype", default="float16")
args = p.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()