| |
| |
| |
|
|
| import copy |
|
|
| import torch |
| import torchvision.transforms as transforms |
| import torchvision.transforms.functional as F |
|
|
| from src.efficientvit.models.utils import torch_random_choices |
|
|
| __all__ = [ |
| "RRSController", |
| "get_interpolate", |
| "MyRandomResizedCrop", |
| ] |
|
|
|
|
| class RRSController: |
| ACTIVE_SIZE = (224, 224) |
| IMAGE_SIZE_LIST = [(224, 224)] |
|
|
| CHOICE_LIST = None |
|
|
| @staticmethod |
| def get_candidates() -> list[tuple[int, int]]: |
| return copy.deepcopy(RRSController.IMAGE_SIZE_LIST) |
|
|
| @staticmethod |
| def sample_resolution(batch_id: int) -> None: |
| RRSController.ACTIVE_SIZE = RRSController.CHOICE_LIST[batch_id] |
|
|
| @staticmethod |
| def set_epoch(epoch: int, batch_per_epoch: int) -> None: |
| g = torch.Generator() |
| g.manual_seed(epoch) |
| RRSController.CHOICE_LIST = torch_random_choices( |
| RRSController.get_candidates(), |
| g, |
| batch_per_epoch, |
| ) |
|
|
|
|
| def get_interpolate(name: str) -> F.InterpolationMode: |
| mapping = { |
| "nearest": F.InterpolationMode.NEAREST, |
| "bilinear": F.InterpolationMode.BILINEAR, |
| "bicubic": F.InterpolationMode.BICUBIC, |
| "box": F.InterpolationMode.BOX, |
| "hamming": F.InterpolationMode.HAMMING, |
| "lanczos": F.InterpolationMode.LANCZOS, |
| } |
| if name in mapping: |
| return mapping[name] |
| elif name == "random": |
| return torch_random_choices( |
| [ |
| F.InterpolationMode.NEAREST, |
| F.InterpolationMode.BILINEAR, |
| F.InterpolationMode.BICUBIC, |
| F.InterpolationMode.BOX, |
| F.InterpolationMode.HAMMING, |
| F.InterpolationMode.LANCZOS, |
| ], |
| ) |
| else: |
| raise NotImplementedError |
|
|
|
|
| class MyRandomResizedCrop(transforms.RandomResizedCrop): |
| def __init__( |
| self, |
| scale=(0.08, 1.0), |
| ratio=(3.0 / 4.0, 4.0 / 3.0), |
| interpolation: str = "random", |
| ): |
| super(MyRandomResizedCrop, self).__init__(224, scale, ratio) |
| self.interpolation = interpolation |
|
|
| def forward(self, img: torch.Tensor) -> torch.Tensor: |
| i, j, h, w = self.get_params(img, list(self.scale), list(self.ratio)) |
| target_size = RRSController.ACTIVE_SIZE |
| return F.resized_crop( |
| img, i, j, h, w, list(target_size), get_interpolate(self.interpolation) |
| ) |
|
|
| def __repr__(self) -> str: |
| format_string = self.__class__.__name__ |
| format_string += f"(\n\tsize={RRSController.get_candidates()},\n" |
| format_string += f"\tscale={tuple(round(s, 4) for s in self.scale)},\n" |
| format_string += f"\tratio={tuple(round(r, 4) for r in self.ratio)},\n" |
| format_string += f"\tinterpolation={self.interpolation})" |
| return format_string |
|
|