rltoken-encoder / code /collect_fail_replay.py
atharva-pantheon's picture
Upload code/collect_fail_replay.py with huggingface_hub
edbd124 verified
#!/usr/bin/env python3
"""Replay karma-recorded baseline rollouts (mcap+mp4) through the prefix server
to collect CLEAN failure prefix shards for the success-vs-fail gate.
Karma episode dir (recordings/.../episode_*/):
left.mcap / right.mcap JSON /{side}/joint_state -> joint_pos(6)+gripper_pos(1)=7/arm
camera_{top,left,right}-images-rgb.mp4 + *-rgb-timestamp.npy
State (14) = [left joint_pos(6), left gripper(1), right joint_pos(6), right gripper(1)]
(verified against demo observation.state stats — see --check).
Server must be in prefix mode pointed at the FAILURE cache dir. One ws connection
per episode (clean episode_id grouping). Subsample camera frames by --stride.
./lerobot/.venv/bin/python collect_fail_replay.py --episodes-dir recordings_fail/20260602 --stride 30
"""
from __future__ import annotations
import argparse, asyncio, glob, json, os
import cv2, numpy as np, msgpack, websockets
from mcap.reader import make_reader
from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text
CAMS = ["top", "left", "right"]
CAM_OBS = {"top": "top_camera-images-rgb", "left": "left_camera-images-rgb", "right": "right_camera-images-rgb"}
def read_arm(mcap_path):
"""-> (log_time_ns array, state7 array) where state7=[joint_pos(6), gripper_pos(1)]."""
ts, st = [], []
with open(mcap_path, "rb") as f:
for _sch, _ch, msg in make_reader(f).iter_messages():
d = json.loads(msg.data)
jp = d["joint_pos"]; gp = d.get("gripper_pos", 0.0)
gp = gp[0] if isinstance(gp, (list, tuple)) else gp
ts.append(msg.log_time); st.append(list(jp) + [gp])
return np.array(ts, dtype=np.int64), np.array(st, dtype=np.float32)
def nearest(ts_sorted, vals, t_ns):
i = int(np.searchsorted(ts_sorted, t_ns))
i = max(0, min(i, len(ts_sorted) - 1))
if i > 0 and abs(ts_sorted[i - 1] - t_ns) < abs(ts_sorted[i] - t_ns):
i -= 1
return vals[i]
def chw_rgb(bgr):
return np.ascontiguousarray(np.transpose(cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB), (2, 0, 1)))
def build_episode_states(ep):
lt, lp = read_arm(f"{ep}/left.mcap")
rt, rp = read_arm(f"{ep}/right.mcap")
cam_ts = np.load(f"{ep}/camera_top-rgb-timestamp.npy") # seconds (reference clock)
return lt, lp, rt, rp, cam_ts
def state_at(lt, lp, rt, rp, t_sec):
t_ns = int(t_sec * 1e9)
return np.concatenate([nearest(lt, lp, t_ns), nearest(rt, rp, t_ns)]).astype(np.float32) # (14,)
async def replay_episode(uri, ep, stride, task, send=True):
lt, lp, rt, rp, cam_ts = build_episode_states(ep)
caps = {c: cv2.VideoCapture(f"{ep}/camera_{c}-images-rgb.mp4") for c in CAMS}
n = min(int(caps["top"].get(cv2.CAP_PROP_FRAME_COUNT)), len(cam_ts))
sent = 0
ws = await websockets.connect(uri, max_size=None, compression=None) if send else None
if ws: await ws.recv() # metadata
try:
for i in range(n):
grabbed = {c: caps[c].read() for c in CAMS}
if i % stride != 0:
continue
if not all(ok for ok, _ in grabbed.values()):
break
state = state_at(lt, lp, rt, rp, float(cam_ts[i]))
obs = {"state": state, "task": task}
for c in CAMS:
obs[CAM_OBS[c]] = chw_rgb(grabbed[c][1])
if ws:
await ws.send(pack(obs))
resp = _decode_text(msgpack.unpackb(await ws.recv(), raw=False, object_hook=_decode_numpy_object))
if "error" in resp:
raise RuntimeError(resp["error"])
sent += 1
finally:
for c in caps.values(): c.release()
if ws: await ws.close()
return sent
async def main_async(args):
eps = sorted(glob.glob(os.path.join(args.episodes_dir, "episode_*")))
print(f"episodes: {len(eps)} stride: {args.stride}")
if args.check:
# verify state convention vs demo stats
lt, lp, rt, rp, cam_ts = build_episode_states(eps[0])
S = np.stack([state_at(lt, lp, rt, rp, float(t)) for t in cam_ts[::50]])
print("constructed state (14) — per-dim min/max:")
print(" min:", np.round(S.min(0), 3)); print(" max:", np.round(S.max(0), 3))
print("compare to demo observation.state stats (from meta) to confirm ordering.")
return
total = 0
for k, ep in enumerate(eps):
s = await replay_episode(args.uri, ep, args.stride, args.task)
total += s
print(f"[{k+1}/{len(eps)}] {os.path.basename(ep)} -> {s} frames sent")
print(f"DONE. {total} failure frames replayed -> server cache dir.")
def main():
p = argparse.ArgumentParser()
p.add_argument("--episodes-dir", default="recordings_fail/20260602")
p.add_argument("--uri", default="ws://127.0.0.1:8112")
p.add_argument("--stride", type=int, default=30, help="camera frames: 30fps/30 ≈ 1Hz")
p.add_argument("--task", default="stack the cubes")
p.add_argument("--check", action="store_true", help="just print constructed-state ranges to verify convention")
args = p.parse_args()
asyncio.run(main_async(args))
if __name__ == "__main__":
main()