Spaces:
Running
on
Zero
Running
on
Zero
| """ | |
| DEIM: DETR with Improved Matching for Fast Convergence | |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. | |
| """ | |
| import torch | |
| import torchvision.transforms.v2 as T | |
| import torchvision.transforms.v2.functional as F | |
| import random | |
| from PIL import Image | |
| from .._misc import convert_to_tv_tensor | |
| from ...core import register | |
| class Mosaic(T.Transform): | |
| """ | |
| Applies Mosaic augmentation to a batch of images. Combines four randomly selected images | |
| into a single composite image with randomized transformations. | |
| """ | |
| def __init__(self, output_size=320, max_size=None, rotation_range=0, translation_range=(0.1, 0.1), | |
| scaling_range=(0.5, 1.5), probability=1.0, fill_value=114, use_cache=True, max_cached_images=50, | |
| random_pop=True) -> None: | |
| """ | |
| Args: | |
| output_size (int): Target size for resizing individual images. | |
| rotation_range (float): Range of rotation in degrees for affine transformation. | |
| translation_range (tuple): Range of translation for affine transformation. | |
| scaling_range (tuple): Range of scaling factors for affine transformation. | |
| probability (float): Probability of applying the Mosaic augmentation. | |
| fill_value (int): Fill value for padding or affine transformations. | |
| use_cache (bool): Whether to use cache. Defaults to True. | |
| max_cached_images (int): The maximum length of the cache. | |
| random_pop (bool): Whether to randomly pop a result from the cache. | |
| """ | |
| super().__init__() | |
| self.resize = T.Resize(size=output_size, max_size=max_size) | |
| self.probability = probability | |
| self.affine_transform = T.RandomAffine(degrees=rotation_range, translate=translation_range, | |
| scale=scaling_range, fill=fill_value) | |
| self.use_cache = use_cache | |
| self.mosaic_cache = [] | |
| self.max_cached_images = max_cached_images | |
| self.random_pop = random_pop | |
| def load_samples_from_dataset(self, image, target, dataset): | |
| """Loads and resizes a set of images and their corresponding targets.""" | |
| # Append the main image | |
| get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size # torchvision >=0.17 is get_size | |
| image, target = self.resize(image, target) | |
| resized_images, resized_targets = [image], [target] | |
| max_height, max_width = get_size_func(resized_images[0]) | |
| # randomly select 3 images | |
| sample_indices = random.choices(range(len(dataset)), k=3) | |
| for idx in sample_indices: | |
| # image, target = dataset.load_item(idx) | |
| image, target = self.resize(dataset.load_item(idx)) | |
| height, width = get_size_func(image) | |
| max_height, max_width = max(max_height, height), max(max_width, width) | |
| resized_images.append(image) | |
| resized_targets.append(target) | |
| return resized_images, resized_targets, max_height, max_width | |
| def load_samples_from_cache(self, image, target, cache): | |
| image, target = self.resize(image, target) | |
| cache.append(dict(img=image, labels=target)) | |
| if len(cache) > self.max_cached_images: | |
| if self.random_pop: | |
| index = random.randint(0, len(cache) - 2) # do not remove last image | |
| else: | |
| index = 0 | |
| cache.pop(index) | |
| sample_indices = random.choices(range(len(cache)), k=3) | |
| mosaic_samples = [dict(img=cache[idx]["img"].copy(), labels=self._clone(cache[idx]["labels"])) for idx in | |
| sample_indices] # sample 3 images | |
| mosaic_samples = [dict(img=image.copy(), labels=self._clone(target))] + mosaic_samples | |
| get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size | |
| sizes = [get_size_func(mosaic_samples[idx]["img"]) for idx in range(4)] | |
| max_height = max(size[0] for size in sizes) | |
| max_width = max(size[1] for size in sizes) | |
| return mosaic_samples, max_height, max_width | |
| def create_mosaic_from_cache(self, mosaic_samples, max_height, max_width): | |
| placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]] | |
| merged_image = Image.new(mode=mosaic_samples[0]["img"].mode, size=(max_width * 2, max_height * 2), color=0) | |
| offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2) | |
| mosaic_target = [] | |
| for i, sample in enumerate(mosaic_samples): | |
| img = sample["img"] | |
| target = sample["labels"] | |
| merged_image.paste(img, placement_offsets[i]) | |
| target['boxes'] = target['boxes'] + offsets[i] | |
| mosaic_target.append(target) | |
| merged_target = {} | |
| for key in mosaic_target[0]: | |
| merged_target[key] = torch.cat([target[key] for target in mosaic_target]) | |
| return merged_image, merged_target | |
| def create_mosaic_from_dataset(self, images, targets, max_height, max_width): | |
| """Creates a mosaic image by combining multiple images.""" | |
| placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]] | |
| merged_image = Image.new(mode=images[0].mode, size=(max_width * 2, max_height * 2), color=0) | |
| for i, img in enumerate(images): | |
| merged_image.paste(img, placement_offsets[i]) | |
| """Merges targets into a single target dictionary for the mosaic.""" | |
| offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2) | |
| merged_target = {} | |
| for key in targets[0]: | |
| if key == 'boxes': | |
| values = [target[key] + offsets[i] for i, target in enumerate(targets)] | |
| else: | |
| values = [target[key] for target in targets] | |
| merged_target[key] = torch.cat(values, dim=0) if isinstance(values[0], torch.Tensor) else values | |
| return merged_image, merged_target | |
| def _clone(tensor_dict): | |
| return {key: value.clone() for (key, value) in tensor_dict.items()} | |
| def forward(self, *inputs): | |
| """ | |
| Args: | |
| inputs (tuple): Input tuple containing (image, target, dataset). | |
| Returns: | |
| tuple: Augmented (image, target, dataset). | |
| """ | |
| if len(inputs) == 1: | |
| inputs = inputs[0] | |
| image, target, dataset = inputs | |
| # Skip mosaic augmentation with probability 1 - self.probability | |
| if self.probability < 1.0 and random.random() > self.probability: | |
| return image, target, dataset | |
| # Prepare mosaic components | |
| if self.use_cache: | |
| mosaic_samples, max_height, max_width = self.load_samples_from_cache(image, target, self.mosaic_cache) | |
| mosaic_image, mosaic_target = self.create_mosaic_from_cache(mosaic_samples, max_height, max_width) | |
| else: | |
| resized_images, resized_targets, max_height, max_width = self.load_samples_from_dataset(image, target,dataset) | |
| mosaic_image, mosaic_target = self.create_mosaic_from_dataset(resized_images, resized_targets, max_height, max_width) | |
| # Clamp boxes and convert target formats | |
| if 'boxes' in mosaic_target: | |
| mosaic_target['boxes'] = convert_to_tv_tensor(mosaic_target['boxes'], 'boxes', box_format='xyxy', | |
| spatial_size=mosaic_image.size[::-1]) | |
| if 'masks' in mosaic_target: | |
| mosaic_target['masks'] = convert_to_tv_tensor(mosaic_target['masks'], 'masks') | |
| # Apply affine transformations | |
| mosaic_image, mosaic_target = self.affine_transform(mosaic_image, mosaic_target) | |
| return mosaic_image, mosaic_target, dataset | |