Spaces:
Build error
Build error
| import cv2 | |
| import numpy as np | |
| # np.set_printoptions(threshold=np.inf) | |
| import random | |
| import torch | |
| import torchvision.transforms as transforms | |
| # from visualization import plot_img_and_mask,plot_one_box,show_seg_result | |
| from pathlib import Path | |
| from PIL import Image | |
| from torch.utils.data import Dataset | |
| from ..utils import letterbox, augment_hsv, random_perspective, xyxy2xywh, cutout | |
| class AutoDriveDataset(Dataset): | |
| """ | |
| A general Dataset for some common function | |
| """ | |
| def __init__(self, cfg, is_train, inputsize=640, transform=None): | |
| """ | |
| initial all the characteristic | |
| Inputs: | |
| -cfg: configurations | |
| -is_train(bool): whether train set or not | |
| -transform: ToTensor and Normalize | |
| Returns: | |
| None | |
| """ | |
| self.is_train = is_train | |
| self.cfg = cfg | |
| self.transform = transform | |
| self.inputsize = inputsize | |
| self.Tensor = transforms.ToTensor() | |
| img_root = Path(cfg.DATASET.DATAROOT) | |
| label_root = Path(cfg.DATASET.LABELROOT) | |
| mask_root = Path(cfg.DATASET.MASKROOT) | |
| lane_root = Path(cfg.DATASET.LANEROOT) | |
| if is_train: | |
| indicator = cfg.DATASET.TRAIN_SET | |
| else: | |
| indicator = cfg.DATASET.TEST_SET | |
| self.img_root = img_root / indicator | |
| self.label_root = label_root / indicator | |
| self.mask_root = mask_root / indicator | |
| self.lane_root = lane_root / indicator | |
| # self.label_list = self.label_root.iterdir() | |
| self.mask_list = self.mask_root.iterdir() | |
| self.db = [] | |
| self.data_format = cfg.DATASET.DATA_FORMAT | |
| self.scale_factor = cfg.DATASET.SCALE_FACTOR | |
| self.rotation_factor = cfg.DATASET.ROT_FACTOR | |
| self.flip = cfg.DATASET.FLIP | |
| self.color_rgb = cfg.DATASET.COLOR_RGB | |
| # self.target_type = cfg.MODEL.TARGET_TYPE | |
| self.shapes = np.array(cfg.DATASET.ORG_IMG_SIZE) | |
| def _get_db(self): | |
| """ | |
| finished on children Dataset(for dataset which is not in Bdd100k format, rewrite children Dataset) | |
| """ | |
| raise NotImplementedError | |
| def evaluate(self, cfg, preds, output_dir): | |
| """ | |
| finished on children dataset | |
| """ | |
| raise NotImplementedError | |
| def __len__(self,): | |
| """ | |
| number of objects in the dataset | |
| """ | |
| return len(self.db) | |
| def __getitem__(self, idx): | |
| """ | |
| Get input and groud-truth from database & add data augmentation on input | |
| Inputs: | |
| -idx: the index of image in self.db(database)(list) | |
| self.db(list) [a,b,c,...] | |
| a: (dictionary){'image':, 'information':} | |
| Returns: | |
| -image: transformed image, first passed the data augmentation in __getitem__ function(type:numpy), then apply self.transform | |
| -target: ground truth(det_gt,seg_gt) | |
| function maybe useful | |
| cv2.imread | |
| cv2.cvtColor(data, cv2.COLOR_BGR2RGB) | |
| cv2.warpAffine | |
| """ | |
| data = self.db[idx] | |
| img = cv2.imread(data["image"], cv2.IMREAD_COLOR | cv2.IMREAD_IGNORE_ORIENTATION) | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| # seg_label = cv2.imread(data["mask"], 0) | |
| if self.cfg.num_seg_class == 3: | |
| seg_label = cv2.imread(data["mask"]) | |
| else: | |
| seg_label = cv2.imread(data["mask"], 0) | |
| lane_label = cv2.imread(data["lane"], 0) | |
| #print(lane_label.shape) | |
| # print(seg_label.shape) | |
| # print(lane_label.shape) | |
| # print(seg_label.shape) | |
| resized_shape = self.inputsize | |
| if isinstance(resized_shape, list): | |
| resized_shape = max(resized_shape) | |
| h0, w0 = img.shape[:2] # orig hw | |
| r = resized_shape / max(h0, w0) # resize image to img_size | |
| if r != 1: # always resize down, only resize up if training with augmentation | |
| interp = cv2.INTER_AREA if r < 1 else cv2.INTER_LINEAR | |
| img = cv2.resize(img, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
| seg_label = cv2.resize(seg_label, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
| lane_label = cv2.resize(lane_label, (int(w0 * r), int(h0 * r)), interpolation=interp) | |
| h, w = img.shape[:2] | |
| (img, seg_label, lane_label), ratio, pad = letterbox((img, seg_label, lane_label), resized_shape, auto=True, scaleup=self.is_train) | |
| shapes = (h0, w0), ((h / h0, w / w0), pad) # for COCO mAP rescaling | |
| # ratio = (w / w0, h / h0) | |
| # print(resized_shape) | |
| det_label = data["label"] | |
| labels=[] | |
| if det_label.size > 0: | |
| # Normalized xywh to pixel xyxy format | |
| labels = det_label.copy() | |
| labels[:, 1] = ratio[0] * w * (det_label[:, 1] - det_label[:, 3] / 2) + pad[0] # pad width | |
| labels[:, 2] = ratio[1] * h * (det_label[:, 2] - det_label[:, 4] / 2) + pad[1] # pad height | |
| labels[:, 3] = ratio[0] * w * (det_label[:, 1] + det_label[:, 3] / 2) + pad[0] | |
| labels[:, 4] = ratio[1] * h * (det_label[:, 2] + det_label[:, 4] / 2) + pad[1] | |
| if self.is_train: | |
| combination = (img, seg_label, lane_label) | |
| (img, seg_label, lane_label), labels = random_perspective( | |
| combination=combination, | |
| targets=labels, | |
| degrees=self.cfg.DATASET.ROT_FACTOR, | |
| translate=self.cfg.DATASET.TRANSLATE, | |
| scale=self.cfg.DATASET.SCALE_FACTOR, | |
| shear=self.cfg.DATASET.SHEAR | |
| ) | |
| #print(labels.shape) | |
| augment_hsv(img, hgain=self.cfg.DATASET.HSV_H, sgain=self.cfg.DATASET.HSV_S, vgain=self.cfg.DATASET.HSV_V) | |
| # img, seg_label, labels = cutout(combination=combination, labels=labels) | |
| if len(labels): | |
| # convert xyxy to xywh | |
| labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) | |
| # Normalize coordinates 0 - 1 | |
| labels[:, [2, 4]] /= img.shape[0] # height | |
| labels[:, [1, 3]] /= img.shape[1] # width | |
| # if self.is_train: | |
| # random left-right flip | |
| lr_flip = True | |
| if lr_flip and random.random() < 0.5: | |
| img = np.fliplr(img) | |
| seg_label = np.fliplr(seg_label) | |
| lane_label = np.fliplr(lane_label) | |
| if len(labels): | |
| labels[:, 1] = 1 - labels[:, 1] | |
| # random up-down flip | |
| ud_flip = False | |
| if ud_flip and random.random() < 0.5: | |
| img = np.flipud(img) | |
| seg_label = np.filpud(seg_label) | |
| lane_label = np.filpud(lane_label) | |
| if len(labels): | |
| labels[:, 2] = 1 - labels[:, 2] | |
| else: | |
| if len(labels): | |
| # convert xyxy to xywh | |
| labels[:, 1:5] = xyxy2xywh(labels[:, 1:5]) | |
| # Normalize coordinates 0 - 1 | |
| labels[:, [2, 4]] /= img.shape[0] # height | |
| labels[:, [1, 3]] /= img.shape[1] # width | |
| labels_out = torch.zeros((len(labels), 6)) | |
| if len(labels): | |
| labels_out[:, 1:] = torch.from_numpy(labels) | |
| # Convert | |
| # img = img[:, :, ::-1].transpose(2, 0, 1) # BGR to RGB, to 3x416x416 | |
| # img = img.transpose(2, 0, 1) | |
| img = np.ascontiguousarray(img) | |
| # seg_label = np.ascontiguousarray(seg_label) | |
| # if idx == 0: | |
| # print(seg_label[:,:,0]) | |
| if self.cfg.num_seg_class == 3: | |
| _,seg0 = cv2.threshold(seg_label[:,:,0],128,255,cv2.THRESH_BINARY) | |
| _,seg1 = cv2.threshold(seg_label[:,:,1],1,255,cv2.THRESH_BINARY) | |
| _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY) | |
| else: | |
| _,seg1 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY) | |
| _,seg2 = cv2.threshold(seg_label,1,255,cv2.THRESH_BINARY_INV) | |
| _,lane1 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY) | |
| _,lane2 = cv2.threshold(lane_label,1,255,cv2.THRESH_BINARY_INV) | |
| # _,seg2 = cv2.threshold(seg_label[:,:,2],1,255,cv2.THRESH_BINARY) | |
| # # seg1[cutout_mask] = 0 | |
| # # seg2[cutout_mask] = 0 | |
| # seg_label /= 255 | |
| # seg0 = self.Tensor(seg0) | |
| if self.cfg.num_seg_class == 3: | |
| seg0 = self.Tensor(seg0) | |
| seg1 = self.Tensor(seg1) | |
| seg2 = self.Tensor(seg2) | |
| # seg1 = self.Tensor(seg1) | |
| # seg2 = self.Tensor(seg2) | |
| lane1 = self.Tensor(lane1) | |
| lane2 = self.Tensor(lane2) | |
| # seg_label = torch.stack((seg2[0], seg1[0]),0) | |
| if self.cfg.num_seg_class == 3: | |
| seg_label = torch.stack((seg0[0],seg1[0],seg2[0]),0) | |
| else: | |
| seg_label = torch.stack((seg2[0], seg1[0]),0) | |
| lane_label = torch.stack((lane2[0], lane1[0]),0) | |
| # _, gt_mask = torch.max(seg_label, 0) | |
| # _ = show_seg_result(img, gt_mask, idx, 0, save_dir='debug', is_gt=True) | |
| target = [labels_out, seg_label, lane_label] | |
| img = self.transform(img) | |
| return img, target, data["image"], shapes | |
| def select_data(self, db): | |
| """ | |
| You can use this function to filter useless images in the dataset | |
| Inputs: | |
| -db: (list)database | |
| Returns: | |
| -db_selected: (list)filtered dataset | |
| """ | |
| db_selected = ... | |
| return db_selected | |
| def collate_fn(batch): | |
| img, label, paths, shapes= zip(*batch) | |
| label_det, label_seg, label_lane = [], [], [] | |
| for i, l in enumerate(label): | |
| l_det, l_seg, l_lane = l | |
| l_det[:, 0] = i # add target image index for build_targets() | |
| label_det.append(l_det) | |
| label_seg.append(l_seg) | |
| label_lane.append(l_lane) | |
| return torch.stack(img, 0), [torch.cat(label_det, 0), torch.stack(label_seg, 0), torch.stack(label_lane, 0)], paths, shapes | |