Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import time | |
| import torch | |
| import os | |
| from PIL import Image | |
| from torchvision import transforms | |
| from torch.utils.data import Dataset | |
| from collections import namedtuple | |
| from datasets.kitti_360.labels import trainId2label | |
| Label = namedtuple( | |
| "Label", | |
| [ | |
| "name", | |
| "id", | |
| "trainId", | |
| "category", | |
| "categoryId", | |
| "hasInstances", | |
| "ignoreInEval", | |
| "color", | |
| "to_cs27", | |
| ], | |
| ) | |
| BDD_LABEL = [ | |
| Label("unlabeled", 0, 255, "void", 0, False, True, (0, 0, 0), 255), | |
| Label("dynamic", 1, 255, "void", 0, False, True, (111, 74, 0), 255), | |
| Label("ego vehicle", 2, 255, "void", 0, False, True, (0, 0, 0), 255), | |
| Label("ground", 3, 255, "void", 0, False, True, (81, 0, 81), 255), | |
| Label("static", 4, 255, "void", 0, False, True, (0, 0, 0), 255), | |
| Label("parking", 5, 255, "flat", 1, False, True, (250, 170, 160), 2), | |
| Label("rail track", 6, 255, "flat", 1, False, True, (230, 150, 140), 3), | |
| Label("road", 7, 0, "flat", 1, False, False, (128, 64, 128), 0), | |
| Label("sidewalk", 8, 1, "flat", 1, False, False, (244, 35, 232), 1), | |
| Label("bridge", 9, 255, "construction", 2, False, True, (150, 100, 100), 8), | |
| Label("building", 10, 2, "construction", 2, False, False, (70, 70, 70), 4), | |
| Label("fence", 11, 4, "construction", 2, False, False, (190, 153, 153), 6), | |
| Label("garage", 12, 255, "construction", 2, False, True, (180, 100, 180), 255), | |
| Label("guard rail", 13, 255, "construction", 2, False, True, (180, 165, 180), 7), | |
| Label("tunnel", 14, 255, "construction", 2, False, True, (150, 120, 90), 9), | |
| Label("wall", 15, 3, "construction", 2, False, False, (102, 102, 156), 5), | |
| Label("banner", 16, 255, "object", 3, False, True, (250, 170, 100), 255), | |
| Label("billboard", 17, 255, "object", 3, False, True, (220, 220, 250), 255), | |
| Label("lane divider", 18, 255, "object", 3, False, True, (255, 165, 0), 255), | |
| Label("parking sign", 19, 255, "object", 3, False, False, (220, 20, 60), 255), | |
| Label("pole", 20, 5, "object", 3, False, False, (153, 153, 153), 10), | |
| Label("polegroup", 21, 255, "object", 3, False, True, (153, 153, 153), 11), | |
| Label("street light", 22, 255, "object", 3, False, True, (220, 220, 100), 255), | |
| Label("traffic cone", 23, 255, "object", 3, False, True, (255, 70, 0), 255), | |
| Label("traffic device", 24, 255, "object", 3, False, True, (220, 220, 220), 255), | |
| Label("traffic light", 25, 6, "object", 3, False, False, (250, 170, 30), 12), | |
| Label("traffic sign", 26, 7, "object", 3, False, False, (220, 220, 0), 13), | |
| Label("traffic sign frame", 27, 255, "object", 3, False, True, (250, 170, 250), 255), | |
| Label("terrain", 28, 9, "nature", 4, False, False, (152, 251, 152), 15), | |
| Label("vegetation", 29, 8, "nature", 4, False, False, (107, 142, 35), 14), | |
| Label("sky", 30, 10, "sky", 5, False, False, (70, 130, 180), 16), | |
| Label("person", 31, 11, "human", 6, True, False, (220, 20, 60), 17), | |
| Label("rider", 32, 12, "human", 6, True, False, (255, 0, 0), 18), | |
| Label("bicycle", 33, 18, "vehicle", 7, True, False, (119, 11, 32), 26), | |
| Label("bus", 34, 15, "vehicle", 7, True, False, (0, 60, 100), 21), | |
| Label("car", 35, 13, "vehicle", 7, True, False, (0, 0, 142), 19), | |
| Label("caravan", 36, 255, "vehicle", 7, True, True, (0, 0, 90), 22), | |
| Label("motorcycle", 37, 17, "vehicle", 7, True, False, (0, 0, 230), 25), | |
| Label("trailer", 38, 255, "vehicle", 7, True, True, (0, 0, 110), 23), | |
| Label("train", 39, 16, "vehicle", 7, True, False, (0, 80, 100), 24), | |
| Label("truck", 40, 14, "vehicle", 7, True, False, (0, 0, 70), 20), | |
| ] | |
| def resize_with_padding(img, target_size, padding_value, interpolation): | |
| target_h, target_w = target_size | |
| width, height = img.size | |
| aspect = width / height | |
| if aspect > (target_w / target_h): | |
| new_w = target_w | |
| new_h = int(target_w / aspect) | |
| else: | |
| new_h = target_h | |
| new_w = int(target_h * aspect) | |
| img = transforms.functional.resize(img, (new_h, new_w), interpolation) | |
| pad_h = target_h - new_h | |
| pad_w = target_w - new_w | |
| padding = (pad_w // 2, pad_h // 2, pad_w - pad_w // 2, pad_h - pad_h // 2) | |
| return transforms.functional.pad(img, padding, fill=padding_value) | |
| class BDDSeg(Dataset): | |
| def __init__(self, root, image_set, image_size=(192, 640)): | |
| super(BDDSeg, self).__init__() | |
| self.split = image_set | |
| self.root = root | |
| self.image_transform = transforms.Compose([ | |
| #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=0, interpolation=transforms.InterpolationMode.BILINEAR)), | |
| transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.BILINEAR), | |
| transforms.CenterCrop(image_size), | |
| transforms.ToTensor(), | |
| ]) | |
| self.target_transform = transforms.Compose([ | |
| #transforms.Lambda(lambda img: resize_with_padding(img, image_size, padding_value=-1, interpolation=transforms.InterpolationMode.NEAREST)), | |
| transforms.Resize((320, 640), interpolation=transforms.InterpolationMode.NEAREST), | |
| transforms.CenterCrop(image_size), | |
| transforms.PILToTensor(), | |
| transforms.Lambda(lambda x: x.long()), | |
| ]) | |
| self.images, self.targets = [], [] | |
| image_dir = os.path.join(self.root, "images/10k", self.split) | |
| target_dir = os.path.join(self.root, "labels/pan_seg/bitmasks", self.split) | |
| for file_name in os.listdir(image_dir): | |
| image_path = os.path.join(image_dir, file_name) | |
| target_filename = os.path.splitext(file_name)[0] + ".png" | |
| target_path = os.path.join(target_dir, target_filename) | |
| assert os.path.isfile(target_path) | |
| self.images.append(image_path) | |
| self.targets.append(target_path) | |
| self.class_mapping = torch.Tensor([trainId2label[c.trainId].id for c in BDD_LABEL]).int() | |
| def __getitem__(self, index): | |
| _start_time = time.time() | |
| image = Image.open(self.images[index]).convert("RGB") | |
| target = Image.open(self.targets[index]) | |
| image = self.image_transform(image) | |
| target = self.target_transform(target) | |
| image = 2.0 * image - 1.0 | |
| poses = torch.eye(4) # (4, 4) | |
| projs = torch.eye(3) # (3, 3) | |
| target = target[0] # ("instance", "semantic", "polygon", "color") | |
| target = self.class_mapping[target] | |
| _proc_time = time.time() - _start_time | |
| data = { | |
| "imgs": [image.numpy()], | |
| "poses": [poses.numpy()], | |
| "projs": [projs.numpy()], | |
| "segs": [target.numpy()], | |
| "t__get_item__": np.array([_proc_time]), | |
| "index": [np.array([index])], | |
| } | |
| return data | |
| def __len__(self): | |
| return len(self.images) | |