""" DEIM: DETR with Improved Matching for Fast Convergence Copyright (c) 2024 The DEIM Authors. All Rights Reserved. """ import torch import torchvision.transforms.v2 as T import torchvision.transforms.v2.functional as F import random from PIL import Image from .._misc import convert_to_tv_tensor from ...core import register @register() class Mosaic(T.Transform): """ Applies Mosaic augmentation to a batch of images. Combines four randomly selected images into a single composite image with randomized transformations. """ def __init__(self, output_size=320, max_size=None, rotation_range=0, translation_range=(0.1, 0.1), scaling_range=(0.5, 1.5), probability=1.0, fill_value=114, use_cache=True, max_cached_images=50, random_pop=True) -> None: """ Args: output_size (int): Target size for resizing individual images. rotation_range (float): Range of rotation in degrees for affine transformation. translation_range (tuple): Range of translation for affine transformation. scaling_range (tuple): Range of scaling factors for affine transformation. probability (float): Probability of applying the Mosaic augmentation. fill_value (int): Fill value for padding or affine transformations. use_cache (bool): Whether to use cache. Defaults to True. max_cached_images (int): The maximum length of the cache. random_pop (bool): Whether to randomly pop a result from the cache. """ super().__init__() self.resize = T.Resize(size=output_size, max_size=max_size) self.probability = probability self.affine_transform = T.RandomAffine(degrees=rotation_range, translate=translation_range, scale=scaling_range, fill=fill_value) self.use_cache = use_cache self.mosaic_cache = [] self.max_cached_images = max_cached_images self.random_pop = random_pop def load_samples_from_dataset(self, image, target, dataset): """Loads and resizes a set of images and their corresponding targets.""" # Append the main image get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size # torchvision >=0.17 is get_size image, target = self.resize(image, target) resized_images, resized_targets = [image], [target] max_height, max_width = get_size_func(resized_images[0]) # randomly select 3 images sample_indices = random.choices(range(len(dataset)), k=3) for idx in sample_indices: # image, target = dataset.load_item(idx) image, target = self.resize(dataset.load_item(idx)) height, width = get_size_func(image) max_height, max_width = max(max_height, height), max(max_width, width) resized_images.append(image) resized_targets.append(target) return resized_images, resized_targets, max_height, max_width def load_samples_from_cache(self, image, target, cache): image, target = self.resize(image, target) cache.append(dict(img=image, labels=target)) if len(cache) > self.max_cached_images: if self.random_pop: index = random.randint(0, len(cache) - 2) # do not remove last image else: index = 0 cache.pop(index) sample_indices = random.choices(range(len(cache)), k=3) mosaic_samples = [dict(img=cache[idx]["img"].copy(), labels=self._clone(cache[idx]["labels"])) for idx in sample_indices] # sample 3 images mosaic_samples = [dict(img=image.copy(), labels=self._clone(target))] + mosaic_samples get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size sizes = [get_size_func(mosaic_samples[idx]["img"]) for idx in range(4)] max_height = max(size[0] for size in sizes) max_width = max(size[1] for size in sizes) return mosaic_samples, max_height, max_width def create_mosaic_from_cache(self, mosaic_samples, max_height, max_width): placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]] merged_image = Image.new(mode=mosaic_samples[0]["img"].mode, size=(max_width * 2, max_height * 2), color=0) offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2) mosaic_target = [] for i, sample in enumerate(mosaic_samples): img = sample["img"] target = sample["labels"] merged_image.paste(img, placement_offsets[i]) target['boxes'] = target['boxes'] + offsets[i] mosaic_target.append(target) merged_target = {} for key in mosaic_target[0]: merged_target[key] = torch.cat([target[key] for target in mosaic_target]) return merged_image, merged_target def create_mosaic_from_dataset(self, images, targets, max_height, max_width): """Creates a mosaic image by combining multiple images.""" placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]] merged_image = Image.new(mode=images[0].mode, size=(max_width * 2, max_height * 2), color=0) for i, img in enumerate(images): merged_image.paste(img, placement_offsets[i]) """Merges targets into a single target dictionary for the mosaic.""" offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2) merged_target = {} for key in targets[0]: if key == 'boxes': values = [target[key] + offsets[i] for i, target in enumerate(targets)] else: values = [target[key] for target in targets] merged_target[key] = torch.cat(values, dim=0) if isinstance(values[0], torch.Tensor) else values return merged_image, merged_target @staticmethod def _clone(tensor_dict): return {key: value.clone() for (key, value) in tensor_dict.items()} def forward(self, *inputs): """ Args: inputs (tuple): Input tuple containing (image, target, dataset). Returns: tuple: Augmented (image, target, dataset). """ if len(inputs) == 1: inputs = inputs[0] image, target, dataset = inputs # Skip mosaic augmentation with probability 1 - self.probability if self.probability < 1.0 and random.random() > self.probability: return image, target, dataset # Prepare mosaic components if self.use_cache: mosaic_samples, max_height, max_width = self.load_samples_from_cache(image, target, self.mosaic_cache) mosaic_image, mosaic_target = self.create_mosaic_from_cache(mosaic_samples, max_height, max_width) else: resized_images, resized_targets, max_height, max_width = self.load_samples_from_dataset(image, target,dataset) mosaic_image, mosaic_target = self.create_mosaic_from_dataset(resized_images, resized_targets, max_height, max_width) # Clamp boxes and convert target formats if 'boxes' in mosaic_target: mosaic_target['boxes'] = convert_to_tv_tensor(mosaic_target['boxes'], 'boxes', box_format='xyxy', spatial_size=mosaic_image.size[::-1]) if 'masks' in mosaic_target: mosaic_target['masks'] = convert_to_tv_tensor(mosaic_target['masks'], 'masks') # Apply affine transformations mosaic_image, mosaic_target = self.affine_transform(mosaic_image, mosaic_target) return mosaic_image, mosaic_target, dataset