Spaces:
Running
Running
| import torchvision | |
| import random | |
| from PIL import Image, ImageOps | |
| import numbers | |
| import torch | |
| import numpy as np | |
| import math | |
| class GroupRandomCrop(object): | |
| def __init__(self, size): | |
| if isinstance(size, numbers.Number): | |
| self.size = (int(size), int(size)) | |
| else: | |
| self.size = size | |
| def __call__(self, img_group): | |
| w, h = img_group[0].size | |
| th, tw = self.size | |
| out_images = list() | |
| x1 = random.randint(0, w - tw) | |
| y1 = random.randint(0, h - th) | |
| for img in img_group: | |
| assert(img.size[0] == w and img.size[1] == h) | |
| if w == tw and h == th: | |
| out_images.append(img) | |
| else: | |
| out_images.append(img.crop((x1, y1, x1 + tw, y1 + th))) | |
| return out_images | |
| class GroupCenterCrop(object): | |
| def __init__(self, size): | |
| self.worker = torchvision.transforms.CenterCrop(size) | |
| def __call__(self, img_group): | |
| return [self.worker(img) for img in img_group] | |
| class GroupRandomHorizontalFlip(object): | |
| """Randomly horizontally flips the given PIL.Image with a probability of 0.5 | |
| """ | |
| def __init__(self, is_flow=False): | |
| self.is_flow = is_flow | |
| def __call__(self, img_group, is_flow=False): | |
| v = random.random() | |
| if v < 0.5: | |
| ret = [img.transpose(Image.FLIP_LEFT_RIGHT) for img in img_group] | |
| if self.is_flow: | |
| for i in range(0, len(ret), 2): | |
| ret[i] = ImageOps.invert(ret[i]) # invert flow pixel values when flipping | |
| return ret | |
| else: | |
| return img_group | |
| class GroupNormalize(object): | |
| def __init__(self, mean, std, threed_data=False): | |
| self.threed_data = threed_data | |
| if self.threed_data: | |
| # convert to the proper format | |
| self.mean = torch.FloatTensor(mean).view(len(mean), 1, 1, 1) | |
| self.std = torch.FloatTensor(std).view(len(std), 1, 1, 1) | |
| else: | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, tensor): | |
| if self.threed_data: | |
| tensor.sub_(self.mean).div_(self.std) | |
| else: | |
| rep_mean = self.mean * (tensor.size()[0] // len(self.mean)) | |
| rep_std = self.std * (tensor.size()[0] // len(self.std)) | |
| # TODO: make efficient | |
| for t, m, s in zip(tensor, rep_mean, rep_std): | |
| t.sub_(m).div_(s) | |
| return tensor | |
| class GroupCutout(object): | |
| """Randomly mask out one or more patches from an image. | |
| Args: | |
| n_holes (int): Number of patches to cut out of each image. | |
| length (int): The length (in pixels) of each square patch. | |
| """ | |
| def __init__(self, n_holes, length): | |
| self.n_holes = n_holes | |
| self.length = length | |
| def __call__(self, imgs): | |
| """ | |
| Args: | |
| img (Tensor): Tensor image of size (C, H, W). | |
| Returns: | |
| Tensor: Image with n_holes of dimension length x length cut out of it. | |
| """ | |
| new_imgs = [] | |
| # import pdb;pdb.set_trace() | |
| C,W,H = imgs.shape #72,224,224 | |
| # print(C,W,H) | |
| # imgs = imgs.reshape(-1,3,H,W) | |
| y = np.random.randint(H) | |
| x = np.random.randint(W) | |
| for i in range(0,imgs.shape[0],3): | |
| h = W | |
| w = H | |
| mask = np.ones((h, w), np.float32) | |
| for n in range(self.n_holes): | |
| y1 = np.clip(y - self.length // 2, 0, h) | |
| y2 = np.clip(y + self.length // 2, 0, h) | |
| x1 = np.clip(x - self.length // 2, 0, w) | |
| x2 = np.clip(x + self.length // 2, 0, w) | |
| mask[y1: y2, x1: x2] = 0. | |
| mask = torch.from_numpy(mask) | |
| mask = mask.expand_as(imgs[i:i+3]) | |
| img = imgs[i:i+3] * mask | |
| new_imgs.append(img) | |
| # import pdb;pdb.set_trace() | |
| new_imgs = torch.stack(new_imgs,0).reshape(C,H,W) | |
| # print(new_imgs.shape) | |
| return new_imgs | |
| class GroupScale(object): | |
| """ Rescales the input PIL.Image to the given 'size'. | |
| 'size' will be the size of the smaller edge. | |
| For example, if height > width, then image will be | |
| rescaled to (size * height / width, size) | |
| size: size of the smaller edge | |
| interpolation: Default: PIL.Image.BILINEAR | |
| """ | |
| def __init__(self, size, interpolation=Image.BILINEAR): | |
| self.worker = torchvision.transforms.Resize(size, interpolation) | |
| def __call__(self, img_group): | |
| return [self.worker(img) for img in img_group] | |
| class GroupRandomScale(object): | |
| """ Rescales the input PIL.Image to the given 'size'. | |
| 'size' will be the size of the smaller edge. | |
| For example, if height > width, then image will be | |
| rescaled to (size * height / width, size) | |
| size: size of the smaller edge | |
| interpolation: Default: PIL.Image.BILINEAR | |
| Randomly select the smaller edge from the range of 'size'. | |
| """ | |
| def __init__(self, size, interpolation=Image.BILINEAR): | |
| self.size = size | |
| self.interpolation = interpolation | |
| def __call__(self, img_group): | |
| selected_size = np.random.randint(low=self.size[0], high=self.size[1] + 1, dtype=int) | |
| scale = GroupScale(selected_size, interpolation=self.interpolation) | |
| return scale(img_group) | |
| class GroupOverSample(object): | |
| def __init__(self, crop_size, scale_size=None, num_crops=5, flip=False): | |
| self.crop_size = crop_size if not isinstance(crop_size, int) else (crop_size, crop_size) | |
| if scale_size is not None: | |
| self.scale_worker = GroupScale(scale_size) | |
| else: | |
| self.scale_worker = None | |
| if num_crops not in [1, 3, 5, 10]: | |
| raise ValueError("num_crops should be in [1, 3, 5, 10] but ({})".format(num_crops)) | |
| self.num_crops = num_crops | |
| self.flip = flip | |
| def __call__(self, img_group): | |
| if self.scale_worker is not None: | |
| img_group = self.scale_worker(img_group) | |
| image_w, image_h = img_group[0].size | |
| crop_w, crop_h = self.crop_size | |
| if self.num_crops == 3: | |
| w_step = (image_w - crop_w) // 4 | |
| h_step = (image_h - crop_h) // 4 | |
| offsets = list() | |
| if image_w != crop_w and image_h != crop_h: | |
| offsets.append((0 * w_step, 0 * h_step)) # top | |
| offsets.append((4 * w_step, 4 * h_step)) # bottom | |
| offsets.append((2 * w_step, 2 * h_step)) # center | |
| else: | |
| if image_w < image_h: | |
| offsets.append((2 * w_step, 0 * h_step)) # top | |
| offsets.append((2 * w_step, 4 * h_step)) # bottom | |
| offsets.append((2 * w_step, 2 * h_step)) # center | |
| else: | |
| offsets.append((0 * w_step, 2 * h_step)) # left | |
| offsets.append((4 * w_step, 2 * h_step)) # right | |
| offsets.append((2 * w_step, 2 * h_step)) # center | |
| else: | |
| offsets = GroupMultiScaleCrop.fill_fix_offset(False, image_w, image_h, crop_w, crop_h) | |
| oversample_group = list() | |
| for o_w, o_h in offsets: | |
| normal_group = list() | |
| flip_group = list() | |
| for i, img in enumerate(img_group): | |
| crop = img.crop((o_w, o_h, o_w + crop_w, o_h + crop_h)) | |
| normal_group.append(crop) | |
| if self.flip: | |
| flip_crop = crop.copy().transpose(Image.FLIP_LEFT_RIGHT) | |
| if img.mode == 'L' and i % 2 == 0: | |
| flip_group.append(ImageOps.invert(flip_crop)) | |
| else: | |
| flip_group.append(flip_crop) | |
| oversample_group.extend(normal_group) | |
| if self.flip: | |
| oversample_group.extend(flip_group) | |
| return oversample_group | |
| class GroupMultiScaleCrop(object): | |
| def __init__(self, input_size, scales=None, max_distort=1, fix_crop=True, more_fix_crop=True): | |
| self.scales = scales if scales is not None else [1, 875, .75, .66] | |
| self.max_distort = max_distort | |
| self.fix_crop = fix_crop | |
| self.more_fix_crop = more_fix_crop | |
| self.input_size = input_size if not isinstance(input_size, int) else [input_size, input_size] | |
| self.interpolation = Image.BILINEAR | |
| def __call__(self, img_group): | |
| im_size = img_group[0].size | |
| crop_w, crop_h, offset_w, offset_h = self._sample_crop_size(im_size) | |
| crop_img_group = [img.crop((offset_w, offset_h, offset_w + crop_w, offset_h + crop_h)) for img in img_group] | |
| ret_img_group = [img.resize((self.input_size[0], self.input_size[1]), self.interpolation) | |
| for img in crop_img_group] | |
| return ret_img_group | |
| def _sample_crop_size(self, im_size): | |
| image_w, image_h = im_size[0], im_size[1] | |
| # find a crop size | |
| base_size = min(image_w, image_h) | |
| crop_sizes = [int(base_size * x) for x in self.scales] | |
| crop_h = [self.input_size[1] if abs(x - self.input_size[1]) < 3 else x for x in crop_sizes] | |
| crop_w = [self.input_size[0] if abs(x - self.input_size[0]) < 3 else x for x in crop_sizes] | |
| pairs = [] | |
| for i, h in enumerate(crop_h): | |
| for j, w in enumerate(crop_w): | |
| if abs(i - j) <= self.max_distort: | |
| pairs.append((w, h)) | |
| crop_pair = random.choice(pairs) | |
| if not self.fix_crop: | |
| w_offset = random.randint(0, image_w - crop_pair[0]) | |
| h_offset = random.randint(0, image_h - crop_pair[1]) | |
| else: | |
| w_offset, h_offset = self._sample_fix_offset(image_w, image_h, crop_pair[0], crop_pair[1]) | |
| return crop_pair[0], crop_pair[1], w_offset, h_offset | |
| def _sample_fix_offset(self, image_w, image_h, crop_w, crop_h): | |
| offsets = self.fill_fix_offset(self.more_fix_crop, image_w, image_h, crop_w, crop_h) | |
| return random.choice(offsets) | |
| def fill_fix_offset(more_fix_crop, image_w, image_h, crop_w, crop_h): | |
| w_step = (image_w - crop_w) // 4 | |
| h_step = (image_h - crop_h) // 4 | |
| ret = list() | |
| ret.append((0, 0)) # upper left | |
| ret.append((4 * w_step, 0)) # upper right | |
| ret.append((0, 4 * h_step)) # lower left | |
| ret.append((4 * w_step, 4 * h_step)) # lower right | |
| ret.append((2 * w_step, 2 * h_step)) # center | |
| if more_fix_crop: | |
| ret.append((0, 2 * h_step)) # center left | |
| ret.append((4 * w_step, 2 * h_step)) # center right | |
| ret.append((2 * w_step, 4 * h_step)) # lower center | |
| ret.append((2 * w_step, 0 * h_step)) # upper center | |
| ret.append((1 * w_step, 1 * h_step)) # upper left quarter | |
| ret.append((3 * w_step, 1 * h_step)) # upper right quarter | |
| ret.append((1 * w_step, 3 * h_step)) # lower left quarter | |
| ret.append((3 * w_step, 3 * h_step)) # lower righ quarter | |
| return ret | |
| class GroupRandomSizedCrop(object): | |
| """Random crop the given PIL.Image to a random size of (0.08 to 1.0) of the original size | |
| and and a random aspect ratio of 3/4 to 4/3 of the original aspect ratio | |
| This is popularly used to train the Inception networks | |
| size: size of the smaller edge | |
| interpolation: Default: PIL.Image.BILINEAR | |
| """ | |
| def __init__(self, size, interpolation=Image.BILINEAR): | |
| self.size = size | |
| self.interpolation = interpolation | |
| def __call__(self, img_group): | |
| for attempt in range(10): | |
| area = img_group[0].size[0] * img_group[0].size[1] | |
| target_area = random.uniform(0.08, 1.0) * area | |
| aspect_ratio = random.uniform(3. / 4, 4. / 3) | |
| w = int(round(math.sqrt(target_area * aspect_ratio))) | |
| h = int(round(math.sqrt(target_area / aspect_ratio))) | |
| if random.random() < 0.5: | |
| w, h = h, w | |
| if w <= img_group[0].size[0] and h <= img_group[0].size[1]: | |
| x1 = random.randint(0, img_group[0].size[0] - w) | |
| y1 = random.randint(0, img_group[0].size[1] - h) | |
| found = True | |
| break | |
| else: | |
| found = False | |
| x1 = 0 | |
| y1 = 0 | |
| if found: | |
| out_group = list() | |
| for img in img_group: | |
| img = img.crop((x1, y1, x1 + w, y1 + h)) | |
| assert(img.size == (w, h)) | |
| out_group.append(img.resize((self.size, self.size), self.interpolation)) | |
| return out_group | |
| else: | |
| # Fallback | |
| scale = GroupScale(self.size, interpolation=self.interpolation) | |
| crop = GroupRandomCrop(self.size) | |
| return crop(scale(img_group)) | |
| class Stack(object): | |
| def __init__(self, roll=False, threed_data=False): | |
| self.roll = roll | |
| self.threed_data = threed_data | |
| def __call__(self, img_group): | |
| if img_group[0].mode == 'L': | |
| return np.concatenate([np.expand_dims(x, 2) for x in img_group], axis=2) | |
| elif img_group[0].mode == 'RGB': | |
| if self.threed_data: | |
| return np.stack(img_group, axis=0) | |
| else: | |
| if self.roll: | |
| return np.concatenate([np.array(x)[:, :, ::-1] for x in img_group], axis=2) | |
| else: | |
| return np.concatenate(img_group, axis=2) | |
| class ToTorchFormatTensor(object): | |
| """ Converts a PIL.Image (RGB) or numpy.ndarray (H x W x C) in the range [0, 255] | |
| to a torch.FloatTensor of shape (C x H x W) in the range [0.0, 1.0] """ | |
| def __init__(self, div=True, num_clips_crops=1): | |
| self.div = div | |
| self.num_clips_crops = num_clips_crops | |
| def __call__(self, pic): | |
| if isinstance(pic, np.ndarray): | |
| # handle numpy array | |
| if len(pic.shape) == 4: | |
| # ((NF)xCxHxW) --> (Cx(NF)xHxW) | |
| img = torch.from_numpy(pic).permute(3, 0, 1, 2).contiguous() | |
| else: # data is HW(FC) | |
| img = torch.from_numpy(pic).permute(2, 0, 1).contiguous() | |
| else: | |
| # handle PIL Image | |
| img = torch.ByteTensor(torch.ByteStorage.from_buffer(pic.tobytes())) | |
| img = img.view(pic.size[1], pic.size[0], len(pic.mode)) | |
| # put it from HWC to CHW format | |
| # yikes, this transpose takes 80% of the loading time/CPU | |
| img = img.transpose(0, 1).transpose(0, 2).contiguous() | |
| return img.float().div(255) if self.div else img.float() | |
| class IdentityTransform(object): | |
| def __call__(self, data): | |
| return data | |
| if __name__ == "__main__": | |
| trans = torchvision.transforms.Compose([ | |
| GroupScale(256), | |
| GroupRandomCrop(224), | |
| GroupOverSample(224, 224, num_crops=3, flip=False), | |
| Stack(), | |
| ToTorchFormatTensor(num_clips_crops=9), | |
| GroupNormalize( | |
| mean=[.485, .456, .406], | |
| std=[.229, .224, .225] | |
| )] | |
| ) | |