Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from mmcv.cnn import VGG | |
| from mmengine.hooks import Hook | |
| from mmengine.runner import Runner | |
| from mmdet.registry import HOOKS | |
| class NumClassCheckHook(Hook): | |
| """Check whether the `num_classes` in head matches the length of `classes` | |
| in `dataset.metainfo`.""" | |
| def _check_head(self, runner: Runner, mode: str) -> None: | |
| """Check whether the `num_classes` in head matches the length of | |
| `classes` in `dataset.metainfo`. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the training or evaluation | |
| process. | |
| """ | |
| assert mode in ['train', 'val'] | |
| model = runner.model | |
| dataset = runner.train_dataloader.dataset if mode == 'train' else \ | |
| runner.val_dataloader.dataset | |
| if dataset.metainfo.get('classes', None) is None: | |
| runner.logger.warning( | |
| f'Please set `classes` ' | |
| f'in the {dataset.__class__.__name__} `metainfo` and' | |
| f'check if it is consistent with the `num_classes` ' | |
| f'of head') | |
| else: | |
| classes = dataset.metainfo['classes'] | |
| assert type(classes) is not str, \ | |
| (f'`classes` in {dataset.__class__.__name__}' | |
| f'should be a tuple of str.' | |
| f'Add comma if number of classes is 1 as ' | |
| f'classes = ({classes},)') | |
| from mmdet.models.roi_heads.mask_heads import FusedSemanticHead | |
| for name, module in model.named_modules(): | |
| if hasattr(module, 'num_classes') and not name.endswith( | |
| 'rpn_head') and not isinstance( | |
| module, (VGG, FusedSemanticHead)): | |
| assert module.num_classes == len(classes), \ | |
| (f'The `num_classes` ({module.num_classes}) in ' | |
| f'{module.__class__.__name__} of ' | |
| f'{model.__class__.__name__} does not matches ' | |
| f'the length of `classes` ' | |
| f'{len(classes)}) in ' | |
| f'{dataset.__class__.__name__}') | |
| def before_train_epoch(self, runner: Runner) -> None: | |
| """Check whether the training dataset is compatible with head. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the training or evaluation | |
| process. | |
| """ | |
| self._check_head(runner, 'train') | |
| def before_val_epoch(self, runner: Runner) -> None: | |
| """Check whether the dataset in val epoch is compatible with head. | |
| Args: | |
| runner (:obj:`Runner`): The runner of the training or evaluation | |
| process. | |
| """ | |
| self._check_head(runner, 'val') | |