Spaces:
Running
on
Zero
Running
on
Zero
File size: 7,866 Bytes
11aa70b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 |
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
"""
import torch
import torchvision.transforms.v2 as T
import torchvision.transforms.v2.functional as F
import random
from PIL import Image
from .._misc import convert_to_tv_tensor
from ...core import register
@register()
class Mosaic(T.Transform):
"""
Applies Mosaic augmentation to a batch of images. Combines four randomly selected images
into a single composite image with randomized transformations.
"""
def __init__(self, output_size=320, max_size=None, rotation_range=0, translation_range=(0.1, 0.1),
scaling_range=(0.5, 1.5), probability=1.0, fill_value=114, use_cache=True, max_cached_images=50,
random_pop=True) -> None:
"""
Args:
output_size (int): Target size for resizing individual images.
rotation_range (float): Range of rotation in degrees for affine transformation.
translation_range (tuple): Range of translation for affine transformation.
scaling_range (tuple): Range of scaling factors for affine transformation.
probability (float): Probability of applying the Mosaic augmentation.
fill_value (int): Fill value for padding or affine transformations.
use_cache (bool): Whether to use cache. Defaults to True.
max_cached_images (int): The maximum length of the cache.
random_pop (bool): Whether to randomly pop a result from the cache.
"""
super().__init__()
self.resize = T.Resize(size=output_size, max_size=max_size)
self.probability = probability
self.affine_transform = T.RandomAffine(degrees=rotation_range, translate=translation_range,
scale=scaling_range, fill=fill_value)
self.use_cache = use_cache
self.mosaic_cache = []
self.max_cached_images = max_cached_images
self.random_pop = random_pop
def load_samples_from_dataset(self, image, target, dataset):
"""Loads and resizes a set of images and their corresponding targets."""
# Append the main image
get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size # torchvision >=0.17 is get_size
image, target = self.resize(image, target)
resized_images, resized_targets = [image], [target]
max_height, max_width = get_size_func(resized_images[0])
# randomly select 3 images
sample_indices = random.choices(range(len(dataset)), k=3)
for idx in sample_indices:
# image, target = dataset.load_item(idx)
image, target = self.resize(dataset.load_item(idx))
height, width = get_size_func(image)
max_height, max_width = max(max_height, height), max(max_width, width)
resized_images.append(image)
resized_targets.append(target)
return resized_images, resized_targets, max_height, max_width
def load_samples_from_cache(self, image, target, cache):
image, target = self.resize(image, target)
cache.append(dict(img=image, labels=target))
if len(cache) > self.max_cached_images:
if self.random_pop:
index = random.randint(0, len(cache) - 2) # do not remove last image
else:
index = 0
cache.pop(index)
sample_indices = random.choices(range(len(cache)), k=3)
mosaic_samples = [dict(img=cache[idx]["img"].copy(), labels=self._clone(cache[idx]["labels"])) for idx in
sample_indices] # sample 3 images
mosaic_samples = [dict(img=image.copy(), labels=self._clone(target))] + mosaic_samples
get_size_func = F.get_size if hasattr(F, "get_size") else F.get_spatial_size
sizes = [get_size_func(mosaic_samples[idx]["img"]) for idx in range(4)]
max_height = max(size[0] for size in sizes)
max_width = max(size[1] for size in sizes)
return mosaic_samples, max_height, max_width
def create_mosaic_from_cache(self, mosaic_samples, max_height, max_width):
placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]
merged_image = Image.new(mode=mosaic_samples[0]["img"].mode, size=(max_width * 2, max_height * 2), color=0)
offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2)
mosaic_target = []
for i, sample in enumerate(mosaic_samples):
img = sample["img"]
target = sample["labels"]
merged_image.paste(img, placement_offsets[i])
target['boxes'] = target['boxes'] + offsets[i]
mosaic_target.append(target)
merged_target = {}
for key in mosaic_target[0]:
merged_target[key] = torch.cat([target[key] for target in mosaic_target])
return merged_image, merged_target
def create_mosaic_from_dataset(self, images, targets, max_height, max_width):
"""Creates a mosaic image by combining multiple images."""
placement_offsets = [[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]
merged_image = Image.new(mode=images[0].mode, size=(max_width * 2, max_height * 2), color=0)
for i, img in enumerate(images):
merged_image.paste(img, placement_offsets[i])
"""Merges targets into a single target dictionary for the mosaic."""
offsets = torch.tensor([[0, 0], [max_width, 0], [0, max_height], [max_width, max_height]]).repeat(1, 2)
merged_target = {}
for key in targets[0]:
if key == 'boxes':
values = [target[key] + offsets[i] for i, target in enumerate(targets)]
else:
values = [target[key] for target in targets]
merged_target[key] = torch.cat(values, dim=0) if isinstance(values[0], torch.Tensor) else values
return merged_image, merged_target
@staticmethod
def _clone(tensor_dict):
return {key: value.clone() for (key, value) in tensor_dict.items()}
def forward(self, *inputs):
"""
Args:
inputs (tuple): Input tuple containing (image, target, dataset).
Returns:
tuple: Augmented (image, target, dataset).
"""
if len(inputs) == 1:
inputs = inputs[0]
image, target, dataset = inputs
# Skip mosaic augmentation with probability 1 - self.probability
if self.probability < 1.0 and random.random() > self.probability:
return image, target, dataset
# Prepare mosaic components
if self.use_cache:
mosaic_samples, max_height, max_width = self.load_samples_from_cache(image, target, self.mosaic_cache)
mosaic_image, mosaic_target = self.create_mosaic_from_cache(mosaic_samples, max_height, max_width)
else:
resized_images, resized_targets, max_height, max_width = self.load_samples_from_dataset(image, target,dataset)
mosaic_image, mosaic_target = self.create_mosaic_from_dataset(resized_images, resized_targets, max_height, max_width)
# Clamp boxes and convert target formats
if 'boxes' in mosaic_target:
mosaic_target['boxes'] = convert_to_tv_tensor(mosaic_target['boxes'], 'boxes', box_format='xyxy',
spatial_size=mosaic_image.size[::-1])
if 'masks' in mosaic_target:
mosaic_target['masks'] = convert_to_tv_tensor(mosaic_target['masks'], 'masks')
# Apply affine transformations
mosaic_image, mosaic_target = self.affine_transform(mosaic_image, mosaic_target)
return mosaic_image, mosaic_target, dataset
|