| | |
| | import random |
| | from typing import Any, Sequence |
| |
|
| | import torch |
| | from mmengine.dataset import COLLATE_FUNCTIONS |
| | from mmengine.logging import print_log |
| | from mmyolo.datasets.yolov5_coco import BatchShapePolicyDataset |
| |
|
| |
|
| | class RobustBatchShapePolicyDataset(BatchShapePolicyDataset): |
| | """Dataset with the batch shape policy that makes paddings with least |
| | pixels during batch inference process, which does not require the image |
| | scales of all batches to be the same throughout validation.""" |
| |
|
| | def _prepare_data(self, idx: int) -> Any: |
| | if self.test_mode is False: |
| | data_info = self.get_data_info(idx) |
| | data_info['dataset'] = self |
| | return self.pipeline(data_info) |
| | else: |
| | return super().prepare_data(idx) |
| |
|
| | def prepare_data(self, idx: int, timeout=10) -> Any: |
| | """Pass the dataset to the pipeline during training to support mixed |
| | data augmentation, such as Mosaic and MixUp.""" |
| | try: |
| | return self._prepare_data(idx) |
| | except Exception as e: |
| | if timeout <= 0: |
| | raise e |
| | print_log(f'Failed to prepare data, due to {e}.' |
| | f'Retrying {timeout} attempts.') |
| | if not self.test_mode: |
| | idx = random.randrange(len(self)) |
| | return self.prepare_data(idx, timeout=timeout - 1) |
| |
|
| |
|
| | @COLLATE_FUNCTIONS.register_module() |
| | def yolow_collate(data_batch: Sequence, |
| | use_ms_training: bool = False) -> dict: |
| | """Rewrite collate_fn to get faster training speed. |
| | |
| | Args: |
| | data_batch (Sequence): Batch of data. |
| | use_ms_training (bool): Whether to use multi-scale training. |
| | """ |
| | batch_imgs = [] |
| | batch_bboxes_labels = [] |
| | batch_masks = [] |
| | for i in range(len(data_batch)): |
| | datasamples = data_batch[i]['data_samples'] |
| | inputs = data_batch[i]['inputs'] |
| | batch_imgs.append(inputs) |
| |
|
| | gt_bboxes = datasamples.gt_instances.bboxes.tensor |
| | gt_labels = datasamples.gt_instances.labels |
| | if 'masks' in datasamples.gt_instances: |
| | masks = datasamples.gt_instances.masks.to_tensor( |
| | dtype=torch.bool, device=gt_bboxes.device) |
| | batch_masks.append(masks) |
| | batch_idx = gt_labels.new_full((len(gt_labels), 1), i) |
| | bboxes_labels = torch.cat((batch_idx, gt_labels[:, None], gt_bboxes), |
| | dim=1) |
| | batch_bboxes_labels.append(bboxes_labels) |
| |
|
| | collated_results = { |
| | 'data_samples': { |
| | 'bboxes_labels': torch.cat(batch_bboxes_labels, 0) |
| | } |
| | } |
| | if len(batch_masks) > 0: |
| | collated_results['data_samples']['masks'] = torch.cat(batch_masks, 0) |
| |
|
| | if use_ms_training: |
| | collated_results['inputs'] = batch_imgs |
| | else: |
| | collated_results['inputs'] = torch.stack(batch_imgs, 0) |
| |
|
| | if hasattr(data_batch[0]['data_samples'], 'texts'): |
| | batch_texts = [meta['data_samples'].texts for meta in data_batch] |
| | collated_results['data_samples']['texts'] = batch_texts |
| |
|
| | if hasattr(data_batch[0]['data_samples'], 'is_detection'): |
| | |
| | batch_detection = [meta['data_samples'].is_detection |
| | for meta in data_batch] |
| | collated_results['data_samples']['is_detection'] = torch.tensor( |
| | batch_detection) |
| |
|
| | return collated_results |
| |
|