Spaces:
No application file
No application file
| """ | |
| Copied from RT-DETR (https://github.com/lyuwenyu/RT-DETR) | |
| Copyright(c) 2023 lyuwenyu. All Rights Reserved. | |
| """ | |
| from typing import Any, Dict, List, Optional | |
| import torch | |
| import torch.nn as nn | |
| import torchvision | |
| import torchvision.transforms.v2 as T | |
| from ...core import GLOBAL_CONFIG, register | |
| from ._transforms import EmptyTransform | |
| torchvision.disable_beta_transforms_warning() | |
| class Compose(T.Compose): | |
| def __init__(self, ops, policy=None) -> None: | |
| transforms = [] | |
| if ops is not None: | |
| for op in ops: | |
| if isinstance(op, dict): | |
| name = op.pop("type") | |
| transform = getattr( | |
| GLOBAL_CONFIG[name]["_pymodule"], GLOBAL_CONFIG[name]["_name"] | |
| )(**op) | |
| transforms.append(transform) | |
| op["type"] = name | |
| elif isinstance(op, nn.Module): | |
| transforms.append(op) | |
| else: | |
| raise ValueError("") | |
| else: | |
| transforms = [ | |
| EmptyTransform(), | |
| ] | |
| super().__init__(transforms=transforms) | |
| if policy is None: | |
| policy = {"name": "default"} | |
| self.policy = policy | |
| self.global_samples = 0 | |
| def forward(self, *inputs: Any) -> Any: | |
| return self.get_forward(self.policy["name"])(*inputs) | |
| def get_forward(self, name): | |
| forwards = { | |
| "default": self.default_forward, | |
| "stop_epoch": self.stop_epoch_forward, | |
| "stop_sample": self.stop_sample_forward, | |
| } | |
| return forwards[name] | |
| def default_forward(self, *inputs: Any) -> Any: | |
| sample = inputs if len(inputs) > 1 else inputs[0] | |
| for transform in self.transforms: | |
| sample = transform(sample) | |
| return sample | |
| def stop_epoch_forward(self, *inputs: Any): | |
| sample = inputs if len(inputs) > 1 else inputs[0] | |
| dataset = sample[-1] | |
| cur_epoch = dataset.epoch | |
| policy_ops = self.policy["ops"] | |
| policy_epoch = self.policy["epoch"] | |
| for transform in self.transforms: | |
| if type(transform).__name__ in policy_ops and cur_epoch >= policy_epoch: | |
| pass | |
| else: | |
| sample = transform(sample) | |
| return sample | |
| def stop_sample_forward(self, *inputs: Any): | |
| sample = inputs if len(inputs) > 1 else inputs[0] | |
| dataset = sample[-1] | |
| cur_epoch = dataset.epoch | |
| policy_ops = self.policy["ops"] | |
| policy_sample = self.policy["sample"] | |
| for transform in self.transforms: | |
| if type(transform).__name__ in policy_ops and self.global_samples >= policy_sample: | |
| pass | |
| else: | |
| sample = transform(sample) | |
| self.global_samples += 1 | |
| return sample | |