| |
| """Replay karma-recorded baseline rollouts (mcap+mp4) through the prefix server |
| to collect CLEAN failure prefix shards for the success-vs-fail gate. |
| |
| Karma episode dir (recordings/.../episode_*/): |
| left.mcap / right.mcap JSON /{side}/joint_state -> joint_pos(6)+gripper_pos(1)=7/arm |
| camera_{top,left,right}-images-rgb.mp4 + *-rgb-timestamp.npy |
| State (14) = [left joint_pos(6), left gripper(1), right joint_pos(6), right gripper(1)] |
| (verified against demo observation.state stats — see --check). |
| |
| Server must be in prefix mode pointed at the FAILURE cache dir. One ws connection |
| per episode (clean episode_id grouping). Subsample camera frames by --stride. |
| |
| ./lerobot/.venv/bin/python collect_fail_replay.py --episodes-dir recordings_fail/20260602 --stride 30 |
| """ |
| from __future__ import annotations |
| import argparse, asyncio, glob, json, os |
| import cv2, numpy as np, msgpack, websockets |
| from mcap.reader import make_reader |
| from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text |
|
|
| CAMS = ["top", "left", "right"] |
| CAM_OBS = {"top": "top_camera-images-rgb", "left": "left_camera-images-rgb", "right": "right_camera-images-rgb"} |
|
|
|
|
| def read_arm(mcap_path): |
| """-> (log_time_ns array, state7 array) where state7=[joint_pos(6), gripper_pos(1)].""" |
| ts, st = [], [] |
| with open(mcap_path, "rb") as f: |
| for _sch, _ch, msg in make_reader(f).iter_messages(): |
| d = json.loads(msg.data) |
| jp = d["joint_pos"]; gp = d.get("gripper_pos", 0.0) |
| gp = gp[0] if isinstance(gp, (list, tuple)) else gp |
| ts.append(msg.log_time); st.append(list(jp) + [gp]) |
| return np.array(ts, dtype=np.int64), np.array(st, dtype=np.float32) |
|
|
|
|
| def nearest(ts_sorted, vals, t_ns): |
| i = int(np.searchsorted(ts_sorted, t_ns)) |
| i = max(0, min(i, len(ts_sorted) - 1)) |
| if i > 0 and abs(ts_sorted[i - 1] - t_ns) < abs(ts_sorted[i] - t_ns): |
| i -= 1 |
| return vals[i] |
|
|
|
|
| def chw_rgb(bgr): |
| return np.ascontiguousarray(np.transpose(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), (2, 0, 1))) |
|
|
|
|
| def build_episode_states(ep): |
| lt, lp = read_arm(f"{ep}/left.mcap") |
| rt, rp = read_arm(f"{ep}/right.mcap") |
| cam_ts = np.load(f"{ep}/camera_top-rgb-timestamp.npy") |
| return lt, lp, rt, rp, cam_ts |
|
|
|
|
| def state_at(lt, lp, rt, rp, t_sec): |
| t_ns = int(t_sec * 1e9) |
| return np.concatenate([nearest(lt, lp, t_ns), nearest(rt, rp, t_ns)]).astype(np.float32) |
|
|
|
|
| async def replay_episode(uri, ep, stride, task, send=True): |
| lt, lp, rt, rp, cam_ts = build_episode_states(ep) |
| caps = {c: cv2.VideoCapture(f"{ep}/camera_{c}-images-rgb.mp4") for c in CAMS} |
| n = min(int(caps["top"].get(cv2.CAP_PROP_FRAME_COUNT)), len(cam_ts)) |
| sent = 0 |
| ws = await websockets.connect(uri, max_size=None, compression=None) if send else None |
| if ws: await ws.recv() |
| try: |
| for i in range(n): |
| grabbed = {c: caps[c].read() for c in CAMS} |
| if i % stride != 0: |
| continue |
| if not all(ok for ok, _ in grabbed.values()): |
| break |
| state = state_at(lt, lp, rt, rp, float(cam_ts[i])) |
| obs = {"state": state, "task": task} |
| for c in CAMS: |
| obs[CAM_OBS[c]] = chw_rgb(grabbed[c][1]) |
| if ws: |
| await ws.send(pack(obs)) |
| resp = _decode_text(msgpack.unpackb(await ws.recv(), raw=False, object_hook=_decode_numpy_object)) |
| if "error" in resp: |
| raise RuntimeError(resp["error"]) |
| sent += 1 |
| finally: |
| for c in caps.values(): c.release() |
| if ws: await ws.close() |
| return sent |
|
|
|
|
| async def main_async(args): |
| eps = sorted(glob.glob(os.path.join(args.episodes_dir, "episode_*"))) |
| print(f"episodes: {len(eps)} stride: {args.stride}") |
| if args.check: |
| |
| lt, lp, rt, rp, cam_ts = build_episode_states(eps[0]) |
| S = np.stack([state_at(lt, lp, rt, rp, float(t)) for t in cam_ts[::50]]) |
| print("constructed state (14) — per-dim min/max:") |
| print(" min:", np.round(S.min(0), 3)); print(" max:", np.round(S.max(0), 3)) |
| print("compare to demo observation.state stats (from meta) to confirm ordering.") |
| return |
| total = 0 |
| for k, ep in enumerate(eps): |
| s = await replay_episode(args.uri, ep, args.stride, args.task) |
| total += s |
| print(f"[{k+1}/{len(eps)}] {os.path.basename(ep)} -> {s} frames sent") |
| print(f"DONE. {total} failure frames replayed -> server cache dir.") |
|
|
|
|
| def main(): |
| p = argparse.ArgumentParser() |
| p.add_argument("--episodes-dir", default="recordings_fail/20260602") |
| p.add_argument("--uri", default="ws://127.0.0.1:8112") |
| p.add_argument("--stride", type=int, default=30, help="camera frames: 30fps/30 ≈ 1Hz") |
| p.add_argument("--task", default="stack the cubes") |
| p.add_argument("--check", action="store_true", help="just print constructed-state ranges to verify convention") |
| args = p.parse_args() |
| asyncio.run(main_async(args)) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|