| """SHIFT dataset.""" |
|
|
| from __future__ import annotations |
|
|
| import json |
| import multiprocessing |
| import os |
| from collections.abc import Sequence |
| from functools import partial |
|
|
| import numpy as np |
| from tqdm import tqdm |
|
|
| from vis4d.common.imports import SCALABEL_AVAILABLE |
| from vis4d.common.logging import rank_zero_info |
| from vis4d.common.typing import NDArrayF32, NDArrayI64, NDArrayNumber |
| from vis4d.data.const import CommonKeys as K |
| from vis4d.data.datasets.base import VideoDataset |
| from vis4d.data.datasets.util import im_decode, npy_decode |
| from vis4d.data.io import DataBackend, FileBackend, HDF5Backend, ZipBackend |
| from vis4d.data.typing import DictData |
|
|
| from .base import VideoDataset, VideoMapping |
| from .scalabel import Scalabel |
|
|
| shift_det_map = { |
| "pedestrian": 0, |
| "car": 1, |
| "truck": 2, |
| "bus": 3, |
| "motorcycle": 4, |
| "bicycle": 5, |
| } |
| shfit_track_map = { |
| "pedestrian": 0, |
| "car": 1, |
| "truck": 2, |
| "bus": 3, |
| "motorcycle": 4, |
| "bicycle": 5, |
| } |
| shift_seg_map = { |
| "unlabeled": 0, |
| "building": 1, |
| "fence": 2, |
| "other": 3, |
| "pedestrian": 4, |
| "pole": 5, |
| "road line": 6, |
| "road": 7, |
| "sidewalk": 8, |
| "vegetation": 9, |
| "vehicle": 10, |
| "wall": 11, |
| "traffic sign": 12, |
| "sky": 13, |
| "ground": 14, |
| "bridge": 15, |
| "rail track": 16, |
| "guard rail": 17, |
| "traffic light": 18, |
| "static": 19, |
| "dynamic": 20, |
| "water": 21, |
| "terrain": 22, |
| } |
| shift_seg_ignore = [ |
| "unlabeled", |
| "other", |
| "ground", |
| "bridge", |
| "rail track", |
| "guard rail", |
| "static", |
| "dynamic", |
| "water", |
| ] |
|
|
| if SCALABEL_AVAILABLE: |
| from scalabel.label.io import parse |
| from scalabel.label.typing import Config |
| from scalabel.label.typing import Dataset as ScalabelData |
| else: |
| raise ImportError("scalabel is not installed.") |
|
|
|
|
| def _get_extension(backend: DataBackend) -> str: |
| """Get the appropriate file extension for the given backend.""" |
| if isinstance(backend, HDF5Backend): |
| return ".hdf5" |
| if isinstance(backend, ZipBackend): |
| return ".zip" |
| if isinstance(backend, FileBackend): |
| return "" |
| raise ValueError(f"Unsupported backend {backend}.") |
|
|
|
|
| class _SHIFTScalabelLabels(Scalabel): |
| """Helper class for labels in SHIFT that are stored in Scalabel format.""" |
|
|
| VIEWS = [ |
| "front", |
| "center", |
| "left_45", |
| "left_90", |
| "right_45", |
| "right_90", |
| "left_stereo", |
| ] |
|
|
| def __init__( |
| self, |
| data_root: str, |
| split: str, |
| data_file: str = "", |
| keys_to_load: Sequence[str] = (K.images, K.boxes2d), |
| attributes_to_load: Sequence[dict[str, str | float]] | None = None, |
| annotation_file: str = "", |
| view: str = "front", |
| framerate: str = "images", |
| shift_type: str = "discrete", |
| skip_empty_frames: bool = False, |
| backend: DataBackend = HDF5Backend(), |
| verbose: bool = False, |
| num_workers: int = 1, |
| ) -> None: |
| """Initialize SHIFT dataset for one view. |
| |
| Args: |
| data_root (str): Path to the root directory of the dataset. |
| split (str): Which data split to load. |
| data_file (str): Path to the data archive file. Default: "". |
| keys_to_load (Sequence[str]): List of keys to load. |
| Default: (K.images, K.boxes2d). |
| attributes_to_load (Sequence[dict[str, str | float]] | None): |
| List of attributes to load. Default: None. |
| annotation_file (str): Path to the annotation file. Default: "". |
| view (str): Which view to load. Default: "front". Options: "front", |
| "center", "left_45", "left_90", "right_45", "right_90", and |
| "left_stereo". |
| framerate (str): Which framerate to load. Default: "images". |
| shift_type (str): Which shift type to load. Default: "discrete". |
| Options: "discrete", "continuous/1x", "continuous/10x", and |
| "continuous/100x". |
| skip_empty_frames (bool): Whether to skip frames with no |
| instance annotations. Default: False. |
| backend (DataBackend): Backend to use for loading data. Default: |
| HDF5Backend(). |
| verbose (bool): Whether to print verbose logs. Default: False. |
| num_workers (int): Number of workers to use for loading data. |
| Default: 1. |
| """ |
| self.verbose = verbose |
| self.num_workers = num_workers |
|
|
| |
| assert split in {"train", "val", "test"}, f"Invalid split '{split}'" |
| assert view in _SHIFTScalabelLabels.VIEWS, f"Invalid view '{view}'" |
|
|
| |
| ext = _get_extension(backend) |
| if shift_type.startswith("continuous"): |
| shift_speed = shift_type.split("/")[-1] |
| annotation_path = os.path.join( |
| data_root, |
| "continuous", |
| framerate, |
| shift_speed, |
| split, |
| view, |
| annotation_file, |
| ) |
| data_path = os.path.join( |
| data_root, |
| "continuous", |
| framerate, |
| shift_speed, |
| split, |
| view, |
| f"{data_file}{ext}", |
| ) |
| else: |
| annotation_path = os.path.join( |
| data_root, "discrete", framerate, split, view, annotation_file |
| ) |
| data_path = os.path.join( |
| data_root, |
| "discrete", |
| framerate, |
| split, |
| view, |
| f"{data_file}{ext}", |
| ) |
| super().__init__( |
| data_path, |
| annotation_path, |
| data_backend=backend, |
| keys_to_load=keys_to_load, |
| attributes_to_load=attributes_to_load, |
| skip_empty_samples=skip_empty_frames, |
| ) |
|
|
| def _generate_mapping(self) -> ScalabelData: |
| """Generate data mapping.""" |
| |
| if self.verbose: |
| rank_zero_info( |
| "Loading annotation from '%s' ...", self.annotation_path |
| ) |
| return self._load(self.annotation_path) |
|
|
| def _load(self, filepath: str) -> ScalabelData: |
| """Load labels from a json file or a folder of json files.""" |
| raw_frames: list[DictData] = [] |
| raw_groups: list[DictData] = [] |
| if not os.path.exists(filepath): |
| raise FileNotFoundError(f"{filepath} does not exist.") |
|
|
| def process_file(filepath: str) -> DictData | None: |
| raw_cfg = None |
| with open(filepath, mode="r", encoding="utf-8") as fp: |
| content = json.load(fp) |
| if isinstance(content, dict): |
| raw_frames.extend(content["frames"]) |
| if "groups" in content and content["groups"] is not None: |
| raw_groups.extend(content["groups"]) |
| if "config" in content and content["config"] is not None: |
| raw_cfg = content["config"] |
| elif isinstance(content, list): |
| raw_frames.extend(content) |
| else: |
| raise TypeError( |
| "The input file contains neither dict nor list." |
| ) |
|
|
| rank_zero_info( |
| "Loading SHIFT annotation from '%s' Done.", filepath |
| ) |
| return raw_cfg |
|
|
| cfg = None |
| if os.path.isfile(filepath) and filepath.endswith("json"): |
| ret_cfg = process_file(filepath) |
| if ret_cfg is not None: |
| cfg = ret_cfg |
| else: |
| raise TypeError("Inputs must be a folder or a JSON file.") |
|
|
| config = None |
| if cfg is not None: |
| config = Config(**cfg) |
|
|
| parse_func = partial(parse, validate_frames=False) |
| if self.num_workers > 1: |
| with multiprocessing.Pool(self.num_workers) as pool: |
| frames = [] |
| with tqdm(total=len(raw_frames)) as pbar: |
| for result in pool.imap_unordered( |
| parse_func, raw_frames, chunksize=1000 |
| ): |
| frames.append(result) |
| pbar.update() |
| else: |
| frames = [parse_func(frame) for frame in raw_frames] |
| return ScalabelData(frames=frames, config=config, groups=None) |
|
|
|
|
| class SHIFT(VideoDataset): |
| """SHIFT dataset class, supporting multiple tasks and views.""" |
|
|
| DESCRIPTION = """SHIFT Dataset, a synthetic driving dataset for continuous |
| multi-task domain adaptation""" |
| HOMEPAGE = "https://www.vis.xyz/shift/" |
| PAPER = "https://arxiv.org/abs/2206.08367" |
| LICENSE = "CC BY-NC-SA 4.0" |
|
|
| KEYS = [ |
| |
| K.images, |
| K.original_hw, |
| K.input_hw, |
| K.points3d, |
| |
| K.intrinsics, |
| K.extrinsics, |
| K.timestamp, |
| K.axis_mode, |
| K.boxes2d, |
| K.boxes2d_classes, |
| K.boxes2d_track_ids, |
| K.instance_masks, |
| K.boxes3d, |
| K.boxes3d_classes, |
| K.boxes3d_track_ids, |
| |
| K.seg_masks, |
| K.depth_maps, |
| K.optical_flows, |
| ] |
|
|
| VIEWS = [ |
| "front", |
| "center", |
| "left_45", |
| "left_90", |
| "right_45", |
| "right_90", |
| "left_stereo", |
| ] |
|
|
| DATA_GROUPS = { |
| "img": [ |
| K.images, |
| K.original_hw, |
| K.input_hw, |
| K.intrinsics, |
| ], |
| "det_2d": [ |
| K.timestamp, |
| K.axis_mode, |
| K.extrinsics, |
| K.boxes2d, |
| K.boxes2d_classes, |
| K.boxes2d_track_ids, |
| ], |
| "det_3d": [ |
| K.boxes3d, |
| K.boxes3d_classes, |
| K.boxes3d_track_ids, |
| ], |
| "det_insseg_2d": [ |
| K.instance_masks, |
| ], |
| "semseg": [ |
| K.seg_masks, |
| ], |
| "depth": [ |
| K.depth_maps, |
| ], |
| "flow": [ |
| K.optical_flows, |
| ], |
| "lidar": [ |
| K.points3d, |
| ], |
| } |
|
|
| GROUPS_IN_SCALABEL = ["det_2d", "det_3d", "det_insseg_2d"] |
|
|
| def __init__( |
| self, |
| data_root: str, |
| split: str, |
| keys_to_load: Sequence[str] = (K.images, K.boxes2d), |
| views_to_load: Sequence[str] = ("front",), |
| attributes_to_load: Sequence[dict[str, str | float]] | None = None, |
| framerate: str = "images", |
| shift_type: str = "discrete", |
| skip_empty_frames: bool = False, |
| backend: DataBackend = HDF5Backend(), |
| num_workers: int = 1, |
| verbose: bool = False, |
| ) -> None: |
| """Initialize SHIFT dataset.""" |
| super().__init__(data_backend=backend) |
| |
| assert split in {"train", "val", "test"}, f"Invalid split '{split}'." |
| assert framerate in { |
| "images", |
| "videos", |
| }, f"Invalid framerate '{framerate}'. Must be 'images' or 'videos'." |
| assert shift_type in { |
| "discrete", |
| "continuous/1x", |
| "continuous/10x", |
| "continuous/100x", |
| }, ( |
| f"Invalid shift_type '{shift_type}'. " |
| "Must be one of 'discrete', 'continuous/1x', 'continuous/10x', " |
| "or 'continuous/100x'." |
| ) |
| self.validate_keys(keys_to_load) |
|
|
| |
| self.data_root = data_root |
| self.split = split |
| self.keys_to_load = keys_to_load |
| self.views_to_load = views_to_load |
| self.attributes_to_load = attributes_to_load |
| self.framerate = framerate |
| self.shift_type = shift_type |
| self.backend = backend |
| self.verbose = verbose |
| self.ext = _get_extension(backend) |
| if self.shift_type.startswith("continuous"): |
| shift_speed = self.shift_type.split("/")[-1] |
| self.annotation_base = os.path.join( |
| self.data_root, |
| "continuous", |
| self.framerate, |
| shift_speed, |
| self.split, |
| ) |
| else: |
| self.annotation_base = os.path.join( |
| self.data_root, self.shift_type, self.framerate, self.split |
| ) |
| if self.verbose: |
| print(f"Base: {self.annotation_base}. Backend: {self.backend}") |
|
|
| |
| self._data_groups_to_load = self._get_data_groups(keys_to_load) |
| if "det_2d" not in self._data_groups_to_load: |
| raise ValueError( |
| "In current implementation, the 'det_2d' data group must be " |
| "loaded to load any other data group." |
| ) |
|
|
| self.scalabel_datasets = {} |
| for view in self.views_to_load: |
| if view == "center": |
| |
| self.scalabel_datasets["center/lidar"] = _SHIFTScalabelLabels( |
| data_root=self.data_root, |
| split=self.split, |
| data_file="lidar", |
| annotation_file="det_3d.json", |
| view=view, |
| framerate=self.framerate, |
| shift_type=self.shift_type, |
| keys_to_load=(K.points3d, *self.DATA_GROUPS["det_3d"]), |
| attributes_to_load=self.attributes_to_load, |
| skip_empty_frames=skip_empty_frames, |
| backend=backend, |
| num_workers=num_workers, |
| verbose=verbose, |
| ) |
| else: |
| |
| image_loaded = False |
| for group in self._data_groups_to_load: |
| name = f"{view}/{group}" |
| keys_to_load = list(self.DATA_GROUPS[group]) |
| |
| if not image_loaded: |
| keys_to_load.extend(self.DATA_GROUPS["img"]) |
| image_loaded = True |
| self.scalabel_datasets[name] = _SHIFTScalabelLabels( |
| data_root=self.data_root, |
| split=self.split, |
| data_file="img", |
| annotation_file=f"{group}.json", |
| view=view, |
| framerate=self.framerate, |
| shift_type=self.shift_type, |
| keys_to_load=keys_to_load, |
| attributes_to_load=self.attributes_to_load, |
| skip_empty_frames=skip_empty_frames, |
| backend=backend, |
| num_workers=num_workers, |
| verbose=verbose, |
| ) |
|
|
| self.video_mapping = self._generate_video_mapping() |
|
|
| def validate_keys(self, keys_to_load: Sequence[str]) -> None: |
| """Validate that all keys to load are supported.""" |
| for k in keys_to_load: |
| if k not in self.KEYS: |
| raise ValueError(f"Key '{k}' is not supported!") |
|
|
| def _get_data_groups(self, keys_to_load: Sequence[str]) -> list[str]: |
| """Get the data groups that need to be loaded from Scalabel.""" |
| data_groups = ["det_2d"] |
| for data_group, group_keys in self.DATA_GROUPS.items(): |
| if data_group in self.GROUPS_IN_SCALABEL: |
| |
| if any(key in group_keys for key in keys_to_load): |
| data_groups.append(data_group) |
| return list(set(data_groups)) |
|
|
| def _load( |
| self, view: str, data_group: str, file_ext: str, video: str, frame: str |
| ) -> NDArrayNumber: |
| """Load data from the given data group.""" |
| frame_number = frame.split("_")[0] |
| filepath = os.path.join( |
| self.annotation_base, |
| view, |
| f"{data_group}{self.ext}", |
| video, |
| f"{frame_number}_{data_group}_{view}.{file_ext}", |
| ) |
| if data_group == "semseg": |
| return self._load_semseg(filepath) |
| if data_group == "depth": |
| return self._load_depth(filepath) |
| if data_group == "flow": |
| return self._load_flow(filepath) |
| raise ValueError( |
| f"Invalid data group '{data_group}'" |
| ) |
|
|
| def _load_semseg(self, filepath: str) -> NDArrayI64: |
| """Load semantic segmentation data.""" |
| im_bytes = self.backend.get(filepath) |
| image = im_decode(im_bytes)[..., 0] |
| return image.astype(np.int64) |
|
|
| def _load_depth( |
| self, filepath: str, depth_factor: float = 16777.216 |
| ) -> NDArrayF32: |
| """Load depth data.""" |
| assert depth_factor > 0, "Max depth value must be greater than 0." |
|
|
| im_bytes = self.backend.get(filepath) |
| image = im_decode(im_bytes) |
| if image.shape[2] > 3: |
| image = image[:, :, :3] |
| image = image.astype(np.float32) |
|
|
| |
| depth = ( |
| image[:, :, 2] * 256 * 256 + image[:, :, 1] * 256 + image[:, :, 0] |
| ) |
| return np.ascontiguousarray(depth / depth_factor, dtype=np.float32) |
|
|
| def _load_flow(self, filepath: str) -> NDArrayF32: |
| """Load optical flow data.""" |
| npy_bytes = self.backend.get(filepath) |
| flow = npy_decode(npy_bytes, key="flow") |
| flow = flow[:, :, [1, 0]] |
| flow *= flow.shape[1] |
| if self.framerate == "images": |
| flow *= 10.0 |
| return flow.astype(np.float32) |
|
|
| def _get_frame_key(self, idx: int) -> tuple[str, str]: |
| """Get the frame identifier (video name, frame name) by index.""" |
| if len(self.scalabel_datasets) > 0: |
| frames = self.scalabel_datasets[ |
| list(self.scalabel_datasets.keys())[0] |
| ].frames |
| return frames[idx].videoName, frames[idx].name |
| raise ValueError("No Scalabel file has been loaded.") |
|
|
| def __len__(self) -> int: |
| """Get the number of samples in the dataset.""" |
| if len(self.scalabel_datasets) > 0: |
| return len( |
| self.scalabel_datasets[list(self.scalabel_datasets.keys())[0]] |
| ) |
| raise ValueError( |
| "No Scalabel file has been loaded." |
| ) |
|
|
| def _generate_video_mapping(self) -> VideoMapping: |
| """Group all dataset sample indices (int) by their video ID (str). |
| |
| Returns: |
| VideoMapping: Mapping of video IDs to sample indices and frame IDs. |
| |
| Raises: |
| ValueError: If no Scalabel file has been loaded. |
| """ |
| if len(self.scalabel_datasets) > 0: |
| return self.scalabel_datasets[ |
| list(self.scalabel_datasets.keys())[0] |
| ].video_mapping |
| raise ValueError("No Scalabel file has been loaded.") |
|
|
| def __getitem__(self, idx: int) -> DictData: |
| """Get single sample. |
| |
| Args: |
| idx (int): Index of sample. |
| |
| Returns: |
| DictData: sample at index in Vis4D input format. |
| """ |
| |
| data_dict = {} |
|
|
| |
| video_name, frame_name = self._get_frame_key(idx) |
| data_dict[K.sample_names] = frame_name |
| data_dict[K.sequence_names] = video_name |
| data_dict[K.frame_ids] = frame_name.split("_")[0] |
|
|
| for view in self.views_to_load: |
| data_dict_view = {} |
|
|
| if view == "center": |
| |
| if K.points3d in self.keys_to_load: |
| data_dict_view.update( |
| self.scalabel_datasets["center/lidar"][idx] |
| ) |
| else: |
| |
| for group in self._data_groups_to_load: |
| data_dict_view.update( |
| self.scalabel_datasets[f"{view}/{group}"][idx] |
| ) |
|
|
| |
| if K.seg_masks in self.keys_to_load: |
| data_dict_view[K.seg_masks] = self._load( |
| view, "semseg", "png", video_name, frame_name |
| ) |
| if K.depth_maps in self.keys_to_load: |
| data_dict_view[K.depth_maps] = self._load( |
| view, "depth", "png", video_name, frame_name |
| ) |
| if K.optical_flows in self.keys_to_load: |
| data_dict_view[K.optical_flows] = self._load( |
| view, "flow", "npz", video_name, frame_name |
| ) |
| data_dict[view] = data_dict_view |
|
|
| return data_dict |
|
|