Spaces:
Runtime error
Runtime error
| import cv2 | |
| import torch | |
| from PIL import Image | |
| import os.path as osp | |
| import numpy as np | |
| from torch.utils import data | |
| import torchvision.transforms as transforms | |
| import torchvision.transforms.functional as TF | |
| import random | |
| class RandomResizedCrop(object): | |
| def __init__(self, N, res, scale=(0.5, 1.0)): | |
| self.res = res | |
| self.scale = scale | |
| self.rscale = [np.random.uniform(*scale) for _ in range(N)] | |
| self.rcrop = [(np.random.uniform(0, 1), np.random.uniform(0, 1)) for _ in range(N)] | |
| def random_crop(self, idx, img): | |
| ws, hs = self.rcrop[idx] | |
| res1 = int(img.size(-1)) | |
| res2 = int(self.rscale[idx]*res1) | |
| i1 = int(round((res1-res2)*ws)) | |
| j1 = int(round((res1-res2)*hs)) | |
| return img[:, :, i1:i1+res2, j1:j1+res2] | |
| def __call__(self, indice, image): | |
| new_image = [] | |
| res_tar = self.res // 4 if image.size(1) > 5 else self.res # View 1 or View 2? | |
| for i, idx in enumerate(indice): | |
| img = image[[i]] | |
| img = self.random_crop(idx, img) | |
| img = F.interpolate(img, res_tar, mode='bilinear', align_corners=False) | |
| new_image.append(img) | |
| new_image = torch.cat(new_image) | |
| return new_image | |
| class RandomVerticalFlip(object): | |
| def __init__(self, N, p=0.5): | |
| self.p_ref = p | |
| self.plist = np.random.random_sample(N) | |
| def __call__(self, indice, image): | |
| I = np.nonzero(self.plist[indice] < self.p_ref)[0] | |
| if len(image.size()) == 3: | |
| image_t = image[I].flip([1]) | |
| else: | |
| image_t = image[I].flip([2]) | |
| return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) | |
| class RandomHorizontalTensorFlip(object): | |
| def __init__(self, N, p=0.5): | |
| self.p_ref = p | |
| self.plist = np.random.random_sample(N) | |
| def __call__(self, indice, image, is_label=False): | |
| I = np.nonzero(self.plist[indice] < self.p_ref)[0] | |
| if len(image.size()) == 3: | |
| image_t = image[I].flip([2]) | |
| else: | |
| image_t = image[I].flip([3]) | |
| return torch.stack([image_t[np.where(I==i)[0][0]] if i in I else image[i] for i in range(image.size(0))]) | |
| class _Coco164kCuratedFew(data.Dataset): | |
| """Base class | |
| This contains fields and methods common to all COCO 164k curated few datasets: | |
| (curated) Coco164kFew_Stuff | |
| (curated) Coco164kFew_Stuff_People | |
| (curated) Coco164kFew_Stuff_Animals | |
| (curated) Coco164kFew_Stuff_People_Animals | |
| """ | |
| def __init__(self, root, img_size, crop_size, split = "train2017"): | |
| super(_Coco164kCuratedFew, self).__init__() | |
| # work out name | |
| self.split = split | |
| self.root = root | |
| self.include_things_labels = False # people | |
| self.incl_animal_things = False # animals | |
| version = 6 | |
| name = "Coco164kFew_Stuff" | |
| if self.include_things_labels and self.incl_animal_things: | |
| name += "_People_Animals" | |
| elif self.include_things_labels: | |
| name += "_People" | |
| elif self.incl_animal_things: | |
| name += "_Animals" | |
| self.name = (name + "_%d" % version) | |
| print("Specific type of _Coco164kCuratedFew dataset: %s" % self.name) | |
| self._set_files() | |
| self.transform = transforms.Compose([ | |
| transforms.RandomChoice([ | |
| transforms.ColorJitter(brightness=0.05), | |
| transforms.ColorJitter(contrast=0.05), | |
| transforms.ColorJitter(saturation=0.01), | |
| transforms.ColorJitter(hue=0.01)]), | |
| transforms.RandomHorizontalFlip(), | |
| transforms.RandomVerticalFlip(), | |
| transforms.Resize(int(img_size)), | |
| transforms.RandomCrop(crop_size)]) | |
| N = len(self.files) | |
| self.random_horizontal_flip = RandomHorizontalTensorFlip(N=N) | |
| self.random_vertical_flip = RandomVerticalFlip(N=N) | |
| self.random_resized_crop = RandomResizedCrop(N=N, res=self.res1, scale=self.scale) | |
| def _set_files(self): | |
| # Create data list by parsing the "images" folder | |
| if self.split in ["train2017", "val2017"]: | |
| file_list = osp.join(self.root, "curated", self.split, self.name + ".txt") | |
| file_list = tuple(open(file_list, "r")) | |
| file_list = [id_.rstrip() for id_ in file_list] | |
| self.files = file_list | |
| print("In total {} images.".format(len(self.files))) | |
| else: | |
| raise ValueError("Invalid split name: {}".format(self.split)) | |
| def __getitem__(self, index): | |
| # same as _Coco164k | |
| # Set paths | |
| image_id = self.files[index] | |
| image_path = osp.join(self.root, "images", self.split, image_id + ".jpg") | |
| label_path = osp.join(self.root, "annotations", self.split, | |
| image_id + ".png") | |
| # Load an image | |
| #image = cv2.imread(image_path, cv2.IMREAD_COLOR).astype(np.uint8) | |
| ori_img = Image.open(image_path) | |
| ori_img = self.transform(ori_img) | |
| ori_img = np.array(ori_img) | |
| if ori_img.ndim < 3: | |
| ori_img = np.expand_dims(ori_img, axis=2).repeat(3, axis = 2) | |
| ori_img = ori_img[:, :, :3] | |
| ori_img = torch.from_numpy(ori_img).float().permute(2, 0, 1) | |
| ori_img = ori_img / 255.0 | |
| #label = cv2.imread(label_path, cv2.IMREAD_GRAYSCALE).astype(np.int32) | |
| #label[label == 255] = -1 # to be consistent with 10k | |
| rets = [] | |
| rets.append(ori_img) | |
| #rets.append(label) | |
| return rets | |
| def __len__(self): | |
| return len(self.files) | |