| """ |
| DEIM: DETR with Improved Matching for Fast Convergence |
| Copyright (c) 2024 The DEIM Authors. All Rights Reserved. |
| --------------------------------------------------------------------------------- |
| Modified from D-FINE (https://github.com/Peterande/D-FINE) |
| Copyright (c) 2024 D-FINE authors. All Rights Reserved. |
| """ |
|
|
| import torch |
| import torch.utils.data as data |
| import torch.nn.functional as F |
| from torch.utils.data import default_collate |
|
|
| import torchvision |
| import torchvision.transforms.v2 as VT |
| from torchvision.transforms.v2 import functional as VF, InterpolationMode |
|
|
| import random |
| from functools import partial |
|
|
| from ..core import register |
| torchvision.disable_beta_transforms_warning() |
| from copy import deepcopy |
| from PIL import Image, ImageDraw |
| import os |
|
|
|
|
| __all__ = [ |
| 'DataLoader', |
| 'BaseCollateFunction', |
| 'BatchImageCollateFunction', |
| 'batch_image_collate_fn' |
| ] |
|
|
|
|
| @register() |
| class DataLoader(data.DataLoader): |
| __inject__ = ['dataset', 'collate_fn'] |
|
|
| def __repr__(self) -> str: |
| format_string = self.__class__.__name__ + "(" |
| for n in ['dataset', 'batch_size', 'num_workers', 'drop_last', 'collate_fn']: |
| format_string += "\n" |
| format_string += " {0}: {1}".format(n, getattr(self, n)) |
| format_string += "\n)" |
| return format_string |
|
|
| def set_epoch(self, epoch): |
| self._epoch = epoch |
| self.dataset.set_epoch(epoch) |
| self.collate_fn.set_epoch(epoch) |
|
|
| @property |
| def epoch(self): |
| return self._epoch if hasattr(self, '_epoch') else -1 |
|
|
| @property |
| def shuffle(self): |
| return self._shuffle |
|
|
| @shuffle.setter |
| def shuffle(self, shuffle): |
| assert isinstance(shuffle, bool), 'shuffle must be a boolean' |
| self._shuffle = shuffle |
|
|
|
|
| @register() |
| def batch_image_collate_fn(items): |
| """only batch image |
| """ |
| return torch.cat([x[0][None] for x in items], dim=0), [x[1] for x in items] |
|
|
|
|
| class BaseCollateFunction(object): |
| def set_epoch(self, epoch): |
| self._epoch = epoch |
|
|
| @property |
| def epoch(self): |
| return self._epoch if hasattr(self, '_epoch') else -1 |
|
|
| def __call__(self, items): |
| raise NotImplementedError('') |
|
|
|
|
| def generate_scales(base_size, base_size_repeat): |
| scale_repeat = (base_size - int(base_size * 0.75 / 32) * 32) // 32 |
| scales = [int(base_size * 0.75 / 32) * 32 + i * 32 for i in range(scale_repeat)] |
| scales += [base_size] * base_size_repeat |
| scales += [int(base_size * 1.25 / 32) * 32 - i * 32 for i in range(scale_repeat)] |
| return scales |
|
|
|
|
| @register() |
| class BatchImageCollateFunction(BaseCollateFunction): |
| def __init__( |
| self, |
| stop_epoch=None, |
| ema_restart_decay=0.9999, |
| base_size=640, |
| base_size_repeat=None, |
| mixup_prob=0.0, |
| mixup_epochs=[0, 0], |
| data_vis=False, |
| vis_save='./vis_dataset/' |
| ) -> None: |
| super().__init__() |
| self.base_size = base_size |
| self.scales = generate_scales(base_size, base_size_repeat) if base_size_repeat is not None else None |
| self.stop_epoch = stop_epoch if stop_epoch is not None else 100000000 |
| self.ema_restart_decay = ema_restart_decay |
| |
| self.mixup_prob, self.mixup_epochs = mixup_prob, mixup_epochs |
| if self.mixup_prob > 0: |
| self.data_vis, self.vis_save = data_vis, vis_save |
| os.makedirs(self.vis_save, exist_ok=True) if self.data_vis else None |
| print(" ### Using MixUp with Prob@{} in {} epochs ### ".format(self.mixup_prob, self.mixup_epochs)) |
| if stop_epoch is not None: |
| print(" ### Multi-scale Training until {} epochs ### ".format(self.stop_epoch)) |
| print(" ### Multi-scales@ {} ### ".format(self.scales)) |
| self.print_info_flag = True |
| |
|
|
| def apply_mixup(self, images, targets): |
| """ |
| Applies Mixup augmentation to the batch if conditions are met. |
| |
| Args: |
| images (torch.Tensor): Batch of images. |
| targets (list[dict]): List of target dictionaries corresponding to images. |
| |
| Returns: |
| tuple: Updated images and targets |
| """ |
| |
| if self.epoch == self.mixup_epochs[-1] and self.print_info_flag: |
| print(f" ### Attention --- Mixup is closed after epoch@ {self.epoch} ###") |
| self.print_info_flag = False |
|
|
| |
| if random.random() < self.mixup_prob and self.mixup_epochs[0] <= self.epoch < self.mixup_epochs[-1]: |
| |
| beta = round(random.uniform(0.45, 0.55), 6) |
|
|
| |
| images = images.roll(shifts=1, dims=0).mul_(1.0 - beta).add_(images.mul(beta)) |
|
|
| |
| shifted_targets = targets[-1:] + targets[:-1] |
| updated_targets = deepcopy(targets) |
|
|
| for i in range(len(targets)): |
| |
| updated_targets[i]['boxes'] = torch.cat([targets[i]['boxes'], shifted_targets[i]['boxes']], dim=0) |
| updated_targets[i]['labels'] = torch.cat([targets[i]['labels'], shifted_targets[i]['labels']], dim=0) |
| updated_targets[i]['area'] = torch.cat([targets[i]['area'], shifted_targets[i]['area']], dim=0) |
|
|
| |
| updated_targets[i]['mixup'] = torch.tensor( |
| [beta] * len(targets[i]['labels']) + [1.0 - beta] * len(shifted_targets[i]['labels']), |
| dtype=torch.float32 |
| ) |
| targets = updated_targets |
|
|
| if self.data_vis: |
| for i in range(len(updated_targets)): |
| image_tensor = images[i] |
| image_tensor_uint8 = (image_tensor * 255).type(torch.uint8) |
| image_numpy = image_tensor_uint8.numpy().transpose((1, 2, 0)) |
| pilImage = Image.fromarray(image_numpy) |
| draw = ImageDraw.Draw(pilImage) |
| print('mix_vis:', i, 'boxes.len=', len(updated_targets[i]['boxes'])) |
| for box in updated_targets[i]['boxes']: |
| draw.rectangle([int(box[0]*640 - (box[2]*640)/2), int(box[1]*640 - (box[3]*640)/2), |
| int(box[0]*640 + (box[2]*640)/2), int(box[1]*640 + (box[3]*640)/2)], outline=(255,255,0)) |
| pilImage.save(self.vis_save + str(i) + "_"+ str(len(updated_targets[i]['boxes'])) +'_out.jpg') |
|
|
| return images, targets |
|
|
| def __call__(self, items): |
| images = torch.cat([x[0][None] for x in items], dim=0) |
| targets = [x[1] for x in items] |
|
|
| |
| images, targets = self.apply_mixup(images, targets) |
|
|
| if self.scales is not None and self.epoch < self.stop_epoch: |
| |
| |
| |
|
|
| sz = random.choice(self.scales) |
| images = F.interpolate(images, size=sz) |
| if 'masks' in targets[0]: |
| for tg in targets: |
| tg['masks'] = F.interpolate(tg['masks'], size=sz, mode='nearest') |
| raise NotImplementedError('') |
|
|
| return images, targets |
|
|