| import os |
| import torch |
| import numpy as np |
| from pyquaternion import Quaternion |
| from nuscenes.nuscenes import NuScenes |
| from itertools import chain |
| from PIL import Image |
| from torchvision import transforms as T |
| import torchvision.transforms as tvf |
| from torchvision.transforms.functional import to_tensor |
|
|
| from .splits_roddick import create_splits_scenes_roddick |
| from ..image import pad_image, rectify_image, resize_image |
| from .utils import decode_binary_labels |
| from ..utils import decompose_rotmat |
| from ...utils.io import read_image |
| from ...utils.wrappers import Camera |
| from ..schema import NuScenesDataConfiguration |
|
|
|
|
| class NuScenesDataset(torch.utils.data.Dataset): |
| def __init__(self, cfg: NuScenesDataConfiguration, split="train"): |
|
|
| self.cfg = cfg |
| self.nusc = NuScenes(version=cfg.version, dataroot=str(cfg.data_dir)) |
| self.map_data_root = cfg.map_dir |
| self.split = split |
|
|
| self.scenes = create_splits_scenes_roddick() |
|
|
| scene_split = { |
| 'v1.0-trainval': {'train': 'train', 'val': 'val', 'test': 'val'}, |
| 'v1.0-mini': {'train': 'mini_train', 'val': 'mini_val'}, |
| }[cfg.version][split] |
| self.scenes = self.scenes[scene_split] |
| self.sample = list(filter(lambda sample: self.nusc.get( |
| 'scene', sample['scene_token'])['name'] in self.scenes, self.nusc.sample)) |
|
|
| self.tfs = self.get_augmentations() if split == "train" else T.Compose([]) |
|
|
| data_tokens = [] |
| for sample in self.sample: |
| data_token = sample['data'] |
| data_token = [v for k,v in data_token.items() if k == "CAM_FRONT"] |
|
|
| data_tokens.append(data_token) |
|
|
| data_tokens = list(chain.from_iterable(data_tokens)) |
| data = [self.nusc.get('sample_data', token) for token in data_tokens] |
|
|
| self.data = [] |
| for d in data: |
| sample = self.nusc.get('sample', d['sample_token']) |
| scene = self.nusc.get('scene', sample['scene_token']) |
| location = self.nusc.get('log', scene['log_token'])['location'] |
|
|
| file_name = d['filename'] |
| ego_pose = self.nusc.get('ego_pose', d['ego_pose_token']) |
| calibrated_sensor = self.nusc.get( |
| "calibrated_sensor", d['calibrated_sensor_token']) |
|
|
| ego2global = np.eye(4).astype(np.float32) |
| ego2global[:3, :3] = Quaternion(ego_pose['rotation']).rotation_matrix |
| ego2global[:3, 3] = ego_pose['translation'] |
|
|
| sensor2ego = np.eye(4).astype(np.float32) |
| sensor2ego[:3, :3] = Quaternion( |
| calibrated_sensor['rotation']).rotation_matrix |
| sensor2ego[:3, 3] = calibrated_sensor['translation'] |
|
|
| sensor2global = ego2global @ sensor2ego |
|
|
| rotation = sensor2global[:3, :3] |
| roll, pitch, yaw = decompose_rotmat(rotation) |
|
|
| fx = calibrated_sensor['camera_intrinsic'][0][0] |
| fy = calibrated_sensor['camera_intrinsic'][1][1] |
| cx = calibrated_sensor['camera_intrinsic'][0][2] |
| cy = calibrated_sensor['camera_intrinsic'][1][2] |
| width = d['width'] |
| height = d['height'] |
|
|
| cam = Camera(torch.tensor( |
| [width, height, fx, fy, cx - 0.5, cy - 0.5])).float() |
| self.data.append({ |
| 'filename': file_name, |
| 'yaw': yaw, |
| 'pitch': pitch, |
| 'roll': roll, |
| 'cam': cam, |
| 'sensor2global': sensor2global, |
| 'token': d['token'], |
| 'sample_token': d['sample_token'], |
| 'location': location |
| }) |
| |
| if self.cfg.percentage < 1.0 and split == "train": |
| self.data = self.data[:int(len(self.data) * self.cfg.percentage)] |
|
|
| def get_augmentations(self): |
|
|
| print(f"Augmentation!", "\n" * 10) |
| augmentations = [ |
| tvf.ColorJitter( |
| brightness=self.cfg.augmentations.brightness, |
| contrast=self.cfg.augmentations.contrast, |
| saturation=self.cfg.augmentations.saturation, |
| hue=self.cfg.augmentations.hue, |
| ) |
| ] |
|
|
| if self.cfg.augmentations.random_resized_crop: |
| augmentations.append( |
| tvf.RandomResizedCrop(scale=(0.8, 1.0)) |
| ) |
|
|
| if self.cfg.augmentations.gaussian_noise.enabled: |
| augmentations.append( |
| tvf.GaussianNoise( |
| mean=self.cfg.augmentations.gaussian_noise.mean, |
| std=self.cfg.augmentations.gaussian_noise.std, |
| ) |
| ) |
|
|
| if self.cfg.augmentations.brightness_contrast.enabled: |
| augmentations.append( |
| tvf.ColorJitter( |
| brightness=self.cfg.augmentations.brightness_contrast.brightness_factor, |
| contrast=self.cfg.augmentations.brightness_contrast.contrast_factor, |
| saturation=0, |
| hue=0, |
| ) |
| ) |
|
|
| return tvf.Compose(augmentations) |
|
|
| def __len__(self): |
| return len(self.data) |
|
|
| def __getitem__(self, idx): |
| d = self.data[idx] |
|
|
| image = read_image(os.path.join(self.nusc.dataroot, d['filename'])) |
| image = np.array(image) |
| cam = d['cam'] |
| roll = d['roll'] |
| pitch = d['pitch'] |
| yaw = d['yaw'] |
|
|
| with Image.open(self.map_data_root / f"{d['token']}.png") as semantic_image: |
| semantic_mask = to_tensor(semantic_image) |
|
|
| semantic_mask = decode_binary_labels(semantic_mask, self.cfg.num_classes + 1) |
| semantic_mask = torch.nn.functional.max_pool2d(semantic_mask.float(), (2, 2), stride=2) |
| semantic_mask = semantic_mask.permute(1, 2, 0) |
| semantic_mask = torch.flip(semantic_mask, [0]) |
| |
| visibility_mask = semantic_mask[..., -1] |
| semantic_mask = semantic_mask[..., :-1] |
|
|
| if self.cfg.class_mapping is not None: |
| semantic_mask = semantic_mask[..., self.cfg.class_mapping] |
|
|
| image = ( |
| torch.from_numpy(np.ascontiguousarray(image)) |
| .permute(2, 0, 1) |
| .float() |
| .div_(255) |
| ) |
|
|
| if not self.cfg.gravity_align: |
| |
| roll = 0.0 |
| pitch = 0.0 |
| image, valid = rectify_image(image, cam, roll, pitch) |
|
|
| else: |
| image, valid = rectify_image( |
| image, cam, roll, pitch if self.cfg.rectify_pitch else None |
| ) |
| roll = 0.0 |
| if self.cfg.rectify_pitch: |
| pitch = 0.0 |
| if self.cfg.resize_image is not None: |
| image, _, cam, valid = resize_image( |
| image, self.cfg.resize_image, fn=max, camera=cam, valid=valid |
| ) |
| if self.cfg.pad_to_square: |
| image, valid, cam = pad_image(image, self.cfg.resize_image, cam, valid) |
| image = self.tfs(image) |
|
|
| confidence_map = visibility_mask.clone().float() |
| confidence_map = (confidence_map - confidence_map.min()) / (confidence_map.max() - confidence_map.min()) |
|
|
| return { |
| "image": image, |
| "roll_pitch_yaw": torch.tensor([roll, pitch, yaw]).float(), |
| "camera": cam, |
| "valid": valid, |
| "seg_masks": semantic_mask.float(), |
| "token": d['token'], |
| "sample_token": d['sample_token'], |
| 'location': d['location'], |
| 'flood_masks': visibility_mask.float(), |
| "confidence_map": confidence_map, |
| 'name': d['sample_token'] |
| } |
|
|