File size: 7,388 Bytes
20b27fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
#!/usr/bin/env python3
"""Replay the 44 yam-stack-cube demos through the MolmoAct2 fine-tune server to
collect (M, 2560) prefix hidden-state shards for RLT Stage-1 encoder training.

This is the ⏭️ "Collect prefix data" step. It does NOT touch the model directly —
it talks to molmoact2_finetune_karma_server.py over ws:8112, which must be
running in PREFIX-CACHE mode:

    MOLMOACT2_FT_ENCODER_CACHE_DIR=./encoder_cache_prefix \
    MOLMOACT2_FT_ENCODER_CACHE_TARGET=prefix \
    MOLMOACT2_FT_ENCODER_CACHE_STRIDE=1 \
    ./run_finetune_server.sh --default-task "stack the cubes"

The server caches EVERY inference it receives (server STRIDE=1); this script does
the frame subsampling (--stride, default 5) so disk stays bounded. Each demo is
sent over its OWN ws connection so the server's next_episode() bumps episode_id
→ shards land as ep{NNNN}_*.npz, grouped per demo for the later success/fail gate.

Data path (read directly from the HF v3.0 dataset cache, no `datasets` extra):
  * observation.state / task  ← data parquet
  * camera frames             ← h264 mp4s, decoded sequentially with cv2 and
                                sliced per episode by `length` (exact alignment).

Run (after the server is up in prefix mode):
    ./lerobot/.venv/bin/python collect_prefix.py --stride 5
"""
from __future__ import annotations

import argparse
import asyncio
import os
import time

import cv2
import msgpack
import numpy as np
import pyarrow.parquet as pq
import websockets

# Reuse the server's exact msgpack wire codecs so the encoding matches 1:1.
from molmoact2_finetune_karma_server import pack, _decode_numpy_object, _decode_text

CAMS = ["observation.images.top", "observation.images.left", "observation.images.right"]
CAM_TO_OBS = {  # dataset key -> karma obs key the server expects
    "observation.images.top": "top_camera-images-rgb",
    "observation.images.left": "left_camera-images-rgb",
    "observation.images.right": "right_camera-images-rgb",
}
SNAPSHOT = os.path.expanduser(
    "~/.cache/huggingface/hub/datasets--atharva-pantheon--yam-stack-cube/"
    "snapshots/3ac38351db6b4f6924263cdc33fec08514a7fc96"
)


def _decode_response(raw: bytes) -> dict:
    obj = msgpack.unpackb(raw, raw=False, object_hook=_decode_numpy_object)
    return _decode_text(obj)


def video_path(cam: str, file_index: int) -> str:
    return os.path.join(SNAPSHOT, "videos", cam, "chunk-000", f"file-{file_index:03d}.mp4")


def load_metadata():
    """Return (states[N,14], tasks_per_episode, episodes[list of dicts])."""
    data = pq.read_table(os.path.join(SNAPSHOT, "data", "chunk-000", "file-000.parquet")).to_pydict()
    states = np.asarray(data["observation.state"], dtype=np.float32)  # (N, 14)

    ep = pq.read_table(
        os.path.join(SNAPSHOT, "meta", "episodes", "chunk-000", "file-000.parquet")
    ).to_pydict()
    episodes = []
    for i in range(len(ep["episode_index"])):
        cam_file = {c: int(ep[f"videos/{c}/file_index"][i]) for c in CAMS}
        cam_ts = {c: float(ep[f"videos/{c}/from_timestamp"][i]) for c in CAMS}
        # all cameras should share the same file split; we read sequentially so
        # only the ordering (by from_timestamp within a file) needs to agree.
        episodes.append({
            "episode_index": int(ep["episode_index"][i]),
            "length": int(ep["length"][i]),
            "from": int(ep["dataset_from_index"][i]),
            "to": int(ep["dataset_to_index"][i]),
            "task": ep["tasks"][i][0] if ep["tasks"][i] else "stack the cubes",
            "file_index": cam_file,
            "from_ts": cam_ts,
        })
    return states, episodes


def chw_rgb(frame_bgr: np.ndarray) -> np.ndarray:
    """cv2 BGR HWC uint8 -> CHW RGB uint8 (what the server's _chw_to_hwc_uint8 wants)."""
    rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
    return np.ascontiguousarray(np.transpose(rgb, (2, 0, 1)))


async def send_episode(uri, ep, states, caps, stride, dtype_note):
    """Stream one demo's frames over a fresh ws connection (→ one server episode)."""
    length, base = ep["length"], ep["from"]
    sent = 0
    async with websockets.connect(uri, max_size=None, compression=None) as ws:
        await ws.recv()  # discard the server's metadata frame
        for i in range(length):
            send_this = (i % stride == 0)
            frames = {}
            for c in CAMS:
                if send_this:
                    ok, fr = caps[c].read()           # grab + decode
                else:
                    ok = caps[c].grab(); fr = None     # advance only, no decode
                if not ok:
                    raise RuntimeError(f"video underrun ep{ep['episode_index']} cam {c} frame {i}")
                frames[c] = fr
            if not send_this:
                continue
            obs = {"state": states[base + i], "task": ep["task"]}
            for c in CAMS:
                obs[CAM_TO_OBS[c]] = chw_rgb(frames[c])
            await ws.send(pack(obs))
            resp = _decode_response(await ws.recv())
            if "error" in resp:
                raise RuntimeError(f"server error ep{ep['episode_index']} frame {i}: {resp['error']}")
            sent += 1
    return sent


async def main_async(args):
    states, episodes = load_metadata()
    if args.max_episodes:
        episodes = episodes[: args.max_episodes]
    n_frames_total = sum(ep["length"] for ep in episodes)
    est_shards = sum(len(range(0, ep["length"], args.stride)) for ep in episodes)
    print(f"episodes: {len(episodes)}  frames: {n_frames_total}  stride: {args.stride}  "
          f"→ ~{est_shards} shards")

    # Group by camera file_index, read each mp4 in episode order (= from_timestamp order).
    by_file: dict[int, list] = {}
    for ep in episodes:
        by_file.setdefault(ep["file_index"][CAMS[0]], []).append(ep)
    for fidx in by_file:
        by_file[fidx].sort(key=lambda e: e["from_ts"][CAMS[0]])

    t0 = time.time()
    done = 0
    for fidx, eps in sorted(by_file.items()):
        # open per-camera captures for THIS file group
        caps = {}
        for c in CAMS:
            caps[c] = cv2.VideoCapture(video_path(c, eps[0]["file_index"][c]))
            if not caps[c].isOpened():
                raise RuntimeError(f"cannot open {video_path(c, eps[0]['file_index'][c])}")
        try:
            for ep in eps:
                sent = await send_episode(args.uri, ep, states, caps, args.stride, args.dtype)
                done += 1
                rate = done / (time.time() - t0)
                print(f"[{done}/{len(episodes)}] demo ep{ep['episode_index']:02d} "
                      f"len={ep['length']} sent={sent}  ({rate*60:.1f} demos/min)")
        finally:
            for c in caps.values():
                c.release()
    print(f"DONE. {done} demos replayed in {(time.time()-t0)/60:.1f} min. "
          f"shards in the server's MOLMOACT2_FT_ENCODER_CACHE_DIR.")


def main():
    p = argparse.ArgumentParser()
    p.add_argument("--uri", default="ws://127.0.0.1:8112")
    p.add_argument("--stride", type=int, default=5, help="send every Nth demo frame (disk control)")
    p.add_argument("--max-episodes", type=int, default=0, help="smoke test: only first N demos (0=all)")
    p.add_argument("--dtype", default="float16")
    args = p.parse_args()
    asyncio.run(main_async(args))


if __name__ == "__main__":
    main()