File size: 17,682 Bytes
c3f1dad
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
"""

HDF5 分片生成脚本:读取 MP4 与 JSON,生成符合规范的 shard_XXXX.h5



层级设计(示例):



  shard_XXXX.h5

  ├── /dataset_name_0/

  │   ├── @dataset_source: "AgiBot World"

  │   ├── @dataset_version: "alpha"

  │   ├── @num_trajectories: <N>

  │   │

  │   ├── /traj_0000/

  │   │   ├── @task: "Pickup items in the supermarket"

  │   │   ├── @task_id: "327"

  │   │   ├── @episode_id: "648642"

  │   │   ├── @scene_id: <init_scene_text>

  │   │   ├── @robot_type: "unknown"

  │   │   ├── @success: 1

  │   │   ├── @num_frames: T

  │   │   ├── @fps: F

  │   │   ├── @duration_sec: T/F

  │   │   ├── @camera_views: ["head", "left", "right", ...]

  │   │   │

  │   │   ├── images_head:  [T, H, W, 3] uint8

  │   │   ├── images_left:  [T, H, W, 3] uint8

  │   │   ├── images_right: [T, H, W, 3] uint8

  │   │   │

  │   │   ├── progress: [T] float32

  │   │   ├── done:     [T] bool

  │   │   └── value:    [T] float32



使用方法(示例):



  1) 安装依赖(Windows):

     pip install h5py numpy opencv-python



  2) 运行脚本(你的分段目录作为根,例如 648642-684757):

     python build_h5_shard.py \

       --dataset-name agibot_world \

       --task-json e:/trae_code/20251111data/database/AgiBot_World/task_327.json \

       --obs-root e:/trae_code/20251111data/OpenDriveLab___AgiBot-World/raw/main/observations/327/648642-684757 \

       --task-id 327 \

       --output e:/trae_code/20251111data/shard_327.h5



  3) 可选参数:

     --dataset-source "AgiBot World" --dataset-version "alpha" --robot-type "franka"



脚本会在 <obs-root>/<episode_id>/videos 下查找 MP4,并固定映射:

  head_color → images_head,hand_left_color → images_left,hand_right_color → images_right。

若 obs-root 指向上层目录(如 observations),也会在子目录中递归查找 `<episode_id>/videos`。



注意:该脚本按时间维度进行流式写入,避免一次性加载整段视频到内存。



分片规则:

- 单个 H5 文件最多写入 150 条轨迹(可通过 `--max-traj-per-shard` 配置)。

- 当达到上限时,自动创建新的 H5 文件,文件名基于 `--output` 增加 `_part_XXXX` 后缀。

"""

import argparse
import json
import os
import sys
from typing import Dict, List, Tuple

import h5py
import numpy as np

try:
    import cv2  # type: ignore
except Exception as e:  # 依赖缺失时给出清晰提示
    print("[ERROR] 缺少依赖 opencv-python,请先运行: pip install opencv-python")
    raise


def string_array(lst: List[str]):
    """将 Python 字符串列表转换为 h5py 兼容的字符串数组。"""
    dt = h5py.string_dtype(encoding="utf-8")
    return np.array(lst, dtype=dt)


def find_episode_videos(obs_root: str, task_id: int, episode_id: int) -> Dict[str, str]:
    """

    在 <obs-root>/<episode_id>/videos 或其子目录中查找 MP4。

    固定只返回 head_color、hand_left_color、hand_right_color 三路(若存在)。

    返回: {raw_camera_key: mp4_path}

    """
    candidates: Dict[str, str] = {}

    # 直接路径:<obs-root>/<episode_id>/videos
    direct_dir = os.path.join(obs_root, str(episode_id), "videos")
    if os.path.isdir(direct_dir):
        for fn in os.listdir(direct_dir):
            if fn.lower().endswith(".mp4"):
                key = os.path.splitext(fn)[0]
                candidates[key] = os.path.join(direct_dir, fn)

    # 若未找到,递归在 obs_root 下寻找 `<episode_id>/videos`
    if not candidates:
        for root, dirs, files in os.walk(obs_root):
            base = os.path.basename(root)
            if base == str(episode_id) and "videos" in dirs:
                vdir = os.path.join(root, "videos")
                for fn in os.listdir(vdir):
                    if fn.lower().endswith(".mp4"):
                        key = os.path.splitext(fn)[0]
                        candidates[key] = os.path.join(vdir, fn)
                break

    # 过滤只保留三路
    filtered: Dict[str, str] = {}
    for k in ["head_color", "hand_left_color", "hand_right_color"]:
        if k in candidates:
            filtered[k] = candidates[k]
    return filtered


def read_video_meta(path: str) -> Tuple[int, int, int, int, float]:
    """读取视频的基础元信息:(frame_count, width, height, channels, fps)。channels 固定为 3。"""
    cap = cv2.VideoCapture(path)
    if not cap.isOpened():
        raise RuntimeError(f"无法打开视频: {path}")
    frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
    width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
    height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
    fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
    if fps <= 0:
        # 兜底:若无法读到 fps,则使用 30
        fps = 30.0
    cap.release()
    return frame_count, width, height, 3, fps


def write_video_slice_to_dataset(mp4_path: str, dset: h5py.Dataset, start_idx: int, count: int) -> int:
    """

    将 mp4 指定区间 [start_idx, start_idx+count) 按帧流式写入 HDF5 dset。

    返回实际写入帧数。

    """
    cap = cv2.VideoCapture(mp4_path)
    if not cap.isOpened():
        raise RuntimeError(f"无法打开视频: {mp4_path}")
    cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(start_idx)))
    t = 0
    while t < count:
        ok, frame_bgr = cap.read()
        if not ok:
            break
        frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
        if frame_rgb.dtype != np.uint8:
            frame_rgb = frame_rgb.astype(np.uint8)
        dset[t, ...] = frame_rgb
        t += 1
    cap.release()
    if t < count:
        print(f"[WARN] {os.path.basename(mp4_path)} 仅写入 {t}/{count} 帧 (start={start_idx})")
    return t


def build_h5_shard(

    output_path: str,

    dataset_name: str,

    task_json_path: str,

    obs_root: str,

    task_id_filter: int,

    dataset_source: str = "AgiBot World",

    dataset_version: str = "alpha",

    default_robot_type: str = "unknown",

    max_traj_per_shard: int = 150,

) -> None:
    """主流程:读取 JSON 和 MP4,生成 HDF5 分片。"""
    with open(task_json_path, "r", encoding="utf-8") as f:
        episodes = json.load(f)
    if not isinstance(episodes, list):
        raise ValueError("task_json 内容应为列表(list)")

    # 统计:按 action 切片写入,每个 action 作为一条轨迹
    # 先收集 (episode_json, videos_dict, cam_metas, actions) 列表
    ep_pool = []
    for ep in episodes:
        try:
            ep_id = int(ep.get("episode_id"))
            t_id = int(ep.get("task_id"))
        except Exception:
            continue
        if t_id != task_id_filter:
            continue
        vids = find_episode_videos(obs_root, task_id_filter, ep_id)
        if not vids:
            # 不输出未找到视频的提示,静默跳过
            continue
        # 只保留三路的 meta
        cam_metas = {}
        for k, mp4 in vids.items():
            fc, w, h, ch, fps = read_video_meta(mp4)
            cam_metas[k] = (fc, w, h, ch, fps, mp4)
        # 打印找到的视频视角
        camera_order = ["head_color", "hand_left_color", "hand_right_color"]
        present_cams = [c for c in camera_order if c in cam_metas]
        view_names = []
        for c in present_cams:
            if c == "head_color":
                view_names.append("head")
            elif c == "hand_left_color":
                view_names.append("left")
            elif c == "hand_right_color":
                view_names.append("right")
        if present_cams:
            print(f"[FOUND] episode {ep_id} 找到视频视角: {', '.join(view_names)}")
        actions = (ep.get("label_info") or {}).get("action_config", [])
        if not actions:
            print(f"[INFO] episode {ep_id} 无 action_config,跳过")
            continue
        ep_pool.append((ep, vids, cam_metas, actions))

    if not ep_pool:
        raise RuntimeError("未找到任何包含动作切片的 episode,请检查 JSON 与目录。")

    # 创建 HDF5 文件并累计轨迹数
    # 预计算有效动作总数(用于整体进度输出)
    total_actions_valid = 0
    for ep, vids, cam_metas, actions in ep_pool:
        camera_order = ["head_color", "hand_left_color", "hand_right_color"]
        present_cams = [c for c in camera_order if c in cam_metas]
        for act in actions:
            try:
                s = int(act.get("start_frame", 0))
                e = int(act.get("end_frame", 0))
            except Exception:
                continue
            per_cam_len = []
            for c in present_cams:
                total = cam_metas[c][0]
                if s >= total:
                    length = 0
                else:
                    length = max(0, min(e, total - 1) - s + 1)
                per_cam_len.append(length)
            slice_len = min(per_cam_len) if per_cam_len else 0
            if slice_len > 0:
                total_actions_valid += 1

    # 分片路径生成函数
    def _make_shard_path(base: str, idx: int) -> str:
        base = os.path.abspath(base)
        d = os.path.dirname(base)
        stem = os.path.splitext(os.path.basename(base))[0]
        return os.path.join(d, f"{stem}_part_{idx:04d}.h5")

    # 打开一个新的分片文件
    def _open_shard(idx: int):
        path = _make_shard_path(output_path, idx)
        h5 = h5py.File(path, "w")
        grp = h5.create_group(f"/{dataset_name}_0")
        grp.attrs["dataset_source"] = dataset_source
        grp.attrs["dataset_version"] = dataset_version
        print(f"[SHARD] 开始写入分片 {idx} -> {path}")
        return h5, grp, path

    shard_idx = 0
    h5, grp_dataset, current_shard_path = _open_shard(shard_idx)
    traj_count_in_shard = 0
    total_traj_written = 0
    processed_actions = 0

    try:
        for ep, vids, cam_metas, actions in ep_pool:
            ep_id = int(ep.get("episode_id"))
            scene_text = (ep.get("init_scene_text") or "")

            # 相机视角固定映射
            camera_order = ["head_color", "hand_left_color", "hand_right_color"]
            present_cams = [c for c in camera_order if c in cam_metas]
            view_names = []
            for c in present_cams:
                if c == "head_color":
                    view_names.append("head")
                elif c == "hand_left_color":
                    view_names.append("left")
                elif c == "hand_right_color":
                    view_names.append("right")

            # 以第一路相机的 fps 作为参考
            ref_fps = cam_metas[present_cams[0]][4] if present_cams else 30.0

            for aidx, act in enumerate(actions):
                try:
                    s = int(act.get("start_frame", 0))
                    e = int(act.get("end_frame", 0))
                except Exception:
                    continue
                action_text = (act.get("action_text") or "")
                skill = (act.get("skill") or "")

                # 对齐各相机的可用帧范围,按最小可用长度截断
                # end_frame 视为包含端点,slice_len = e - s + 1
                per_cam_len = []
                for c in present_cams:
                    total = cam_metas[c][0]
                    if s >= total:
                        length = 0
                    else:
                        length = max(0, min(e, total - 1) - s + 1)
                    per_cam_len.append(length)
                slice_len = min(per_cam_len) if per_cam_len else 0
                if slice_len <= 0:
                    print(f"[WARN] episode {ep_id} action[{aidx}]({s}-{e}) 无有效帧,跳过")
                    continue

                # 在当前分片内按计数命名轨迹分组
                traj_grp = grp_dataset.create_group(f"traj_{traj_count_in_shard:04d}")
                traj_grp.attrs["task"] = action_text
                # 自动标号:<task_id>_act_<aidx>
                traj_grp.attrs["task_id"] = f"{task_id_filter}_act_{aidx:03d}"
                traj_grp.attrs["episode_id"] = str(ep_id)
                traj_grp.attrs["scene_id"] = scene_text
                traj_grp.attrs["robot_type"] = default_robot_type
                traj_grp.attrs["success"] = 1
                traj_grp.attrs["num_frames"] = int(slice_len)
                traj_grp.attrs["fps"] = float(ref_fps)
                traj_grp.attrs["duration_sec"] = float(slice_len) / float(ref_fps)
                traj_grp.attrs["camera_views"] = string_array(view_names)

                # 写入三路图像(若存在)
                for c in present_cams:
                    _, w, h, _, _, mp4_path = cam_metas[c]
                    # 目标数据集名称
                    if c == "head_color":
                        dname = "images_head"
                    elif c == "hand_left_color":
                        dname = "images_left"
                    else:
                        dname = "images_right"

                    dset = traj_grp.create_dataset(
                        name=dname,
                        shape=(slice_len, h, w, 3),
                        dtype=np.uint8,
                        chunks=(1, h, w, 3),
                        compression="gzip",
                        compression_opts=4,
                    )
                    written = write_video_slice_to_dataset(mp4_path, dset, start_idx=s, count=slice_len)
                    if written < slice_len:
                        # 若未写满,仍保留数据集;进度/时长基于 slice_len
                        pass

                # 写入 progress / done / value
                prog = np.linspace(0.0, 1.0, num=slice_len, dtype=np.float32)
                done = np.zeros((slice_len,), dtype=np.bool_)
                done[-1] = True
                value = np.zeros((slice_len,), dtype=np.float32)

                traj_grp.create_dataset("progress", data=prog, dtype=np.float32)
                traj_grp.create_dataset("done", data=done, dtype=np.bool_)
                traj_grp.create_dataset("value", data=value, dtype=np.float32)

                traj_count_in_shard += 1
                total_traj_written += 1
                processed_actions += 1
                # 输出整体进度(单行刷新)
                sys.stdout.write(
                    f"\r[PROGRESS] 已写入轨迹 {processed_actions}/{total_actions_valid} (episode {ep_id}, action {aidx})"
                )
                sys.stdout.flush()

                # 达到分片上限则切换到新分片
                if traj_count_in_shard >= max_traj_per_shard:
                    grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
                    h5.close()
                    shard_idx += 1
                    h5, grp_dataset, current_shard_path = _open_shard(shard_idx)
                    traj_count_in_shard = 0

        # 收尾:为最后一个分片设置轨迹数并关闭文件
        grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
        h5.close()
        # 进度换行结束
        if total_actions_valid > 0:
            sys.stdout.write("\n")
    finally:
        # 防止异常未关闭
        try:
            if h5 and h5.id:
                grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
                h5.close()
        except Exception:
            pass

    print(f"✅ 生成完成,共写入轨迹 {total_traj_written},分片数 {shard_idx + 1}")


def parse_args() -> argparse.Namespace:
    p = argparse.ArgumentParser(description="AgiBot World: MP4 + JSON → HDF5 分片生成")
    p.add_argument("--dataset-name", required=True, help="HDF5 顶层数据集名前缀(如 droid、bridge、agibot_world)")
    p.add_argument("--task-json", required=True, help="task_[id].json 路径")
    p.add_argument("--obs-root", required=True, help="observations 根目录(包含 <task_id>/<episode_id>/videos)")
    p.add_argument("--task-id", type=int, required=True, help="任务 ID(如 327)")
    p.add_argument("--output", required=True, help="输出 HDF5 基础文件路径(会生成 _part_XXXX.h5 分片)")
    p.add_argument("--max-traj-per-shard", type=int, default=150, help="单个 H5 分片的最大轨迹数(默认 150)")
    p.add_argument("--dataset-source", default="AgiBot World", help="@dataset_source 属性值")
    p.add_argument("--dataset-version", default="alpha", help="@dataset_version 属性值")
    p.add_argument("--robot-type", default="unknown", help="@robot_type 属性默认值")
    return p.parse_args()


def main():
    args = parse_args()
    build_h5_shard(
        output_path=args.output,
        dataset_name=args.dataset_name,
        task_json_path=args.task_json,
        obs_root=args.obs_root,
        task_id_filter=args.task_id,
        dataset_source=args.dataset_source,
        dataset_version=args.dataset_version,
        default_robot_type=args.robot_type,
        max_traj_per_shard=args.max_traj_per_shard,
    )


if __name__ == "__main__":
    main()