| import os
|
| import numpy as np
|
| import torch.utils.data as data
|
| import umsgpack
|
| from PIL import Image
|
| import json
|
| import torchvision.transforms as tvf
|
|
|
| from .transform import BEVTransform
|
| from ..schema import KITTIDataConfiguration
|
|
|
| class BEVKitti360Dataset(data.Dataset):
|
| _IMG_DIR = "img"
|
| _BEV_MSK_DIR = "bev_msk"
|
| _BEV_PLABEL_DIR = "bev_plabel_dynamic"
|
| _FV_MSK_DIR = "front_msk_seam"
|
| _BEV_DIR = "bev_ortho"
|
| _LST_DIR = "split"
|
| _PERCENTAGES_DIR = "percentages"
|
| _BEV_METADATA_FILE = "metadata_ortho.bin"
|
| _FV_METADATA_FILE = "metadata_front.bin"
|
|
|
| def __init__(self, cfg: KITTIDataConfiguration, split_name="train"):
|
| super(BEVKitti360Dataset, self).__init__()
|
| self.cfg = cfg
|
| self.seam_root_dir = cfg.seam_root_dir
|
| self.kitti_root_dir = cfg.dataset_root_dir
|
| self.split_name = split_name
|
|
|
| self.rgb_cameras = ['front']
|
| if cfg.bev_percentage < 1:
|
| self.bev_percentage = cfg.bev_percentage
|
| else:
|
| self.bev_percentage = int(cfg.bev_percentage)
|
|
|
|
|
| self._img_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._IMG_DIR)
|
| self._bev_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_MSK_DIR, BEVKitti360Dataset._BEV_DIR)
|
| self._bev_plabel_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_PLABEL_DIR, BEVKitti360Dataset._BEV_DIR)
|
| self._fv_msk_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_MSK_DIR, "front")
|
| self._lst_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR)
|
| self._percentages_dir = os.path.join(self.seam_root_dir, BEVKitti360Dataset._LST_DIR, BEVKitti360Dataset._PERCENTAGES_DIR)
|
|
|
|
|
| self._bev_meta, self._bev_images, self._bev_images_all, self._fv_meta, self._fv_images, self._fv_images_all,\
|
| self._img_map, self.bev_percent_split = self._load_split()
|
|
|
| self.tfs = self.get_augmentations() if split_name == "train" else tvf.Compose([])
|
| self.transform = BEVTransform(cfg, self.tfs)
|
|
|
| def get_augmentations(self):
|
|
|
| print(f"Augmentation!", "\n" * 10)
|
| augmentations = [
|
| tvf.ColorJitter(
|
| brightness=self.cfg.augmentations.brightness,
|
| contrast=self.cfg.augmentations.contrast,
|
| saturation=self.cfg.augmentations.saturation,
|
| hue=self.cfg.augmentations.hue,
|
| )
|
| ]
|
|
|
| if self.cfg.augmentations.random_resized_crop:
|
| augmentations.append(
|
| tvf.RandomResizedCrop(scale=(0.8, 1.0))
|
| )
|
|
|
| if self.cfg.augmentations.gaussian_noise.enabled:
|
| augmentations.append(
|
| tvf.GaussianNoise(
|
| mean=self.cfg.augmentations.gaussian_noise.mean,
|
| std=self.cfg.augmentations.gaussian_noise.std,
|
| )
|
| )
|
|
|
| if self.cfg.augmentations.brightness_contrast.enabled:
|
| augmentations.append(
|
| tvf.ColorJitter(
|
| brightness=self.cfg.augmentations.brightness_contrast.brightness_factor,
|
| contrast=self.cfg.augmentations.brightness_contrast.contrast_factor,
|
| saturation=0,
|
| hue=0,
|
| )
|
| )
|
|
|
| return tvf.Compose(augmentations)
|
|
|
|
|
| def _load_split(self):
|
| with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._BEV_METADATA_FILE), "rb") as fid:
|
| bev_metadata = umsgpack.unpack(fid, encoding="utf-8")
|
|
|
| with open(os.path.join(self.seam_root_dir, BEVKitti360Dataset._FV_METADATA_FILE), 'rb') as fid:
|
| fv_metadata = umsgpack.unpack(fid, encoding="utf-8")
|
|
|
|
|
| with open(os.path.join(self._lst_dir, self.split_name + ".txt"), "r") as fid:
|
| lst = fid.readlines()
|
| lst = [line.strip() for line in lst]
|
|
|
| if self.split_name == "train":
|
|
|
| with open(os.path.join(self._lst_dir, "{}_all.txt".format(self.split_name)), 'r') as fid:
|
| lst_all = fid.readlines()
|
| lst_all = [line.strip() for line in lst_all]
|
|
|
|
|
| percentage_file = os.path.join(self._percentages_dir, "{}_{}.txt".format(self.split_name, self.bev_percentage))
|
| print("Loading {}% file".format(self.bev_percentage))
|
| with open(percentage_file, 'r') as fid:
|
| lst_percent = fid.readlines()
|
| lst_percent = [line.strip() for line in lst_percent]
|
| else:
|
| lst_all = lst
|
| lst_percent = lst
|
|
|
|
|
| fv_msk_frames = os.listdir(self._fv_msk_dir)
|
| fv_msk_frames = [frame.split(".")[0] for frame in fv_msk_frames]
|
| fv_msk_frames_exist_map = {entry: True for entry in fv_msk_frames}
|
| lst = [entry for entry in lst if entry in fv_msk_frames_exist_map]
|
| lst_all = [entry for entry in lst_all if entry in fv_msk_frames_exist_map]
|
|
|
|
|
| if self.bev_percentage < 100:
|
| lst_filt = [entry for entry in lst if entry in lst_percent]
|
| lst = lst_filt
|
|
|
|
|
| lst = set(lst)
|
| lst_percent = set(lst_percent)
|
|
|
| img_map = {}
|
| for camera in self.rgb_cameras:
|
| with open(os.path.join(self._img_dir, "{}.json".format(camera))) as fp:
|
| map_list = json.load(fp)
|
| map_dict = {k: v for d in map_list for k, v in d.items()}
|
| img_map[camera] = map_dict
|
|
|
| bev_meta = bev_metadata["meta"]
|
| bev_images = [img_desc for img_desc in bev_metadata["images"] if img_desc["id"] in lst]
|
| fv_meta = fv_metadata["meta"]
|
| fv_images = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst]
|
|
|
|
|
| bev_images_ids = [bev_img["id"] for bev_img in bev_images]
|
| fv_images_ids = [fv_img["id"] for fv_img in fv_images]
|
| assert set(bev_images_ids) == set(fv_images_ids) and len(bev_images_ids) == len(fv_images_ids), 'Inconsistency between fv_images and bev_images detected'
|
|
|
| if lst_all is not None:
|
| bev_images_all = [img_desc for img_desc in bev_metadata['images'] if img_desc['id'] in lst_all]
|
| fv_images_all = [img_desc for img_desc in fv_metadata['images'] if img_desc['id'] in lst_all]
|
| else:
|
| bev_images_all, fv_images_all = None, None
|
|
|
| return bev_meta, bev_images, bev_images_all, fv_meta, fv_images, fv_images_all, img_map, lst_percent
|
|
|
| def _find_index(self, list, key, value):
|
| for i, dic in enumerate(list):
|
| if dic[key] == value:
|
| return i
|
| return None
|
|
|
| def _load_item(self, item_idx):
|
|
|
| all_idx = self._find_index(self._fv_images_all, "id", self._fv_images[item_idx]['id'])
|
| if all_idx is None:
|
| raise IOError("Required index not found!")
|
|
|
| bev_img_desc = self._bev_images[item_idx]
|
| fv_img_desc = self._fv_images[item_idx]
|
|
|
| scene, frame_id = self._bev_images[item_idx]["id"].split(";")
|
|
|
|
|
| img_file = os.path.join(
|
| self.kitti_root_dir,
|
| self._img_map["front"]["{}.png"
|
| .format(bev_img_desc['id'])]
|
| )
|
|
|
| if not os.path.exists(img_file):
|
| raise IOError(
|
| "RGB image not found! Scene: {}, Frame: {}".format(scene, frame_id)
|
| )
|
|
|
|
|
| img = Image.open(img_file).convert(mode="RGB")
|
|
|
|
|
| bev_msk_file = os.path.join(
|
| self._bev_msk_dir,
|
| "{}.png".format(bev_img_desc['id'])
|
| )
|
| bev_msk = Image.open(bev_msk_file)
|
| bev_plabel = None
|
|
|
|
|
| fv_msk_file = os.path.join(
|
| self._fv_msk_dir,
|
| "{}.png".format(fv_img_desc['id'])
|
| )
|
| fv_msk = Image.open(fv_msk_file)
|
|
|
|
|
| bev_weights_msk_combined = None
|
|
|
|
|
| bev_cat = bev_img_desc["cat"]
|
| bev_iscrowd = bev_img_desc["iscrowd"]
|
| fv_cat = fv_img_desc['cat']
|
| fv_iscrowd = fv_img_desc['iscrowd']
|
| fv_intrinsics = fv_img_desc["cam_intrinsic"]
|
| ego_pose = fv_img_desc['ego_pose']
|
|
|
|
|
| frame_ids = bev_img_desc["id"]
|
|
|
| return img, bev_msk, bev_plabel, fv_msk, bev_weights_msk_combined, bev_cat, \
|
| bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, frame_ids
|
|
|
| @property
|
| def fv_categories(self):
|
| """Category names"""
|
| return self._fv_meta["categories"]
|
|
|
| @property
|
| def fv_num_categories(self):
|
| """Number of categories"""
|
| return len(self.fv_categories)
|
|
|
| @property
|
| def fv_num_stuff(self):
|
| """Number of "stuff" categories"""
|
| return self._fv_meta["num_stuff"]
|
|
|
| @property
|
| def fv_num_thing(self):
|
| """Number of "thing" categories"""
|
| return self.fv_num_categories - self.fv_num_stuff
|
|
|
| @property
|
| def bev_categories(self):
|
| """Category names"""
|
| return self._bev_meta["categories"]
|
|
|
| @property
|
| def bev_num_categories(self):
|
| """Number of categories"""
|
| return len(self.bev_categories)
|
|
|
| @property
|
| def bev_num_stuff(self):
|
| """Number of "stuff" categories"""
|
| return self._bev_meta["num_stuff"]
|
|
|
| @property
|
| def bev_num_thing(self):
|
| """Number of "thing" categories"""
|
| return self.bev_num_categories - self.bev_num_stuff
|
|
|
| @property
|
| def original_ids(self):
|
| """Original class id of each category"""
|
| return self._fv_meta["original_ids"]
|
|
|
| @property
|
| def palette(self):
|
| """Default palette to be used when color-coding semantic labels"""
|
| return np.array(self._fv_meta["palette"], dtype=np.uint8)
|
|
|
| @property
|
| def img_sizes(self):
|
| """Size of each image of the dataset"""
|
| return [img_desc["size"] for img_desc in self._fv_images]
|
|
|
| @property
|
| def img_categories(self):
|
| """Categories present in each image of the dataset"""
|
| return [img_desc["cat"] for img_desc in self._fv_images]
|
|
|
| @property
|
| def dataset_name(self):
|
| return "Kitti360"
|
|
|
| def __len__(self):
|
| if self.cfg.percentage < 1:
|
| return int(len(self._fv_images) * self.cfg.percentage)
|
|
|
| return len(self._fv_images)
|
|
|
| def __getitem__(self, item):
|
| img, bev_msk, bev_plabel, fv_msk, bev_weights_msk, bev_cat, bev_iscrowd, fv_cat, fv_iscrowd, fv_intrinsics, ego_pose, idx = self._load_item(item)
|
|
|
| rec = self.transform(img=img, bev_msk=bev_msk, bev_plabel=bev_plabel, fv_msk=fv_msk, bev_weights_msk=bev_weights_msk, bev_cat=bev_cat,
|
| bev_iscrowd=bev_iscrowd, fv_cat=fv_cat, fv_iscrowd=fv_iscrowd, fv_intrinsics=fv_intrinsics,
|
| ego_pose=ego_pose)
|
| size = (img.size[1], img.size[0])
|
|
|
|
|
| img.close()
|
| bev_msk.close()
|
| fv_msk.close()
|
|
|
| rec["index"] = idx
|
| rec["size"] = size
|
| rec['name'] = idx
|
|
|
| return rec
|
|
|
| def get_image_desc(self, idx):
|
| """Look up an image descriptor given the id"""
|
| matching = [img_desc for img_desc in self._images if img_desc["id"] == idx]
|
| if len(matching) == 1:
|
| return matching[0]
|
| else:
|
| raise ValueError("No image found with id %s" % idx) |