#!/usr/bin/env python3 """Replay karma-recorded baseline rollouts (mcap+mp4) through the prefix server to collect CLEAN failure prefix shards for the success-vs-fail gate. Karma episode dir (recordings/.../episode_*/): left.mcap / right.mcap JSON /{side}/joint_state -> joint_pos(6)+gripper_pos(1)=7/arm camera_{top,left,right}-images-rgb.mp4 + *-rgb-timestamp.npy State (14) = [left joint_pos(6), left gripper(1), right joint_pos(6), right gripper(1)] (verified against demo observation.state stats — see --check). Server must be in prefix mode pointed at the FAILURE cache dir. One ws connection per episode (clean episode_id grouping). Subsample camera frames by --stride. ./lerobot/.venv/bin/python collect_fail_replay.py --episodes-dir recordings_fail/20260602 --stride 30 """ from __future__ import annotations import argparse, asyncio, glob, json, os import cv2, numpy as np, msgpack, websockets from mcap.reader import make_reader from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text CAMS = ["top", "left", "right"] CAM_OBS = {"top": "top_camera-images-rgb", "left": "left_camera-images-rgb", "right": "right_camera-images-rgb"} def read_arm(mcap_path): """-> (log_time_ns array, state7 array) where state7=[joint_pos(6), gripper_pos(1)].""" ts, st = [], [] with open(mcap_path, "rb") as f: for _sch, _ch, msg in make_reader(f).iter_messages(): d = json.loads(msg.data) jp = d["joint_pos"]; gp = d.get("gripper_pos", 0.0) gp = gp[0] if isinstance(gp, (list, tuple)) else gp ts.append(msg.log_time); st.append(list(jp) + [gp]) return np.array(ts, dtype=np.int64), np.array(st, dtype=np.float32) def nearest(ts_sorted, vals, t_ns): i = int(np.searchsorted(ts_sorted, t_ns)) i = max(0, min(i, len(ts_sorted) - 1)) if i > 0 and abs(ts_sorted[i - 1] - t_ns) < abs(ts_sorted[i] - t_ns): i -= 1 return vals[i] def chw_rgb(bgr): return np.ascontiguousarray(np.transpose(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), (2, 0, 1))) def build_episode_states(ep): lt, lp = read_arm(f"{ep}/left.mcap") rt, rp = read_arm(f"{ep}/right.mcap") cam_ts = np.load(f"{ep}/camera_top-rgb-timestamp.npy") # seconds (reference clock) return lt, lp, rt, rp, cam_ts def state_at(lt, lp, rt, rp, t_sec): t_ns = int(t_sec * 1e9) return np.concatenate([nearest(lt, lp, t_ns), nearest(rt, rp, t_ns)]).astype(np.float32) # (14,) async def replay_episode(uri, ep, stride, task, send=True): lt, lp, rt, rp, cam_ts = build_episode_states(ep) caps = {c: cv2.VideoCapture(f"{ep}/camera_{c}-images-rgb.mp4") for c in CAMS} n = min(int(caps["top"].get(cv2.CAP_PROP_FRAME_COUNT)), len(cam_ts)) sent = 0 ws = await websockets.connect(uri, max_size=None, compression=None) if send else None if ws: await ws.recv() # metadata try: for i in range(n): grabbed = {c: caps[c].read() for c in CAMS} if i % stride != 0: continue if not all(ok for ok, _ in grabbed.values()): break state = state_at(lt, lp, rt, rp, float(cam_ts[i])) obs = {"state": state, "task": task} for c in CAMS: obs[CAM_OBS[c]] = chw_rgb(grabbed[c][1]) if ws: await ws.send(pack(obs)) resp = _decode_text(msgpack.unpackb(await ws.recv(), raw=False, object_hook=_decode_numpy_object)) if "error" in resp: raise RuntimeError(resp["error"]) sent += 1 finally: for c in caps.values(): c.release() if ws: await ws.close() return sent async def main_async(args): eps = sorted(glob.glob(os.path.join(args.episodes_dir, "episode_*"))) print(f"episodes: {len(eps)} stride: {args.stride}") if args.check: # verify state convention vs demo stats lt, lp, rt, rp, cam_ts = build_episode_states(eps[0]) S = np.stack([state_at(lt, lp, rt, rp, float(t)) for t in cam_ts[::50]]) print("constructed state (14) — per-dim min/max:") print(" min:", np.round(S.min(0), 3)); print(" max:", np.round(S.max(0), 3)) print("compare to demo observation.state stats (from meta) to confirm ordering.") return total = 0 for k, ep in enumerate(eps): s = await replay_episode(args.uri, ep, args.stride, args.task) total += s print(f"[{k+1}/{len(eps)}] {os.path.basename(ep)} -> {s} frames sent") print(f"DONE. {total} failure frames replayed -> server cache dir.") def main(): p = argparse.ArgumentParser() p.add_argument("--episodes-dir", default="recordings_fail/20260602") p.add_argument("--uri", default="ws://127.0.0.1:8112") p.add_argument("--stride", type=int, default=30, help="camera frames: 30fps/30 ≈ 1Hz") p.add_argument("--task", default="stack the cubes") p.add_argument("--check", action="store_true", help="just print constructed-state ranges to verify convention") args = p.parse_args() asyncio.run(main_async(args)) if __name__ == "__main__": main()