| | |
| | from typing import Sequence |
| |
|
| | from torch.utils.data import BatchSampler, Sampler |
| |
|
| | from mmdet.registry import DATA_SAMPLERS |
| |
|
| |
|
| | |
| | @DATA_SAMPLERS.register_module() |
| | class AspectRatioBatchSampler(BatchSampler): |
| | """A sampler wrapper for grouping images with similar aspect ratio (< 1 or. |
| | |
| | >= 1) into a same batch. |
| | |
| | Args: |
| | sampler (Sampler): Base sampler. |
| | batch_size (int): Size of mini-batch. |
| | drop_last (bool): If ``True``, the sampler will drop the last batch if |
| | its size would be less than ``batch_size``. |
| | """ |
| |
|
| | def __init__(self, |
| | sampler: Sampler, |
| | batch_size: int, |
| | drop_last: bool = False) -> None: |
| | if not isinstance(sampler, Sampler): |
| | raise TypeError('sampler should be an instance of ``Sampler``, ' |
| | f'but got {sampler}') |
| | if not isinstance(batch_size, int) or batch_size <= 0: |
| | raise ValueError('batch_size should be a positive integer value, ' |
| | f'but got batch_size={batch_size}') |
| | self.sampler = sampler |
| | self.batch_size = batch_size |
| | self.drop_last = drop_last |
| | |
| | self._aspect_ratio_buckets = [[] for _ in range(2)] |
| |
|
| | def __iter__(self) -> Sequence[int]: |
| | for idx in self.sampler: |
| | data_info = self.sampler.dataset.get_data_info(idx) |
| | width, height = data_info['width'], data_info['height'] |
| | bucket_id = 0 if width < height else 1 |
| | bucket = self._aspect_ratio_buckets[bucket_id] |
| | bucket.append(idx) |
| | |
| | if len(bucket) == self.batch_size: |
| | yield bucket[:] |
| | del bucket[:] |
| |
|
| | |
| | left_data = self._aspect_ratio_buckets[0] + self._aspect_ratio_buckets[ |
| | 1] |
| | self._aspect_ratio_buckets = [[] for _ in range(2)] |
| | while len(left_data) > 0: |
| | if len(left_data) <= self.batch_size: |
| | if not self.drop_last: |
| | yield left_data[:] |
| | left_data = [] |
| | else: |
| | yield left_data[:self.batch_size] |
| | left_data = left_data[self.batch_size:] |
| |
|
| | def __len__(self) -> int: |
| | if self.drop_last: |
| | return len(self.sampler) // self.batch_size |
| | else: |
| | return (len(self.sampler) + self.batch_size - 1) // self.batch_size |
| |
|