Spaces:
Runtime error
Runtime error
| # Copyright 2024 EPFL and Apple Inc. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import random | |
| from abc import ABC, abstractmethod | |
| import numpy as np | |
| import torchvision | |
| from fourm.utils import to_2tuple | |
| class AbstractImageAugmenter(ABC): | |
| """Abstract class for image augmenters. | |
| """ | |
| def __call__(self, mod_dict, crop_settings): | |
| pass | |
| class RandomCropImageAugmenter(AbstractImageAugmenter): | |
| def __init__(self, target_size=224, hflip=0.5, crop_scale=(0.2, 1.0), crop_ratio=(0.75, 1.3333), main_domain='rgb'): | |
| self.target_size = to_2tuple(target_size) | |
| self.hflip = hflip | |
| self.crop_scale = crop_scale | |
| self.crop_ratio = crop_ratio | |
| self.main_domain = main_domain | |
| def __call__(self, mod_dict, crop_settings): | |
| if crop_settings is not None: | |
| raise ValueError("Crop settings are provided but not used by this augmenter.") | |
| image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] | |
| # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) | |
| orig_width, orig_height = image.size | |
| orig_size = (orig_height, orig_width) | |
| top, left, h, w = torchvision.transforms.RandomResizedCrop.get_params( | |
| image, scale=self.crop_scale, ratio=self.crop_ratio | |
| ) | |
| crop_coords = top, left, h, w | |
| flip = random.random() < self.hflip | |
| rand_aug_idx = None | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class NoImageAugmenter(AbstractImageAugmenter): # this is for non-image modalities like poses where we don't do any augs, e.g. during tokenization | |
| def __init__(self, no_aug=True, main_domain='human_poses'): | |
| self.target_size = None #to_2tuple(target_size) | |
| self.no_aug = no_aug | |
| self.main_domain = main_domain | |
| def __call__(self, mod_dict, crop_settings): | |
| # # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) | |
| orig_size = (224, 224) | |
| rand_aug_idx = 0 | |
| top, left, h, w, flip = 0, 0, 224, 224, 0 | |
| crop_coords = (top, left, h, w) | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class PreTokenizedImageAugmenter(AbstractImageAugmenter): | |
| def __init__(self, target_size, no_aug=False, main_domain='rgb'): | |
| self.target_size = to_2tuple(target_size) | |
| self.no_aug = no_aug | |
| self.main_domain = main_domain | |
| def __call__(self, mod_dict, crop_settings): | |
| # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) | |
| if self.main_domain in mod_dict and 'tok' not in self.main_domain: | |
| image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] | |
| orig_width, orig_height = image.size | |
| orig_size = (orig_height, orig_width) | |
| else: | |
| orig_size = None | |
| rand_aug_idx = 0 if self.no_aug else np.random.randint(len(crop_settings)) | |
| top, left, h, w, flip = crop_settings[rand_aug_idx] | |
| crop_coords = (top, left, h, w) | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class CenterCropImageAugmenter(AbstractImageAugmenter): | |
| def __init__(self, target_size, hflip=0.0, main_domain='rgb'): | |
| self.target_size = to_2tuple(target_size) | |
| self.hflip = hflip | |
| self.main_domain = main_domain | |
| def __call__(self, mod_dict, crop_settings=None): | |
| image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] | |
| orig_width, orig_height = image.size | |
| orig_size = (orig_height, orig_width) | |
| if orig_height > orig_width: | |
| h = w = orig_width | |
| top = (orig_height - orig_width) // 2 | |
| left = 0 | |
| else: | |
| h = w = orig_height | |
| top = 0 | |
| left = (orig_width - orig_height) // 2 | |
| crop_coords = (top, left, h, w) | |
| flip = random.random() < self.hflip | |
| rand_aug_idx = None | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class PaddingImageAugmenter(AbstractImageAugmenter): | |
| def __init__(self, target_size, hflip=0.0, main_domain='rgb'): | |
| self.target_size = to_2tuple(target_size) | |
| self.hflip = hflip | |
| self.main_domain = main_domain | |
| def __call__(self, mod_dict, crop_settings): | |
| image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] | |
| orig_width, orig_height = image.size | |
| orig_size = (orig_height, orig_width) | |
| h = w = max(orig_width, orig_height) | |
| top = left = 0 | |
| crop_coords = (top, left, h, w) | |
| flip = random.random() < self.hflip | |
| rand_aug_idx = None | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class ScaleJitteringImageAugmenter(AbstractImageAugmenter): | |
| def __init__(self, target_size, hflip=0.0, scale=(0.1, 2.0), main_domain='rgb'): | |
| self.target_size = to_2tuple(target_size) | |
| self.hflip = hflip | |
| self.scale = scale | |
| self.main_domain = main_domain | |
| def scale_jitter(self, orig_height, orig_width): | |
| rand_scale = np.random.uniform(self.scale[0], self.scale[1]) | |
| max_hw = max(orig_height, orig_width) | |
| h = w = round(max_hw / rand_scale) | |
| top = round(max(0, np.random.uniform(0, orig_height - h))) | |
| left = round(max(0, np.random.uniform(0, orig_width - w))) | |
| return top, left, h, w | |
| def __call__(self, mod_dict, crop_settings): | |
| if crop_settings is not None: | |
| raise ValueError("Crop settings are provided but not used by this augmenter.") | |
| image = mod_dict[self.main_domain] if self.main_domain is not None else mod_dict[list(mod_dict.keys())[0]] | |
| # With torchvision 0.13+, can also be: orig_size = TF.get_dimensions(image) | |
| orig_width, orig_height = image.size | |
| orig_size = (orig_height, orig_width) | |
| crop_coords = self.scale_jitter(orig_height, orig_width) | |
| flip = random.random() < self.hflip | |
| rand_aug_idx = None | |
| return crop_coords, flip, orig_size, self.target_size, rand_aug_idx | |
| class EmptyAugmenter(AbstractImageAugmenter): | |
| def __init__(self): | |
| pass | |
| def __call__(self, mod_dict, crop_settings): | |
| return None, None, None, None, None |