Spaces:
Running
Running
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| import collections | |
| import collections.abc | |
| from collections import defaultdict | |
| from pathlib import Path | |
| from typing import Optional | |
| import numpy as np | |
| import pytorch_lightning as pl | |
| import torch | |
| import torch.utils.data as torchdata | |
| from omegaconf import OmegaConf | |
| from scipy.spatial.transform import Rotation | |
| from ... import DATASETS_PATH, logger | |
| from ...osm.tiling import TileManager | |
| from ..dataset import MapLocDataset | |
| from ..sequential import chunk_sequence | |
| from ..torch import collate, worker_init_fn | |
| from .utils import get_camera_calibration, parse_gps_file, parse_split_file | |
| class KittiDataModule(pl.LightningDataModule): | |
| default_cfg = { | |
| **MapLocDataset.default_cfg, | |
| "name": "kitti", | |
| # paths and fetch | |
| "data_dir": DATASETS_PATH / "kitti", | |
| "tiles_filename": "tiles.pkl", | |
| "splits": { | |
| "train": "train_files.txt", | |
| "val": "test1_files.txt", | |
| "test": "test2_files.txt", | |
| }, | |
| "loading": { | |
| "train": "???", | |
| "val": "${.test}", | |
| "test": {"batch_size": 1, "num_workers": 0}, | |
| }, | |
| "max_num_val": 500, | |
| "selection_subset_val": "furthest", | |
| "drop_train_too_close_to_val": 5.0, | |
| "skip_frames": 1, | |
| "camera_index": 2, | |
| # overwrite | |
| "crop_size_meters": 64, | |
| "max_init_error": 20, | |
| "max_init_error_rotation": 10, | |
| "add_map_mask": True, | |
| "mask_pad": 2, | |
| "target_focal_length": 256, | |
| } | |
| dummy_scene_name = "kitti" | |
| def __init__(self, cfg, tile_manager: Optional[TileManager] = None): | |
| super().__init__() | |
| default_cfg = OmegaConf.create(self.default_cfg) | |
| OmegaConf.set_struct(default_cfg, True) # cannot add new keys | |
| self.cfg = OmegaConf.merge(default_cfg, cfg) | |
| self.root = Path(self.cfg.data_dir) | |
| self.tile_manager = tile_manager | |
| if self.cfg.crop_size_meters < self.cfg.max_init_error: | |
| raise ValueError("The ground truth location can be outside the map.") | |
| assert self.cfg.selection_subset_val in ["random", "furthest"] | |
| self.splits = {} | |
| self.shifts = {} | |
| self.calibrations = {} | |
| self.data = {} | |
| self.image_paths = {} | |
| def prepare_data(self): | |
| if not (self.root.exists() and (self.root / ".downloaded").exists()): | |
| raise FileNotFoundError( | |
| "Cannot find the KITTI dataset, run maploc.data.kitti.prepare" | |
| ) | |
| def parse_split(self, split_arg): | |
| if isinstance(split_arg, str): | |
| names, shifts = parse_split_file(self.root / split_arg) | |
| elif isinstance(split_arg, collections.abc.Sequence): | |
| names = [] | |
| shifts = None | |
| for date_drive in split_arg: | |
| data_dir = ( | |
| self.root / date_drive / f"image_{self.cfg.camera_index:02}/data" | |
| ) | |
| assert data_dir.exists(), data_dir | |
| date_drive = tuple(date_drive.split("/")) | |
| n = sorted(date_drive + (p.name,) for p in data_dir.glob("*.png")) | |
| names.extend(n[:: self.cfg.skip_frames]) | |
| else: | |
| raise ValueError(split_arg) | |
| return names, shifts | |
| def setup(self, stage: Optional[str] = None): | |
| if stage == "fit": | |
| stages = ["train", "val"] | |
| elif stage is None: | |
| stages = ["train", "val", "test"] | |
| else: | |
| stages = [stage] | |
| for stage in stages: | |
| self.splits[stage], self.shifts[stage] = self.parse_split( | |
| self.cfg.splits[stage] | |
| ) | |
| do_val_subset = "val" in stages and self.cfg.max_num_val is not None | |
| if do_val_subset and self.cfg.selection_subset_val == "random": | |
| select = np.random.RandomState(self.cfg.seed).choice( | |
| len(self.splits["val"]), self.cfg.max_num_val, replace=False | |
| ) | |
| self.splits["val"] = [self.splits["val"][i] for i in select] | |
| if self.shifts["val"] is not None: | |
| self.shifts["val"] = self.shifts["val"][select] | |
| dates = {d for ns in self.splits.values() for d, _, _ in ns} | |
| for d in dates: | |
| self.calibrations[d] = get_camera_calibration( | |
| self.root / d, self.cfg.camera_index | |
| ) | |
| if self.tile_manager is None: | |
| logger.info("Loading the tile manager...") | |
| self.tile_manager = TileManager.load(self.root / self.cfg.tiles_filename) | |
| self.cfg.num_classes = {k: len(g) for k, g in self.tile_manager.groups.items()} | |
| self.cfg.pixel_per_meter = self.tile_manager.ppm | |
| # pack all attributes in a single tensor to optimize memory access | |
| self.pack_data(stages) | |
| dists = None | |
| if do_val_subset and self.cfg.selection_subset_val == "furthest": | |
| dists = torch.cdist( | |
| self.data["val"]["t_c2w"][:, :2].double(), | |
| self.data["train"]["t_c2w"][:, :2].double(), | |
| ) | |
| min_dists = dists.min(1).values | |
| select = torch.argsort(min_dists)[-self.cfg.max_num_val :] | |
| dists = dists[select] | |
| self.splits["val"] = [self.splits["val"][i] for i in select] | |
| if self.shifts["val"] is not None: | |
| self.shifts["val"] = self.shifts["val"][select] | |
| for k in list(self.data["val"]): | |
| if k != "cameras": | |
| self.data["val"][k] = self.data["val"][k][select] | |
| self.image_paths["val"] = self.image_paths["val"][select] | |
| if "train" in stages and self.cfg.drop_train_too_close_to_val is not None: | |
| if dists is None: | |
| dists = torch.cdist( | |
| self.data["val"]["t_c2w"][:, :2].double(), | |
| self.data["train"]["t_c2w"][:, :2].double(), | |
| ) | |
| drop = torch.any(dists < self.cfg.drop_train_too_close_to_val, 0) | |
| select = torch.where(~drop)[0] | |
| logger.info( | |
| "Dropping %d (%f %%) images that are too close to validation images.", | |
| drop.sum(), | |
| drop.float().mean(), | |
| ) | |
| self.splits["train"] = [self.splits["train"][i] for i in select] | |
| if self.shifts["train"] is not None: | |
| self.shifts["train"] = self.shifts["train"][select] | |
| for k in list(self.data["train"]): | |
| if k != "cameras": | |
| self.data["train"][k] = self.data["train"][k][select] | |
| self.image_paths["train"] = self.image_paths["train"][select] | |
| def pack_data(self, stages): | |
| for stage in stages: | |
| names = [] | |
| data = {} | |
| for i, (date, drive, index) in enumerate(self.splits[stage]): | |
| d = self.get_frame_data(date, drive, index) | |
| for k, v in d.items(): | |
| if i == 0: | |
| data[k] = [] | |
| data[k].append(v) | |
| path = f"{date}/{drive}/image_{self.cfg.camera_index:02}/data/{index}" | |
| names.append((self.dummy_scene_name, f"{date}/{drive}", path)) | |
| for k in list(data): | |
| data[k] = torch.from_numpy(np.stack(data[k])) | |
| data["camera_id"] = np.full(len(names), self.cfg.camera_index) | |
| sequences = {date_drive for _, date_drive, _ in names} | |
| data["cameras"] = { | |
| self.dummy_scene_name: { | |
| seq: { | |
| self.cfg.camera_index: self.calibrations[seq.split("/")[0]][0] | |
| } | |
| for seq in sequences | |
| } | |
| } | |
| shifts = self.shifts[stage] | |
| if shifts is not None: | |
| data["shifts"] = torch.from_numpy(shifts.astype(np.float32)) | |
| self.data[stage] = data | |
| self.image_paths[stage] = np.array(names) | |
| def get_frame_data(self, date, drive, index): | |
| _, R_cam_gps, t_cam_gps = self.calibrations[date] | |
| # Transform the GPS pose to the camera pose | |
| gps_path = ( | |
| self.root / date / drive / "oxts/data" / Path(index).with_suffix(".txt") | |
| ) | |
| _, R_world_gps, t_world_gps = parse_gps_file( | |
| gps_path, self.tile_manager.projection | |
| ) | |
| R_world_cam = R_world_gps @ R_cam_gps.T | |
| t_world_cam = t_world_gps - R_world_gps @ R_cam_gps.T @ t_cam_gps | |
| # Some voodoo to extract correct Euler angles from R_world_cam | |
| R_cv_xyz = Rotation.from_euler("YX", [-90, 90], degrees=True).as_matrix() | |
| R_world_cam_xyz = R_world_cam @ R_cv_xyz | |
| y, p, r = Rotation.from_matrix(R_world_cam_xyz).as_euler("ZYX", degrees=True) | |
| roll, pitch, yaw = r, -p, 90 - y | |
| roll_pitch_yaw = np.array([-roll, -pitch, yaw], np.float32) # for some reason | |
| return { | |
| "t_c2w": t_world_cam.astype(np.float32), | |
| "roll_pitch_yaw": roll_pitch_yaw, | |
| "index": int(index.split(".")[0]), | |
| } | |
| def dataset(self, stage: str): | |
| return MapLocDataset( | |
| stage, | |
| self.cfg, | |
| self.image_paths[stage], | |
| self.data[stage], | |
| {self.dummy_scene_name: self.root}, | |
| {self.dummy_scene_name: self.tile_manager}, | |
| ) | |
| def dataloader( | |
| self, | |
| stage: str, | |
| shuffle: bool = False, | |
| num_workers: int = None, | |
| sampler: Optional[torchdata.Sampler] = None, | |
| ): | |
| dataset = self.dataset(stage) | |
| cfg = self.cfg["loading"][stage] | |
| num_workers = cfg["num_workers"] if num_workers is None else num_workers | |
| loader = torchdata.DataLoader( | |
| dataset, | |
| batch_size=cfg["batch_size"], | |
| num_workers=num_workers, | |
| shuffle=shuffle or (stage == "train"), | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| worker_init_fn=worker_init_fn, | |
| collate_fn=collate, | |
| sampler=sampler, | |
| ) | |
| return loader | |
| def train_dataloader(self, **kwargs): | |
| return self.dataloader("train", **kwargs) | |
| def val_dataloader(self, **kwargs): | |
| return self.dataloader("val", **kwargs) | |
| def test_dataloader(self, **kwargs): | |
| return self.dataloader("test", **kwargs) | |
| def sequence_dataset(self, stage: str, **kwargs): | |
| keys = self.image_paths[stage] | |
| # group images by sequence (date/drive) | |
| seq2indices = defaultdict(list) | |
| for index, (_, date_drive, _) in enumerate(keys): | |
| seq2indices[date_drive].append(index) | |
| # chunk the sequences to the required length | |
| chunk2indices = {} | |
| for seq, indices in seq2indices.items(): | |
| chunks = chunk_sequence( | |
| self.data[stage], indices, names=self.image_paths[stage], **kwargs | |
| ) | |
| for i, sub_indices in enumerate(chunks): | |
| chunk2indices[seq, i] = sub_indices | |
| # store the index of each chunk in its sequence | |
| chunk_indices = torch.full((len(keys),), -1) | |
| for (_, chunk_index), idx in chunk2indices.items(): | |
| chunk_indices[idx] = chunk_index | |
| self.data[stage]["chunk_index"] = chunk_indices | |
| dataset = self.dataset(stage) | |
| return dataset, chunk2indices | |
| def sequence_dataloader(self, stage: str, shuffle: bool = False, **kwargs): | |
| dataset, chunk2idx = self.sequence_dataset(stage, **kwargs) | |
| seq_keys = sorted(chunk2idx) | |
| if shuffle: | |
| perm = torch.randperm(len(seq_keys)) | |
| seq_keys = [seq_keys[i] for i in perm] | |
| key_indices = [i for key in seq_keys for i in chunk2idx[key]] | |
| num_workers = self.cfg["loading"][stage]["num_workers"] | |
| loader = torchdata.DataLoader( | |
| dataset, | |
| batch_size=None, | |
| sampler=key_indices, | |
| num_workers=num_workers, | |
| shuffle=False, | |
| pin_memory=True, | |
| persistent_workers=num_workers > 0, | |
| worker_init_fn=worker_init_fn, | |
| collate_fn=collate, | |
| ) | |
| return loader, seq_keys, chunk2idx | |