Spaces:
Sleeping
Sleeping
Evgeny Zhukov
Origin: https://github.com/ali-vilab/UniAnimate/commit/d7814fa44a0a1154524b92fce0e3133a2604d333
2ba4412
| import torch | |
| import torchvision.transforms.functional as F | |
| import random | |
| import math | |
| import numpy as np | |
| from PIL import Image, ImageFilter | |
| __all__ = ['Compose', 'Resize', 'Rescale', 'CenterCrop', 'CenterCropV2', 'CenterCropWide', 'RandomCrop', 'RandomCropV2', 'RandomHFlip',\ | |
| 'GaussianBlur', 'ColorJitter', 'RandomGray', 'ToTensor', 'Normalize', "ResizeRandomCrop", "ExtractResizeRandomCrop", "ExtractResizeAssignCrop"] | |
| class Compose(object): | |
| def __init__(self, transforms): | |
| self.transforms = transforms | |
| def __getitem__(self, index): | |
| if isinstance(index, slice): | |
| return Compose(self.transforms[index]) | |
| else: | |
| return self.transforms[index] | |
| def __len__(self): | |
| return len(self.transforms) | |
| def __call__(self, rgb): | |
| for t in self.transforms: | |
| rgb = t(rgb) | |
| return rgb | |
| class Resize(object): | |
| def __init__(self, size=256): | |
| if isinstance(size, int): | |
| size = (size, size) | |
| self.size = size | |
| def __call__(self, rgb): | |
| if isinstance(rgb, list): | |
| rgb = [u.resize(self.size, Image.BILINEAR) for u in rgb] | |
| else: | |
| rgb = rgb.resize(self.size, Image.BILINEAR) | |
| return rgb | |
| class Rescale(object): | |
| def __init__(self, size=256, interpolation=Image.BILINEAR): | |
| self.size = size | |
| self.interpolation = interpolation | |
| def __call__(self, rgb): | |
| w, h = rgb[0].size | |
| scale = self.size / min(w, h) | |
| out_w, out_h = int(round(w * scale)), int(round(h * scale)) | |
| rgb = [u.resize((out_w, out_h), self.interpolation) for u in rgb] | |
| return rgb | |
| class CenterCrop(object): | |
| def __init__(self, size=224): | |
| self.size = size | |
| def __call__(self, rgb): | |
| w, h = rgb[0].size | |
| assert min(w, h) >= self.size | |
| x1 = (w - self.size) // 2 | |
| y1 = (h - self.size) // 2 | |
| rgb = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in rgb] | |
| return rgb | |
| class ResizeRandomCrop(object): | |
| def __init__(self, size=256, size_short=292): | |
| self.size = size | |
| # self.min_area = min_area | |
| self.size_short = size_short | |
| def __call__(self, rgb): | |
| # consistent crop between rgb and m | |
| while min(rgb[0].size) >= 2 * self.size_short: | |
| rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] | |
| scale = self.size_short / min(rgb[0].size) | |
| rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] | |
| out_w = self.size | |
| out_h = self.size | |
| w, h = rgb[0].size # (518, 292) | |
| x1 = random.randint(0, w - out_w) | |
| y1 = random.randint(0, h - out_h) | |
| rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] | |
| # rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] | |
| # # center crop | |
| # x1 = (img[0].width - self.size) // 2 | |
| # y1 = (img[0].height - self.size) // 2 | |
| # img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] | |
| return rgb | |
| class ExtractResizeRandomCrop(object): | |
| def __init__(self, size=256, size_short=292): | |
| self.size = size | |
| # self.min_area = min_area | |
| self.size_short = size_short | |
| def __call__(self, rgb): | |
| # consistent crop between rgb and m | |
| while min(rgb[0].size) >= 2 * self.size_short: | |
| rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] | |
| scale = self.size_short / min(rgb[0].size) | |
| rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] | |
| out_w = self.size | |
| out_h = self.size | |
| w, h = rgb[0].size # (518, 292) | |
| x1 = random.randint(0, w - out_w) | |
| y1 = random.randint(0, h - out_h) | |
| rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] | |
| wh = [x1, y1, x1 + out_w, y1 + out_h] | |
| return rgb, wh | |
| class ExtractResizeAssignCrop(object): | |
| def __init__(self, size=256, size_short=292): | |
| self.size = size | |
| # self.min_area = min_area | |
| self.size_short = size_short | |
| def __call__(self, rgb, wh): | |
| # consistent crop between rgb and m | |
| while min(rgb[0].size) >= 2 * self.size_short: | |
| rgb = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in rgb] | |
| scale = self.size_short / min(rgb[0].size) | |
| rgb = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in rgb] | |
| rgb = [u.crop(wh) for u in rgb] | |
| rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] | |
| return rgb | |
| class CenterCropV2(object): | |
| def __init__(self, size): | |
| self.size = size | |
| def __call__(self, img): | |
| # fast resize | |
| while min(img[0].size) >= 2 * self.size: | |
| img = [u.resize((u.width // 2, u.height // 2), resample=Image.BOX) for u in img] | |
| scale = self.size / min(img[0].size) | |
| img = [u.resize((round(scale * u.width), round(scale * u.height)), resample=Image.BICUBIC) for u in img] | |
| # center crop | |
| x1 = (img[0].width - self.size) // 2 | |
| y1 = (img[0].height - self.size) // 2 | |
| img = [u.crop((x1, y1, x1 + self.size, y1 + self.size)) for u in img] | |
| return img | |
| class CenterCropWide(object): | |
| def __init__(self, size, interpolation=Image.BOX): | |
| self.size = size | |
| self.interpolation = interpolation | |
| def __call__(self, img): | |
| if isinstance(img, list): | |
| scale = min(img[0].size[0]/self.size[0], img[0].size[1]/self.size[1]) | |
| img = [u.resize((round(u.width // scale), round(u.height // scale)), resample=self.interpolation) for u in img] | |
| # center crop | |
| x1 = (img[0].width - self.size[0]) // 2 | |
| y1 = (img[0].height - self.size[1]) // 2 | |
| img = [u.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) for u in img] | |
| return img | |
| else: | |
| scale = min(img.size[0]/self.size[0], img.size[1]/self.size[1]) | |
| img = img.resize((round(img.width // scale), round(img.height // scale)), resample=self.interpolation) | |
| x1 = (img.width - self.size[0]) // 2 | |
| y1 = (img.height - self.size[1]) // 2 | |
| img = img.crop((x1, y1, x1 + self.size[0], y1 + self.size[1])) | |
| return img | |
| class RandomCrop(object): | |
| def __init__(self, size=224, min_area=0.4): | |
| self.size = size | |
| self.min_area = min_area | |
| def __call__(self, rgb): | |
| # consistent crop between rgb and m | |
| w, h = rgb[0].size | |
| area = w * h | |
| out_w, out_h = float('inf'), float('inf') | |
| while out_w > w or out_h > h: | |
| target_area = random.uniform(self.min_area, 1.0) * area | |
| aspect_ratio = random.uniform(3. / 4., 4. / 3.) | |
| out_w = int(round(math.sqrt(target_area * aspect_ratio))) | |
| out_h = int(round(math.sqrt(target_area / aspect_ratio))) | |
| x1 = random.randint(0, w - out_w) | |
| y1 = random.randint(0, h - out_h) | |
| rgb = [u.crop((x1, y1, x1 + out_w, y1 + out_h)) for u in rgb] | |
| rgb = [u.resize((self.size, self.size), Image.BILINEAR) for u in rgb] | |
| return rgb | |
| class RandomCropV2(object): | |
| def __init__(self, size=224, min_area=0.4, ratio=(3. / 4., 4. / 3.)): | |
| if isinstance(size, (tuple, list)): | |
| self.size = size | |
| else: | |
| self.size = (size, size) | |
| self.min_area = min_area | |
| self.ratio = ratio | |
| def _get_params(self, img): | |
| width, height = img.size | |
| area = height * width | |
| for _ in range(10): | |
| target_area = random.uniform(self.min_area, 1.0) * area | |
| log_ratio = (math.log(self.ratio[0]), math.log(self.ratio[1])) | |
| aspect_ratio = math.exp(random.uniform(*log_ratio)) | |
| w = int(round(math.sqrt(target_area * aspect_ratio))) | |
| h = int(round(math.sqrt(target_area / aspect_ratio))) | |
| if 0 < w <= width and 0 < h <= height: | |
| i = random.randint(0, height - h) | |
| j = random.randint(0, width - w) | |
| return i, j, h, w | |
| # Fallback to central crop | |
| in_ratio = float(width) / float(height) | |
| if (in_ratio < min(self.ratio)): | |
| w = width | |
| h = int(round(w / min(self.ratio))) | |
| elif (in_ratio > max(self.ratio)): | |
| h = height | |
| w = int(round(h * max(self.ratio))) | |
| else: # whole image | |
| w = width | |
| h = height | |
| i = (height - h) // 2 | |
| j = (width - w) // 2 | |
| return i, j, h, w | |
| def __call__(self, rgb): | |
| i, j, h, w = self._get_params(rgb[0]) | |
| rgb = [F.resized_crop(u, i, j, h, w, self.size) for u in rgb] | |
| return rgb | |
| class RandomHFlip(object): | |
| def __init__(self, p=0.5): | |
| self.p = p | |
| def __call__(self, rgb): | |
| if random.random() < self.p: | |
| rgb = [u.transpose(Image.FLIP_LEFT_RIGHT) for u in rgb] | |
| return rgb | |
| class GaussianBlur(object): | |
| def __init__(self, sigmas=[0.1, 2.0], p=0.5): | |
| self.sigmas = sigmas | |
| self.p = p | |
| def __call__(self, rgb): | |
| if random.random() < self.p: | |
| sigma = random.uniform(*self.sigmas) | |
| rgb = [u.filter(ImageFilter.GaussianBlur(radius=sigma)) for u in rgb] | |
| return rgb | |
| class ColorJitter(object): | |
| def __init__(self, brightness=0.4, contrast=0.4, saturation=0.4, hue=0.1, p=0.5): | |
| self.brightness = brightness | |
| self.contrast = contrast | |
| self.saturation = saturation | |
| self.hue = hue | |
| self.p = p | |
| def __call__(self, rgb): | |
| if random.random() < self.p: | |
| brightness, contrast, saturation, hue = self._random_params() | |
| transforms = [ | |
| lambda f: F.adjust_brightness(f, brightness), | |
| lambda f: F.adjust_contrast(f, contrast), | |
| lambda f: F.adjust_saturation(f, saturation), | |
| lambda f: F.adjust_hue(f, hue)] | |
| random.shuffle(transforms) | |
| for t in transforms: | |
| rgb = [t(u) for u in rgb] | |
| return rgb | |
| def _random_params(self): | |
| brightness = random.uniform( | |
| max(0, 1 - self.brightness), 1 + self.brightness) | |
| contrast = random.uniform( | |
| max(0, 1 - self.contrast), 1 + self.contrast) | |
| saturation = random.uniform( | |
| max(0, 1 - self.saturation), 1 + self.saturation) | |
| hue = random.uniform(-self.hue, self.hue) | |
| return brightness, contrast, saturation, hue | |
| class RandomGray(object): | |
| def __init__(self, p=0.2): | |
| self.p = p | |
| def __call__(self, rgb): | |
| if random.random() < self.p: | |
| rgb = [u.convert('L').convert('RGB') for u in rgb] | |
| return rgb | |
| class ToTensor(object): | |
| def __call__(self, rgb): | |
| if isinstance(rgb, list): | |
| rgb = torch.stack([F.to_tensor(u) for u in rgb], dim=0) | |
| else: | |
| rgb = F.to_tensor(rgb) | |
| return rgb | |
| class Normalize(object): | |
| def __init__(self, mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]): | |
| self.mean = mean | |
| self.std = std | |
| def __call__(self, rgb): | |
| rgb = rgb.clone() | |
| rgb.clamp_(0, 1) | |
| if not isinstance(self.mean, torch.Tensor): | |
| self.mean = rgb.new_tensor(self.mean).view(-1) | |
| if not isinstance(self.std, torch.Tensor): | |
| self.std = rgb.new_tensor(self.std).view(-1) | |
| if rgb.dim() == 4: | |
| rgb.sub_(self.mean.view(1, -1, 1, 1)).div_(self.std.view(1, -1, 1, 1)) | |
| elif rgb.dim() == 3: | |
| rgb.sub_(self.mean.view(-1, 1, 1)).div_(self.std.view(-1, 1, 1)) | |
| return rgb | |