File size: 17,682 Bytes
c3f1dad | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 | """
HDF5 分片生成脚本:读取 MP4 与 JSON,生成符合规范的 shard_XXXX.h5
层级设计(示例):
shard_XXXX.h5
├── /dataset_name_0/
│ ├── @dataset_source: "AgiBot World"
│ ├── @dataset_version: "alpha"
│ ├── @num_trajectories: <N>
│ │
│ ├── /traj_0000/
│ │ ├── @task: "Pickup items in the supermarket"
│ │ ├── @task_id: "327"
│ │ ├── @episode_id: "648642"
│ │ ├── @scene_id: <init_scene_text>
│ │ ├── @robot_type: "unknown"
│ │ ├── @success: 1
│ │ ├── @num_frames: T
│ │ ├── @fps: F
│ │ ├── @duration_sec: T/F
│ │ ├── @camera_views: ["head", "left", "right", ...]
│ │ │
│ │ ├── images_head: [T, H, W, 3] uint8
│ │ ├── images_left: [T, H, W, 3] uint8
│ │ ├── images_right: [T, H, W, 3] uint8
│ │ │
│ │ ├── progress: [T] float32
│ │ ├── done: [T] bool
│ │ └── value: [T] float32
使用方法(示例):
1) 安装依赖(Windows):
pip install h5py numpy opencv-python
2) 运行脚本(你的分段目录作为根,例如 648642-684757):
python build_h5_shard.py \
--dataset-name agibot_world \
--task-json e:/trae_code/20251111data/database/AgiBot_World/task_327.json \
--obs-root e:/trae_code/20251111data/OpenDriveLab___AgiBot-World/raw/main/observations/327/648642-684757 \
--task-id 327 \
--output e:/trae_code/20251111data/shard_327.h5
3) 可选参数:
--dataset-source "AgiBot World" --dataset-version "alpha" --robot-type "franka"
脚本会在 <obs-root>/<episode_id>/videos 下查找 MP4,并固定映射:
head_color → images_head,hand_left_color → images_left,hand_right_color → images_right。
若 obs-root 指向上层目录(如 observations),也会在子目录中递归查找 `<episode_id>/videos`。
注意:该脚本按时间维度进行流式写入,避免一次性加载整段视频到内存。
分片规则:
- 单个 H5 文件最多写入 150 条轨迹(可通过 `--max-traj-per-shard` 配置)。
- 当达到上限时,自动创建新的 H5 文件,文件名基于 `--output` 增加 `_part_XXXX` 后缀。
"""
import argparse
import json
import os
import sys
from typing import Dict, List, Tuple
import h5py
import numpy as np
try:
import cv2 # type: ignore
except Exception as e: # 依赖缺失时给出清晰提示
print("[ERROR] 缺少依赖 opencv-python,请先运行: pip install opencv-python")
raise
def string_array(lst: List[str]):
"""将 Python 字符串列表转换为 h5py 兼容的字符串数组。"""
dt = h5py.string_dtype(encoding="utf-8")
return np.array(lst, dtype=dt)
def find_episode_videos(obs_root: str, task_id: int, episode_id: int) -> Dict[str, str]:
"""
在 <obs-root>/<episode_id>/videos 或其子目录中查找 MP4。
固定只返回 head_color、hand_left_color、hand_right_color 三路(若存在)。
返回: {raw_camera_key: mp4_path}
"""
candidates: Dict[str, str] = {}
# 直接路径:<obs-root>/<episode_id>/videos
direct_dir = os.path.join(obs_root, str(episode_id), "videos")
if os.path.isdir(direct_dir):
for fn in os.listdir(direct_dir):
if fn.lower().endswith(".mp4"):
key = os.path.splitext(fn)[0]
candidates[key] = os.path.join(direct_dir, fn)
# 若未找到,递归在 obs_root 下寻找 `<episode_id>/videos`
if not candidates:
for root, dirs, files in os.walk(obs_root):
base = os.path.basename(root)
if base == str(episode_id) and "videos" in dirs:
vdir = os.path.join(root, "videos")
for fn in os.listdir(vdir):
if fn.lower().endswith(".mp4"):
key = os.path.splitext(fn)[0]
candidates[key] = os.path.join(vdir, fn)
break
# 过滤只保留三路
filtered: Dict[str, str] = {}
for k in ["head_color", "hand_left_color", "hand_right_color"]:
if k in candidates:
filtered[k] = candidates[k]
return filtered
def read_video_meta(path: str) -> Tuple[int, int, int, int, float]:
"""读取视频的基础元信息:(frame_count, width, height, channels, fps)。channels 固定为 3。"""
cap = cv2.VideoCapture(path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频: {path}")
frame_count = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
fps = float(cap.get(cv2.CAP_PROP_FPS) or 0.0)
if fps <= 0:
# 兜底:若无法读到 fps,则使用 30
fps = 30.0
cap.release()
return frame_count, width, height, 3, fps
def write_video_slice_to_dataset(mp4_path: str, dset: h5py.Dataset, start_idx: int, count: int) -> int:
"""
将 mp4 指定区间 [start_idx, start_idx+count) 按帧流式写入 HDF5 dset。
返回实际写入帧数。
"""
cap = cv2.VideoCapture(mp4_path)
if not cap.isOpened():
raise RuntimeError(f"无法打开视频: {mp4_path}")
cap.set(cv2.CAP_PROP_POS_FRAMES, max(0, int(start_idx)))
t = 0
while t < count:
ok, frame_bgr = cap.read()
if not ok:
break
frame_rgb = cv2.cvtColor(frame_bgr, cv2.COLOR_BGR2RGB)
if frame_rgb.dtype != np.uint8:
frame_rgb = frame_rgb.astype(np.uint8)
dset[t, ...] = frame_rgb
t += 1
cap.release()
if t < count:
print(f"[WARN] {os.path.basename(mp4_path)} 仅写入 {t}/{count} 帧 (start={start_idx})")
return t
def build_h5_shard(
output_path: str,
dataset_name: str,
task_json_path: str,
obs_root: str,
task_id_filter: int,
dataset_source: str = "AgiBot World",
dataset_version: str = "alpha",
default_robot_type: str = "unknown",
max_traj_per_shard: int = 150,
) -> None:
"""主流程:读取 JSON 和 MP4,生成 HDF5 分片。"""
with open(task_json_path, "r", encoding="utf-8") as f:
episodes = json.load(f)
if not isinstance(episodes, list):
raise ValueError("task_json 内容应为列表(list)")
# 统计:按 action 切片写入,每个 action 作为一条轨迹
# 先收集 (episode_json, videos_dict, cam_metas, actions) 列表
ep_pool = []
for ep in episodes:
try:
ep_id = int(ep.get("episode_id"))
t_id = int(ep.get("task_id"))
except Exception:
continue
if t_id != task_id_filter:
continue
vids = find_episode_videos(obs_root, task_id_filter, ep_id)
if not vids:
# 不输出未找到视频的提示,静默跳过
continue
# 只保留三路的 meta
cam_metas = {}
for k, mp4 in vids.items():
fc, w, h, ch, fps = read_video_meta(mp4)
cam_metas[k] = (fc, w, h, ch, fps, mp4)
# 打印找到的视频视角
camera_order = ["head_color", "hand_left_color", "hand_right_color"]
present_cams = [c for c in camera_order if c in cam_metas]
view_names = []
for c in present_cams:
if c == "head_color":
view_names.append("head")
elif c == "hand_left_color":
view_names.append("left")
elif c == "hand_right_color":
view_names.append("right")
if present_cams:
print(f"[FOUND] episode {ep_id} 找到视频视角: {', '.join(view_names)}")
actions = (ep.get("label_info") or {}).get("action_config", [])
if not actions:
print(f"[INFO] episode {ep_id} 无 action_config,跳过")
continue
ep_pool.append((ep, vids, cam_metas, actions))
if not ep_pool:
raise RuntimeError("未找到任何包含动作切片的 episode,请检查 JSON 与目录。")
# 创建 HDF5 文件并累计轨迹数
# 预计算有效动作总数(用于整体进度输出)
total_actions_valid = 0
for ep, vids, cam_metas, actions in ep_pool:
camera_order = ["head_color", "hand_left_color", "hand_right_color"]
present_cams = [c for c in camera_order if c in cam_metas]
for act in actions:
try:
s = int(act.get("start_frame", 0))
e = int(act.get("end_frame", 0))
except Exception:
continue
per_cam_len = []
for c in present_cams:
total = cam_metas[c][0]
if s >= total:
length = 0
else:
length = max(0, min(e, total - 1) - s + 1)
per_cam_len.append(length)
slice_len = min(per_cam_len) if per_cam_len else 0
if slice_len > 0:
total_actions_valid += 1
# 分片路径生成函数
def _make_shard_path(base: str, idx: int) -> str:
base = os.path.abspath(base)
d = os.path.dirname(base)
stem = os.path.splitext(os.path.basename(base))[0]
return os.path.join(d, f"{stem}_part_{idx:04d}.h5")
# 打开一个新的分片文件
def _open_shard(idx: int):
path = _make_shard_path(output_path, idx)
h5 = h5py.File(path, "w")
grp = h5.create_group(f"/{dataset_name}_0")
grp.attrs["dataset_source"] = dataset_source
grp.attrs["dataset_version"] = dataset_version
print(f"[SHARD] 开始写入分片 {idx} -> {path}")
return h5, grp, path
shard_idx = 0
h5, grp_dataset, current_shard_path = _open_shard(shard_idx)
traj_count_in_shard = 0
total_traj_written = 0
processed_actions = 0
try:
for ep, vids, cam_metas, actions in ep_pool:
ep_id = int(ep.get("episode_id"))
scene_text = (ep.get("init_scene_text") or "")
# 相机视角固定映射
camera_order = ["head_color", "hand_left_color", "hand_right_color"]
present_cams = [c for c in camera_order if c in cam_metas]
view_names = []
for c in present_cams:
if c == "head_color":
view_names.append("head")
elif c == "hand_left_color":
view_names.append("left")
elif c == "hand_right_color":
view_names.append("right")
# 以第一路相机的 fps 作为参考
ref_fps = cam_metas[present_cams[0]][4] if present_cams else 30.0
for aidx, act in enumerate(actions):
try:
s = int(act.get("start_frame", 0))
e = int(act.get("end_frame", 0))
except Exception:
continue
action_text = (act.get("action_text") or "")
skill = (act.get("skill") or "")
# 对齐各相机的可用帧范围,按最小可用长度截断
# end_frame 视为包含端点,slice_len = e - s + 1
per_cam_len = []
for c in present_cams:
total = cam_metas[c][0]
if s >= total:
length = 0
else:
length = max(0, min(e, total - 1) - s + 1)
per_cam_len.append(length)
slice_len = min(per_cam_len) if per_cam_len else 0
if slice_len <= 0:
print(f"[WARN] episode {ep_id} action[{aidx}]({s}-{e}) 无有效帧,跳过")
continue
# 在当前分片内按计数命名轨迹分组
traj_grp = grp_dataset.create_group(f"traj_{traj_count_in_shard:04d}")
traj_grp.attrs["task"] = action_text
# 自动标号:<task_id>_act_<aidx>
traj_grp.attrs["task_id"] = f"{task_id_filter}_act_{aidx:03d}"
traj_grp.attrs["episode_id"] = str(ep_id)
traj_grp.attrs["scene_id"] = scene_text
traj_grp.attrs["robot_type"] = default_robot_type
traj_grp.attrs["success"] = 1
traj_grp.attrs["num_frames"] = int(slice_len)
traj_grp.attrs["fps"] = float(ref_fps)
traj_grp.attrs["duration_sec"] = float(slice_len) / float(ref_fps)
traj_grp.attrs["camera_views"] = string_array(view_names)
# 写入三路图像(若存在)
for c in present_cams:
_, w, h, _, _, mp4_path = cam_metas[c]
# 目标数据集名称
if c == "head_color":
dname = "images_head"
elif c == "hand_left_color":
dname = "images_left"
else:
dname = "images_right"
dset = traj_grp.create_dataset(
name=dname,
shape=(slice_len, h, w, 3),
dtype=np.uint8,
chunks=(1, h, w, 3),
compression="gzip",
compression_opts=4,
)
written = write_video_slice_to_dataset(mp4_path, dset, start_idx=s, count=slice_len)
if written < slice_len:
# 若未写满,仍保留数据集;进度/时长基于 slice_len
pass
# 写入 progress / done / value
prog = np.linspace(0.0, 1.0, num=slice_len, dtype=np.float32)
done = np.zeros((slice_len,), dtype=np.bool_)
done[-1] = True
value = np.zeros((slice_len,), dtype=np.float32)
traj_grp.create_dataset("progress", data=prog, dtype=np.float32)
traj_grp.create_dataset("done", data=done, dtype=np.bool_)
traj_grp.create_dataset("value", data=value, dtype=np.float32)
traj_count_in_shard += 1
total_traj_written += 1
processed_actions += 1
# 输出整体进度(单行刷新)
sys.stdout.write(
f"\r[PROGRESS] 已写入轨迹 {processed_actions}/{total_actions_valid} (episode {ep_id}, action {aidx})"
)
sys.stdout.flush()
# 达到分片上限则切换到新分片
if traj_count_in_shard >= max_traj_per_shard:
grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
h5.close()
shard_idx += 1
h5, grp_dataset, current_shard_path = _open_shard(shard_idx)
traj_count_in_shard = 0
# 收尾:为最后一个分片设置轨迹数并关闭文件
grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
h5.close()
# 进度换行结束
if total_actions_valid > 0:
sys.stdout.write("\n")
finally:
# 防止异常未关闭
try:
if h5 and h5.id:
grp_dataset.attrs["num_trajectories"] = traj_count_in_shard
h5.close()
except Exception:
pass
print(f"✅ 生成完成,共写入轨迹 {total_traj_written},分片数 {shard_idx + 1}")
def parse_args() -> argparse.Namespace:
p = argparse.ArgumentParser(description="AgiBot World: MP4 + JSON → HDF5 分片生成")
p.add_argument("--dataset-name", required=True, help="HDF5 顶层数据集名前缀(如 droid、bridge、agibot_world)")
p.add_argument("--task-json", required=True, help="task_[id].json 路径")
p.add_argument("--obs-root", required=True, help="observations 根目录(包含 <task_id>/<episode_id>/videos)")
p.add_argument("--task-id", type=int, required=True, help="任务 ID(如 327)")
p.add_argument("--output", required=True, help="输出 HDF5 基础文件路径(会生成 _part_XXXX.h5 分片)")
p.add_argument("--max-traj-per-shard", type=int, default=150, help="单个 H5 分片的最大轨迹数(默认 150)")
p.add_argument("--dataset-source", default="AgiBot World", help="@dataset_source 属性值")
p.add_argument("--dataset-version", default="alpha", help="@dataset_version 属性值")
p.add_argument("--robot-type", default="unknown", help="@robot_type 属性默认值")
return p.parse_args()
def main():
args = parse_args()
build_h5_shard(
output_path=args.output,
dataset_name=args.dataset_name,
task_json_path=args.task_json,
obs_root=args.obs_root,
task_id_filter=args.task_id,
dataset_source=args.dataset_source,
dataset_version=args.dataset_version,
default_robot_type=args.robot_type,
max_traj_per_shard=args.max_traj_per_shard,
)
if __name__ == "__main__":
main() |