Spaces:
Runtime error
Runtime error
| # Copyright (c) OpenMMLab. All rights reserved. | |
| from mmcv.transforms import Compose | |
| from mmengine.hooks import Hook | |
| from mmdet.registry import HOOKS | |
| class PipelineSwitchHook(Hook): | |
| """Switch data pipeline at switch_epoch. | |
| Args: | |
| switch_epoch (int): switch pipeline at this epoch. | |
| switch_pipeline (list[dict]): the pipeline to switch to. | |
| """ | |
| def __init__(self, switch_epoch, switch_pipeline): | |
| self.switch_epoch = switch_epoch | |
| self.switch_pipeline = switch_pipeline | |
| self._restart_dataloader = False | |
| def before_train_epoch(self, runner): | |
| """switch pipeline.""" | |
| epoch = runner.epoch | |
| train_loader = runner.train_dataloader | |
| if epoch == self.switch_epoch: | |
| runner.logger.info('Switch pipeline now!') | |
| # The dataset pipeline cannot be updated when persistent_workers | |
| # is True, so we need to force the dataloader's multi-process | |
| # restart. This is a very hacky approach. | |
| train_loader.dataset.pipeline = Compose(self.switch_pipeline) | |
| if hasattr(train_loader, 'persistent_workers' | |
| ) and train_loader.persistent_workers is True: | |
| train_loader._DataLoader__initialized = False | |
| train_loader._iterator = None | |
| self._restart_dataloader = True | |
| else: | |
| # Once the restart is complete, we need to restore | |
| # the initialization flag. | |
| if self._restart_dataloader: | |
| train_loader._DataLoader__initialized = True | |