himipo's picture
first
11aa70b
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE)
Copyright (c) 2024 D-FINE authors. All Rights Reserved.
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as T
from typing import Any, Dict, List, Optional
from ._transforms import EmptyTransform
from ...core import register, GLOBAL_CONFIG
torchvision.disable_beta_transforms_warning()
import random
@register()
class Compose(T.Compose):
def __init__(self, ops, policy=None, mosaic_prob=-0.1) -> None:
transforms = []
if ops is not None:
for op in ops:
if isinstance(op, dict):
name = op.pop('type')
transform = getattr(GLOBAL_CONFIG[name]['_pymodule'], GLOBAL_CONFIG[name]['_name'])(**op)
transforms.append(transform)
op['type'] = name
print(" ### Transform @{} ### ".format(type(transform).__name__))
elif isinstance(op, nn.Module):
transforms.append(op)
else:
raise ValueError('')
else:
transforms =[EmptyTransform(), ]
super().__init__(transforms=transforms)
self.mosaic_prob = mosaic_prob
if policy is None:
policy = {'name': 'default'}
else:
if self.mosaic_prob > 0:
print(" ### Mosaic with Prob.@{} and ZoomOut/IoUCrop existed ### ".format(self.mosaic_prob))
print(" ### ImgTransforms Epochs: {} ### ".format(policy['epoch']))
print(' ### Policy_ops@{} ###'.format(policy['ops']))
self.global_samples = 0
self.policy = policy
def forward(self, *inputs: Any) -> Any:
return self.get_forward(self.policy['name'])(*inputs)
def get_forward(self, name):
forwards = {
'default': self.default_forward,
'stop_epoch': self.stop_epoch_forward,
'stop_sample': self.stop_sample_forward,
}
return forwards[name]
def default_forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
def stop_epoch_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy['ops']
policy_epoch = self.policy['epoch']
if isinstance(policy_epoch, list) and len(policy_epoch) == 3: # 4-stages
if policy_epoch[0] <= cur_epoch < policy_epoch[1]:
with_mosaic = random.random() <= self.mosaic_prob # Probility for Mosaic
else:
with_mosaic = False
for transform in self.transforms:
if (type(transform).__name__ in policy_ops and cur_epoch < policy_epoch[0]): # first stage: NoAug
pass
elif (type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch[-1]): # last stage: NoAug
pass
else:
# Using Mosaic for [policy_epoch[0], policy_epoch[1]] with probability
if (type(transform).__name__ == 'Mosaic' and not with_mosaic):
pass
# Mosaic and Zoomout/IoUCrop can not be co-existed in the same sample
elif (type(transform).__name__ == 'RandomZoomOut' or type(transform).__name__ == 'RandomIoUCrop') and with_mosaic:
pass
else:
sample = transform(sample)
else: # the default data scheduler
for transform in self.transforms:
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
pass
else:
sample = transform(sample)
return sample
def stop_sample_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy['ops']
policy_sample = self.policy['sample']
for transform in self.transforms:
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
pass
else:
sample = transform(sample)
self.global_samples += 1
return sample