| """Cosmos-Drive-Dreams 数据集加载器(真实实现)。 |
| |
| 期待目录结构(从 NVIDIA 提供的 .tar 解压): |
| |
| data_root/ |
| synthetic/single_view/ |
| generation/{clip_id}_{chunk_id}_{weather}.mp4 # 121 帧合成视频 |
| labels/{clip_id}/ |
| vehicle_pose/000000.vehicle_pose.npy ... # 30 FPS, FLU |
| pose/000000.pose.{camera}.npy # 30 FPS, OpenCV |
| ftheta_intrinsic/ftheta_intrinsic.{camera}.npy |
| all_object_info/000000.all_object_info.json |
| lidar_raw/000000.lidar_raw.npz # 10 FPS |
| |
| 每段 clip 提供: |
| - 视频按 `_chunk_id` 分块。chunk_id=0 对应 label idx 0..120;chunk_id=1 对应 label idx 121..241。 |
| - 每个样本:8 帧不重叠窗口 t∈[7, 96],输入 8 帧(t-7..t)+ 未来 24 帧标签。 |
| """ |
|
|
| from __future__ import annotations |
|
|
| import json |
| from dataclasses import dataclass |
| from pathlib import Path |
| from typing import Sequence |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from torch.utils.data import Dataset |
|
|
| from ..modules.normalization import symlog |
| from ..modules.rays import FThetaCamera |
| from .label_paths import resolve_clip_file |
| from .hdmap import parse_hdmap_clip |
| from .se3 import matrix_to_6d |
| from .targets import ( |
| ObjectTrackInfo, |
| build_detection_targets, |
| build_ego_future_target, |
| ) |
| from .transforms import DINOV3_MEAN, DINOV3_STD |
|
|
|
|
| |
| DEFAULT_DYNAMIC_CLASSES = [ |
| "Automobile", |
| "Heavy_truck", |
| "Bus", |
| "Train_or_tram_car", |
| "Trolley_bus", |
| "Other_vehicle", |
| "Trailer", |
| "Person", |
| "Stroller", |
| "Rider", |
| "Animal", |
| "Protruding_object", |
| ] |
|
|
| |
| DEFAULT_STRUCTURED_CLASSES = [ |
| "lane", |
| "laneline", |
| "road_boundary", |
| "wait_line", |
| "crosswalk", |
| "road_marking", |
| "pole", |
| "traffic_light", |
| "traffic_sign", |
| ] |
|
|
|
|
| @dataclass |
| class ClipSample: |
| """clip 索引项。""" |
|
|
| clip_id: str |
| chunk_id: int |
| weather: str |
| video_path: Path |
| labels_dir: Path |
| anchor_t: int |
| chunk_offset: int |
|
|
|
|
| def build_clip_index( |
| data_root: str | Path, |
| weathers: Sequence[str] = ("Sunny",), |
| chunk_ids: Sequence[int] = (0, 1), |
| camera_name: str = "camera_front_wide_120fov", |
| stride: int = 8, |
| anchor_min: int = 7, |
| anchor_max: int = 96, |
| max_clips: int | None = None, |
| ) -> list[ClipSample]: |
| """枚举所有可用 (clip, chunk, weather, anchor_t) 样本。 |
| |
| 锚点 ``t`` 在 chunk 内为局部索引,对应视频帧 ``t``,对应标签帧 |
| ``chunk_offset + t``(chunk_offset = chunk_id * 121)。 |
| """ |
| root = Path(data_root) |
| syn_dir = root / "synthetic" / "single_view" / "generation" |
| labels_dir = root / "labels" |
|
|
| samples: list[ClipSample] = [] |
| if not syn_dir.exists(): |
| return samples |
|
|
| for video_path in sorted(syn_dir.glob("*.mp4")): |
| |
| |
| stem = video_path.stem |
| parts = stem.rsplit("_", 2) |
| if len(parts) != 3: |
| continue |
| clip_id, chunk_str, weather = parts |
| try: |
| chunk_id = int(chunk_str) |
| except ValueError: |
| continue |
| if chunk_id not in chunk_ids or weather not in weathers: |
| continue |
|
|
| clip_label_dir = labels_dir / clip_id |
| if not clip_label_dir.exists(): |
| continue |
|
|
| chunk_offset = chunk_id * 121 |
| for t in range(anchor_min, anchor_max + 1, stride): |
| samples.append( |
| ClipSample( |
| clip_id=clip_id, |
| chunk_id=chunk_id, |
| weather=weather, |
| video_path=video_path, |
| labels_dir=clip_label_dir, |
| anchor_t=t, |
| chunk_offset=chunk_offset, |
| ) |
| ) |
| if max_clips is not None and len({s.clip_id for s in samples}) >= max_clips: |
| break |
|
|
| return samples |
|
|
|
|
| def _load_video_frames( |
| video_path: Path, |
| frame_indices: Sequence[int], |
| target_h: int, |
| target_w: int, |
| ) -> torch.Tensor: |
| """从 .mp4 中读取指定帧序列,调整大小并按 ``[T, 3, H, W]`` 返回 ``float32 in [0, 1]``。""" |
| cap = cv2.VideoCapture(str(video_path)) |
| if not cap.isOpened(): |
| raise FileNotFoundError(f"无法打开视频: {video_path}") |
| frames = [] |
| for idx in frame_indices: |
| cap.set(cv2.CAP_PROP_POS_FRAMES, idx) |
| ok, bgr = cap.read() |
| if not ok: |
| cap.release() |
| raise RuntimeError(f"读取帧 {idx} 失败: {video_path}") |
| rgb = cv2.cvtColor(bgr, cv2.COLOR_BGR2RGB) |
| rgb = cv2.resize(rgb, (target_w, target_h * 2), interpolation=cv2.INTER_AREA) |
| |
| rgb = rgb[target_h:, :, :] |
| rgb = rgb.astype(np.float32) / 255.0 |
| frames.append(torch.from_numpy(rgb).permute(2, 0, 1)) |
| cap.release() |
| return torch.stack(frames, dim=0) |
|
|
|
|
| def _load_npy(path: Path) -> np.ndarray: |
| return np.load(path, allow_pickle=False) |
|
|
|
|
| def _load_object_info(path: Path) -> list[ObjectTrackInfo]: |
| """解析单帧 all_object_info JSON。""" |
| if not path.exists(): |
| return [] |
| data = json.loads(path.read_text()) |
| out = [] |
| for tid, info in data.items(): |
| T = torch.tensor(info["object_to_world"], dtype=torch.float32) |
| lwh = torch.tensor(info["object_lwh"], dtype=torch.float32) |
| out.append( |
| ObjectTrackInfo( |
| tracking_id=tid, |
| object_to_world=T, |
| lwh=lwh, |
| is_moving=bool(info.get("object_is_moving", False)), |
| object_type=str(info.get("object_type", "")), |
| ) |
| ) |
| return out |
|
|
|
|
| def _load_lidar_self_frame( |
| labels_dir: Path, |
| label_idx: int, |
| vehicle_pose: torch.Tensor, |
| max_history: int = 3, |
| ) -> torch.Tensor | None: |
| """读取与 ``label_idx`` 时间最近的 LIDAR 帧并把 xyz 转到当前 ego self 系。 |
| |
| LIDAR 是 10 FPS(每 3 个相机帧 1 个 LIDAR 帧),数据集存储 ``000000``、 |
| ``000003``、``000006`` 等步长 3 的索引。我们向下取整最近的一帧。 |
| """ |
| lidar_idx = (label_idx // 3) * 3 |
| search_order = [lidar_idx - back * 3 for back in range(max_history + 1) if lidar_idx - back * 3 >= 0] |
| p: Path | None = None |
| for idx_try in search_order: |
| try: |
| p = resolve_clip_file(labels_dir, "lidar_raw", f"{idx_try:06d}.lidar_raw.npz") |
| break |
| except FileNotFoundError: |
| continue |
| if p is None: |
| return None |
| arr = np.load(p, allow_pickle=False) |
| xyz_lidar = arr["xyz"] |
| lidar_to_world = arr["lidar_to_world"] |
| |
| pts_w = (lidar_to_world[:3, :3] @ xyz_lidar.T).T + lidar_to_world[:3, 3] |
| inv_pose = torch.linalg.inv(vehicle_pose) |
| pts_w_t = torch.from_numpy(pts_w).float() |
| pts_self = (inv_pose[:3, :3] @ pts_w_t.T).T + inv_pose[:3, 3] |
| return pts_self |
|
|
|
|
| class CosmosDriveDreamsDataset(Dataset): |
| """端到端样本:8 帧图像 + ego/intr/extr + 检测 + 自车未来 + 对象未来。""" |
|
|
| def __init__( |
| self, |
| data_root: str | Path, |
| samples: list[ClipSample] | None = None, |
| weathers: Sequence[str] = ("Sunny",), |
| camera_name: str = "camera_front_wide_120fov", |
| image_h: int = 384, |
| image_w: int = 1024, |
| num_history: int = 8, |
| future_horizon: int = 24, |
| max_distance_m: float = 48.0, |
| occlusion_tol: float = 0.5, |
| dynamic_classes: Sequence[str] = DEFAULT_DYNAMIC_CLASSES, |
| structured_classes: Sequence[str] = DEFAULT_STRUCTURED_CLASSES, |
| do_normalize: bool = True, |
| use_lidar_occlusion: bool = True, |
| use_hdmap: bool = True, |
| ) -> None: |
| super().__init__() |
| self.data_root = Path(data_root) |
| self.samples = samples if samples is not None else build_clip_index( |
| data_root, weathers=weathers, camera_name=camera_name |
| ) |
| self.camera_name = camera_name |
| self.image_h = image_h |
| self.image_w = image_w |
| self.num_history = num_history |
| self.future_horizon = future_horizon |
| self.max_distance_m = max_distance_m |
| self.occlusion_tol = occlusion_tol |
| self.dynamic_classes = list(dynamic_classes) |
| self.structured_classes = list(structured_classes) |
| self.do_normalize = do_normalize |
| self.use_lidar_occlusion = use_lidar_occlusion |
| self.use_hdmap = use_hdmap |
| |
| self._hdmap_cache: dict[str, list[ObjectTrackInfo]] = {} |
| self._hdmap_cache_max = 32 |
|
|
| def __len__(self) -> int: |
| return len(self.samples) |
|
|
| def _load_intrinsic(self, sample: ClipSample) -> torch.Tensor: |
| p = resolve_clip_file( |
| sample.labels_dir, |
| "ftheta_intrinsic", |
| f"ftheta_intrinsic.{self.camera_name}.npy", |
| ) |
| return torch.from_numpy(_load_npy(p)).float() |
|
|
| def _load_pose_camera(self, sample: ClipSample, label_idx: int) -> torch.Tensor: |
| p = resolve_clip_file( |
| sample.labels_dir, |
| "pose", |
| f"{label_idx:06d}.pose.{self.camera_name}.npy", |
| ) |
| return torch.from_numpy(_load_npy(p)).float() |
|
|
| def _load_pose_vehicle(self, sample: ClipSample, label_idx: int) -> torch.Tensor: |
| p = resolve_clip_file( |
| sample.labels_dir, |
| "vehicle_pose", |
| f"{label_idx:06d}.vehicle_pose.npy", |
| ) |
| return torch.from_numpy(_load_npy(p)).float() |
|
|
| def _load_hdmap_static(self, clip_dir: Path) -> list[ObjectTrackInfo]: |
| if not self.use_hdmap: |
| return [] |
| key = str(clip_dir) |
| cached = self._hdmap_cache.get(key) |
| if cached is not None: |
| return cached |
| objs = parse_hdmap_clip(clip_dir) |
| if len(self._hdmap_cache) >= self._hdmap_cache_max: |
| self._hdmap_cache.pop(next(iter(self._hdmap_cache))) |
| self._hdmap_cache[key] = objs |
| return objs |
|
|
| def _load_objects(self, sample: ClipSample, label_idx: int) -> list[ObjectTrackInfo]: |
| p = resolve_clip_file( |
| sample.labels_dir, |
| "all_object_info", |
| f"{label_idx:06d}.all_object_info.json", |
| ) |
| dynamic = _load_object_info(p) |
| |
| |
| |
| return dynamic + self._load_hdmap_static(sample.labels_dir) |
|
|
| def __getitem__(self, idx: int) -> dict: |
| s = self.samples[idx] |
| |
| t = s.anchor_t |
| history_frames = list(range(t - self.num_history + 1, t + 1)) |
| |
| history_label_idx = [s.chunk_offset + f for f in history_frames] |
| future_label_idx = [s.chunk_offset + t + 1 + k for k in range(self.future_horizon)] |
|
|
| |
| |
| |
| images = _load_video_frames(s.video_path, history_frames, self.image_h, self.image_w) |
| |
| if self.do_normalize: |
| images = (images - DINOV3_MEAN) / DINOV3_STD |
|
|
| |
| intr_vec = self._load_intrinsic(s) |
|
|
| |
| pose_cam_world = self._load_pose_camera(s, s.chunk_offset + t) |
| pose_veh_world = self._load_pose_vehicle(s, s.chunk_offset + t) |
| |
| inv_veh = torch.linalg.inv(pose_veh_world) |
| cam2veh = inv_veh @ pose_cam_world |
| extr_6d = matrix_to_6d(cam2veh) |
|
|
| |
| ego_6d = [] |
| for li in history_label_idx: |
| T_vw = self._load_pose_vehicle(s, li) |
| ego_6d.append(matrix_to_6d(T_vw)) |
| ego_6d = torch.stack(ego_6d, dim=0) |
|
|
| |
| |
| objs_t = self._load_objects(s, s.chunk_offset + t) |
| objs_future = [self._load_objects(s, li) for li in future_label_idx] |
| veh_pose_future = [] |
| for li in future_label_idx: |
| try: |
| veh_pose_future.append(self._load_pose_vehicle(s, li)) |
| except FileNotFoundError: |
| break |
|
|
| cam = FThetaCamera.from_vector(intr_vec) |
| lidar_self = None |
| if self.use_lidar_occlusion: |
| try: |
| lidar_self = _load_lidar_self_frame( |
| s.labels_dir, |
| s.chunk_offset + t, |
| pose_veh_world, |
| ) |
| except Exception: |
| lidar_self = None |
|
|
| det_targets = build_detection_targets( |
| objects_t=objs_t, |
| objects_future=objs_future, |
| vehicle_pose_t=pose_veh_world, |
| vehicle_pose_future=veh_pose_future, |
| cam_intrinsic=cam, |
| cam2vehicle=cam2veh, |
| image_h=self.image_h, |
| image_w=self.image_w, |
| max_distance_m=self.max_distance_m, |
| occlusion_depth_tolerance=self.occlusion_tol, |
| lidar_points_self=lidar_self, |
| dynamic_classes=self.dynamic_classes, |
| structured_classes=self.structured_classes, |
| future_horizon=self.future_horizon, |
| ) |
|
|
| ego_future, ego_future_valid = build_ego_future_target( |
| pose_veh_world, veh_pose_future, horizon=self.future_horizon |
| ) |
|
|
| sample_out = { |
| "images": images, |
| "ego_6d": ego_6d, |
| "intr_vec": intr_vec, |
| "extr_6d": extr_6d, |
| "ego_future": ego_future, |
| "ego_future_valid": ego_future_valid, |
| "targets": det_targets, |
| "meta": { |
| "clip_id": s.clip_id, |
| "chunk_id": s.chunk_id, |
| "weather": s.weather, |
| "anchor_t": s.anchor_t, |
| }, |
| } |
| return sample_out |
|
|
|
|
| def collate_samples(batch: list[dict]) -> dict: |
| """自定义 collate:对图像 / ego / intr / extr / ego_future 直接 stack; |
| targets 列表保留为 list(便于匈牙利匹配处理变长 N); |
| meta 也保留为 list。""" |
| out = { |
| "images": torch.stack([b["images"] for b in batch], dim=0), |
| "ego_6d": torch.stack([b["ego_6d"] for b in batch], dim=0), |
| "intr_vec": torch.stack([b["intr_vec"] for b in batch], dim=0), |
| "extr_6d": torch.stack([b["extr_6d"] for b in batch], dim=0), |
| "ego_future": torch.stack([b["ego_future"] for b in batch], dim=0), |
| "ego_future_valid": torch.stack([b["ego_future_valid"] for b in batch], dim=0), |
| "targets": [b["targets"] for b in batch], |
| "meta": [b["meta"] for b in batch], |
| } |
| return out |
|
|