Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from typing import Callable, Union | |
| import numpy as np | |
| import torch | |
| from mmpretrain.registry import BATCH_AUGMENTS | |
| class RandomBatchAugment: | |
| """Randomly choose one batch augmentation to apply. | |
| Args: | |
| augments (Callable | dict | list): configs of batch | |
| augmentations. | |
| probs (float | List[float] | None): The probabilities of each batch | |
| augmentations. If None, choose evenly. Defaults to None. | |
| Example: | |
| >>> import torch | |
| >>> import torch.nn.functional as F | |
| >>> from mmpretrain.models import RandomBatchAugment | |
| >>> augments_cfg = [ | |
| ... dict(type='CutMix', alpha=1.), | |
| ... dict(type='Mixup', alpha=1.) | |
| ... ] | |
| >>> batch_augment = RandomBatchAugment(augments_cfg, probs=[0.5, 0.3]) | |
| >>> imgs = torch.rand(16, 3, 32, 32) | |
| >>> label = F.one_hot(torch.randint(0, 10, (16, )), num_classes=10) | |
| >>> imgs, label = batch_augment(imgs, label) | |
| .. note :: | |
| To decide which batch augmentation will be used, it picks one of | |
| ``augments`` based on the probabilities. In the example above, the | |
| probability to use CutMix is 0.5, to use Mixup is 0.3, and to do | |
| nothing is 0.2. | |
| """ | |
| def __init__(self, augments: Union[Callable, dict, list], probs=None): | |
| if not isinstance(augments, (tuple, list)): | |
| augments = [augments] | |
| self.augments = [] | |
| for aug in augments: | |
| if isinstance(aug, dict): | |
| self.augments.append(BATCH_AUGMENTS.build(aug)) | |
| else: | |
| self.augments.append(aug) | |
| if isinstance(probs, float): | |
| probs = [probs] | |
| if probs is not None: | |
| assert len(augments) == len(probs), \ | |
| '``augments`` and ``probs`` must have same lengths. ' \ | |
| f'Got {len(augments)} vs {len(probs)}.' | |
| assert sum(probs) <= 1, \ | |
| 'The total probability of batch augments exceeds 1.' | |
| self.augments.append(None) | |
| probs.append(1 - sum(probs)) | |
| self.probs = probs | |
| def __call__(self, batch_input: torch.Tensor, batch_score: torch.Tensor): | |
| """Randomly apply batch augmentations to the batch inputs and batch | |
| data samples.""" | |
| aug_index = np.random.choice(len(self.augments), p=self.probs) | |
| aug = self.augments[aug_index] | |
| if aug is not None: | |
| return aug(batch_input, batch_score) | |
| else: | |
| return batch_input, batch_score.float() | |