| | |
| | import os.path as osp |
| | import platform |
| | import shutil |
| | import time |
| | import warnings |
| |
|
| | import torch |
| |
|
| | import annotator.uniformer.mmcv as mmcv |
| | from .base_runner import BaseRunner |
| | from .builder import RUNNERS |
| | from .checkpoint import save_checkpoint |
| | from .utils import get_host_info |
| |
|
| |
|
| | @RUNNERS.register_module() |
| | class EpochBasedRunner(BaseRunner): |
| | """Epoch-based Runner. |
| | |
| | This runner train models epoch by epoch. |
| | """ |
| |
|
| | def run_iter(self, data_batch, train_mode, **kwargs): |
| | if self.batch_processor is not None: |
| | outputs = self.batch_processor( |
| | self.model, data_batch, train_mode=train_mode, **kwargs) |
| | elif train_mode: |
| | outputs = self.model.train_step(data_batch, self.optimizer, |
| | **kwargs) |
| | else: |
| | outputs = self.model.val_step(data_batch, self.optimizer, **kwargs) |
| | if not isinstance(outputs, dict): |
| | raise TypeError('"batch_processor()" or "model.train_step()"' |
| | 'and "model.val_step()" must return a dict') |
| | if 'log_vars' in outputs: |
| | self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) |
| | self.outputs = outputs |
| |
|
| | def train(self, data_loader, **kwargs): |
| | self.model.train() |
| | self.mode = 'train' |
| | self.data_loader = data_loader |
| | self._max_iters = self._max_epochs * len(self.data_loader) |
| | self.call_hook('before_train_epoch') |
| | time.sleep(2) |
| | for i, data_batch in enumerate(self.data_loader): |
| | self._inner_iter = i |
| | self.call_hook('before_train_iter') |
| | self.run_iter(data_batch, train_mode=True, **kwargs) |
| | self.call_hook('after_train_iter') |
| | self._iter += 1 |
| |
|
| | self.call_hook('after_train_epoch') |
| | self._epoch += 1 |
| |
|
| | @torch.no_grad() |
| | def val(self, data_loader, **kwargs): |
| | self.model.eval() |
| | self.mode = 'val' |
| | self.data_loader = data_loader |
| | self.call_hook('before_val_epoch') |
| | time.sleep(2) |
| | for i, data_batch in enumerate(self.data_loader): |
| | self._inner_iter = i |
| | self.call_hook('before_val_iter') |
| | self.run_iter(data_batch, train_mode=False) |
| | self.call_hook('after_val_iter') |
| |
|
| | self.call_hook('after_val_epoch') |
| |
|
| | def run(self, data_loaders, workflow, max_epochs=None, **kwargs): |
| | """Start running. |
| | |
| | Args: |
| | data_loaders (list[:obj:`DataLoader`]): Dataloaders for training |
| | and validation. |
| | workflow (list[tuple]): A list of (phase, epochs) to specify the |
| | running order and epochs. E.g, [('train', 2), ('val', 1)] means |
| | running 2 epochs for training and 1 epoch for validation, |
| | iteratively. |
| | """ |
| | assert isinstance(data_loaders, list) |
| | assert mmcv.is_list_of(workflow, tuple) |
| | assert len(data_loaders) == len(workflow) |
| | if max_epochs is not None: |
| | warnings.warn( |
| | 'setting max_epochs in run is deprecated, ' |
| | 'please set max_epochs in runner_config', DeprecationWarning) |
| | self._max_epochs = max_epochs |
| |
|
| | assert self._max_epochs is not None, ( |
| | 'max_epochs must be specified during instantiation') |
| |
|
| | for i, flow in enumerate(workflow): |
| | mode, epochs = flow |
| | if mode == 'train': |
| | self._max_iters = self._max_epochs * len(data_loaders[i]) |
| | break |
| |
|
| | work_dir = self.work_dir if self.work_dir is not None else 'NONE' |
| | self.logger.info('Start running, host: %s, work_dir: %s', |
| | get_host_info(), work_dir) |
| | self.logger.info('Hooks will be executed in the following order:\n%s', |
| | self.get_hook_info()) |
| | self.logger.info('workflow: %s, max: %d epochs', workflow, |
| | self._max_epochs) |
| | self.call_hook('before_run') |
| |
|
| | while self.epoch < self._max_epochs: |
| | for i, flow in enumerate(workflow): |
| | mode, epochs = flow |
| | if isinstance(mode, str): |
| | if not hasattr(self, mode): |
| | raise ValueError( |
| | f'runner has no method named "{mode}" to run an ' |
| | 'epoch') |
| | epoch_runner = getattr(self, mode) |
| | else: |
| | raise TypeError( |
| | 'mode in workflow must be a str, but got {}'.format( |
| | type(mode))) |
| |
|
| | for _ in range(epochs): |
| | if mode == 'train' and self.epoch >= self._max_epochs: |
| | break |
| | epoch_runner(data_loaders[i], **kwargs) |
| |
|
| | time.sleep(1) |
| | self.call_hook('after_run') |
| |
|
| | def save_checkpoint(self, |
| | out_dir, |
| | filename_tmpl='epoch_{}.pth', |
| | save_optimizer=True, |
| | meta=None, |
| | create_symlink=True): |
| | """Save the checkpoint. |
| | |
| | Args: |
| | out_dir (str): The directory that checkpoints are saved. |
| | filename_tmpl (str, optional): The checkpoint filename template, |
| | which contains a placeholder for the epoch number. |
| | Defaults to 'epoch_{}.pth'. |
| | save_optimizer (bool, optional): Whether to save the optimizer to |
| | the checkpoint. Defaults to True. |
| | meta (dict, optional): The meta information to be saved in the |
| | checkpoint. Defaults to None. |
| | create_symlink (bool, optional): Whether to create a symlink |
| | "latest.pth" to point to the latest checkpoint. |
| | Defaults to True. |
| | """ |
| | if meta is None: |
| | meta = {} |
| | elif not isinstance(meta, dict): |
| | raise TypeError( |
| | f'meta should be a dict or None, but got {type(meta)}') |
| | if self.meta is not None: |
| | meta.update(self.meta) |
| | |
| | |
| | |
| | |
| | meta.update(epoch=self.epoch + 1, iter=self.iter) |
| |
|
| | filename = filename_tmpl.format(self.epoch + 1) |
| | filepath = osp.join(out_dir, filename) |
| | optimizer = self.optimizer if save_optimizer else None |
| | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) |
| | |
| | |
| | if create_symlink: |
| | dst_file = osp.join(out_dir, 'latest.pth') |
| | if platform.system() != 'Windows': |
| | mmcv.symlink(filename, dst_file) |
| | else: |
| | shutil.copy(filepath, dst_file) |
| |
|
| |
|
| | @RUNNERS.register_module() |
| | class Runner(EpochBasedRunner): |
| | """Deprecated name of EpochBasedRunner.""" |
| |
|
| | def __init__(self, *args, **kwargs): |
| | warnings.warn( |
| | 'Runner was deprecated, please use EpochBasedRunner instead') |
| | super().__init__(*args, **kwargs) |
| |
|