Spaces:
Running
on
Zero
Running
on
Zero
File size: 4,765 Bytes
11aa70b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
"""
DEIM: DETR with Improved Matching for Fast Convergence
Copyright (c) 2024 The DEIM Authors. All Rights Reserved.
---------------------------------------------------------------------------------
Modified from D-FINE (https://github.com/Peterande/D-FINE)
Copyright (c) 2024 D-FINE authors. All Rights Reserved.
"""
import torch
import torch.nn as nn
import torchvision
import torchvision.transforms.v2 as T
from typing import Any, Dict, List, Optional
from ._transforms import EmptyTransform
from ...core import register, GLOBAL_CONFIG
torchvision.disable_beta_transforms_warning()
import random
@register()
class Compose(T.Compose):
def __init__(self, ops, policy=None, mosaic_prob=-0.1) -> None:
transforms = []
if ops is not None:
for op in ops:
if isinstance(op, dict):
name = op.pop('type')
transform = getattr(GLOBAL_CONFIG[name]['_pymodule'], GLOBAL_CONFIG[name]['_name'])(**op)
transforms.append(transform)
op['type'] = name
print(" ### Transform @{} ### ".format(type(transform).__name__))
elif isinstance(op, nn.Module):
transforms.append(op)
else:
raise ValueError('')
else:
transforms =[EmptyTransform(), ]
super().__init__(transforms=transforms)
self.mosaic_prob = mosaic_prob
if policy is None:
policy = {'name': 'default'}
else:
if self.mosaic_prob > 0:
print(" ### Mosaic with Prob.@{} and ZoomOut/IoUCrop existed ### ".format(self.mosaic_prob))
print(" ### ImgTransforms Epochs: {} ### ".format(policy['epoch']))
print(' ### Policy_ops@{} ###'.format(policy['ops']))
self.global_samples = 0
self.policy = policy
def forward(self, *inputs: Any) -> Any:
return self.get_forward(self.policy['name'])(*inputs)
def get_forward(self, name):
forwards = {
'default': self.default_forward,
'stop_epoch': self.stop_epoch_forward,
'stop_sample': self.stop_sample_forward,
}
return forwards[name]
def default_forward(self, *inputs: Any) -> Any:
sample = inputs if len(inputs) > 1 else inputs[0]
for transform in self.transforms:
sample = transform(sample)
return sample
def stop_epoch_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy['ops']
policy_epoch = self.policy['epoch']
if isinstance(policy_epoch, list) and len(policy_epoch) == 3: # 4-stages
if policy_epoch[0] <= cur_epoch < policy_epoch[1]:
with_mosaic = random.random() <= self.mosaic_prob # Probility for Mosaic
else:
with_mosaic = False
for transform in self.transforms:
if (type(transform).__name__ in policy_ops and cur_epoch < policy_epoch[0]): # first stage: NoAug
pass
elif (type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch[-1]): # last stage: NoAug
pass
else:
# Using Mosaic for [policy_epoch[0], policy_epoch[1]] with probability
if (type(transform).__name__ == 'Mosaic' and not with_mosaic):
pass
# Mosaic and Zoomout/IoUCrop can not be co-existed in the same sample
elif (type(transform).__name__ == 'RandomZoomOut' or type(transform).__name__ == 'RandomIoUCrop') and with_mosaic:
pass
else:
sample = transform(sample)
else: # the default data scheduler
for transform in self.transforms:
if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch:
pass
else:
sample = transform(sample)
return sample
def stop_sample_forward(self, *inputs: Any):
sample = inputs if len(inputs) > 1 else inputs[0]
dataset = sample[-1]
cur_epoch = dataset.epoch
policy_ops = self.policy['ops']
policy_sample = self.policy['sample']
for transform in self.transforms:
if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample:
pass
else:
sample = transform(sample)
self.global_samples += 1
return sample
|