| | |
| | import os.path as osp |
| | import platform |
| | import shutil |
| | import time |
| | import warnings |
| |
|
| | import torch |
| | from torch.optim import Optimizer |
| |
|
| | import annotator.uniformer.mmcv as mmcv |
| | from .base_runner import BaseRunner |
| | from .builder import RUNNERS |
| | from .checkpoint import save_checkpoint |
| | from .hooks import IterTimerHook |
| | from .utils import get_host_info |
| |
|
| |
|
| | class IterLoader: |
| |
|
| | def __init__(self, dataloader): |
| | self._dataloader = dataloader |
| | self.iter_loader = iter(self._dataloader) |
| | self._epoch = 0 |
| |
|
| | @property |
| | def epoch(self): |
| | return self._epoch |
| |
|
| | def __next__(self): |
| | try: |
| | data = next(self.iter_loader) |
| | except StopIteration: |
| | self._epoch += 1 |
| | if hasattr(self._dataloader.sampler, 'set_epoch'): |
| | self._dataloader.sampler.set_epoch(self._epoch) |
| | time.sleep(2) |
| | self.iter_loader = iter(self._dataloader) |
| | data = next(self.iter_loader) |
| |
|
| | return data |
| |
|
| | def __len__(self): |
| | return len(self._dataloader) |
| |
|
| |
|
| | @RUNNERS.register_module() |
| | class IterBasedRunner(BaseRunner): |
| | """Iteration-based Runner. |
| | |
| | This runner train models iteration by iteration. |
| | """ |
| |
|
| | def train(self, data_loader, **kwargs): |
| | self.model.train() |
| | self.mode = 'train' |
| | self.data_loader = data_loader |
| | self._epoch = data_loader.epoch |
| | data_batch = next(data_loader) |
| | self.call_hook('before_train_iter') |
| | outputs = self.model.train_step(data_batch, self.optimizer, **kwargs) |
| | if not isinstance(outputs, dict): |
| | raise TypeError('model.train_step() must return a dict') |
| | if 'log_vars' in outputs: |
| | self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) |
| | self.outputs = outputs |
| | self.call_hook('after_train_iter') |
| | self._inner_iter += 1 |
| | self._iter += 1 |
| |
|
| | @torch.no_grad() |
| | def val(self, data_loader, **kwargs): |
| | self.model.eval() |
| | self.mode = 'val' |
| | self.data_loader = data_loader |
| | data_batch = next(data_loader) |
| | self.call_hook('before_val_iter') |
| | outputs = self.model.val_step(data_batch, **kwargs) |
| | if not isinstance(outputs, dict): |
| | raise TypeError('model.val_step() must return a dict') |
| | if 'log_vars' in outputs: |
| | self.log_buffer.update(outputs['log_vars'], outputs['num_samples']) |
| | self.outputs = outputs |
| | self.call_hook('after_val_iter') |
| | self._inner_iter += 1 |
| |
|
| | def run(self, data_loaders, workflow, max_iters=None, **kwargs): |
| | """Start running. |
| | |
| | Args: |
| | data_loaders (list[:obj:`DataLoader`]): Dataloaders for training |
| | and validation. |
| | workflow (list[tuple]): A list of (phase, iters) to specify the |
| | running order and iterations. E.g, [('train', 10000), |
| | ('val', 1000)] means running 10000 iterations for training and |
| | 1000 iterations for validation, iteratively. |
| | """ |
| | assert isinstance(data_loaders, list) |
| | assert mmcv.is_list_of(workflow, tuple) |
| | assert len(data_loaders) == len(workflow) |
| | if max_iters is not None: |
| | warnings.warn( |
| | 'setting max_iters in run is deprecated, ' |
| | 'please set max_iters in runner_config', DeprecationWarning) |
| | self._max_iters = max_iters |
| | assert self._max_iters is not None, ( |
| | 'max_iters must be specified during instantiation') |
| |
|
| | work_dir = self.work_dir if self.work_dir is not None else 'NONE' |
| | self.logger.info('Start running, host: %s, work_dir: %s', |
| | get_host_info(), work_dir) |
| | self.logger.info('Hooks will be executed in the following order:\n%s', |
| | self.get_hook_info()) |
| | self.logger.info('workflow: %s, max: %d iters', workflow, |
| | self._max_iters) |
| | self.call_hook('before_run') |
| |
|
| | iter_loaders = [IterLoader(x) for x in data_loaders] |
| |
|
| | self.call_hook('before_epoch') |
| |
|
| | while self.iter < self._max_iters: |
| | for i, flow in enumerate(workflow): |
| | self._inner_iter = 0 |
| | mode, iters = flow |
| | if not isinstance(mode, str) or not hasattr(self, mode): |
| | raise ValueError( |
| | 'runner has no method named "{}" to run a workflow'. |
| | format(mode)) |
| | iter_runner = getattr(self, mode) |
| | for _ in range(iters): |
| | if mode == 'train' and self.iter >= self._max_iters: |
| | break |
| | iter_runner(iter_loaders[i], **kwargs) |
| |
|
| | time.sleep(1) |
| | self.call_hook('after_epoch') |
| | self.call_hook('after_run') |
| |
|
| | def resume(self, |
| | checkpoint, |
| | resume_optimizer=True, |
| | map_location='default'): |
| | """Resume model from checkpoint. |
| | |
| | Args: |
| | checkpoint (str): Checkpoint to resume from. |
| | resume_optimizer (bool, optional): Whether resume the optimizer(s) |
| | if the checkpoint file includes optimizer(s). Default to True. |
| | map_location (str, optional): Same as :func:`torch.load`. |
| | Default to 'default'. |
| | """ |
| | if map_location == 'default': |
| | device_id = torch.cuda.current_device() |
| | checkpoint = self.load_checkpoint( |
| | checkpoint, |
| | map_location=lambda storage, loc: storage.cuda(device_id)) |
| | else: |
| | checkpoint = self.load_checkpoint( |
| | checkpoint, map_location=map_location) |
| |
|
| | self._epoch = checkpoint['meta']['epoch'] |
| | self._iter = checkpoint['meta']['iter'] |
| | self._inner_iter = checkpoint['meta']['iter'] |
| | if 'optimizer' in checkpoint and resume_optimizer: |
| | if isinstance(self.optimizer, Optimizer): |
| | self.optimizer.load_state_dict(checkpoint['optimizer']) |
| | elif isinstance(self.optimizer, dict): |
| | for k in self.optimizer.keys(): |
| | self.optimizer[k].load_state_dict( |
| | checkpoint['optimizer'][k]) |
| | else: |
| | raise TypeError( |
| | 'Optimizer should be dict or torch.optim.Optimizer ' |
| | f'but got {type(self.optimizer)}') |
| |
|
| | self.logger.info(f'resumed from epoch: {self.epoch}, iter {self.iter}') |
| |
|
| | def save_checkpoint(self, |
| | out_dir, |
| | filename_tmpl='iter_{}.pth', |
| | meta=None, |
| | save_optimizer=True, |
| | create_symlink=True): |
| | """Save checkpoint to file. |
| | |
| | Args: |
| | out_dir (str): Directory to save checkpoint files. |
| | filename_tmpl (str, optional): Checkpoint file template. |
| | Defaults to 'iter_{}.pth'. |
| | meta (dict, optional): Metadata to be saved in checkpoint. |
| | Defaults to None. |
| | save_optimizer (bool, optional): Whether save optimizer. |
| | Defaults to True. |
| | create_symlink (bool, optional): Whether create symlink to the |
| | latest checkpoint file. Defaults to True. |
| | """ |
| | if meta is None: |
| | meta = {} |
| | elif not isinstance(meta, dict): |
| | raise TypeError( |
| | f'meta should be a dict or None, but got {type(meta)}') |
| | if self.meta is not None: |
| | meta.update(self.meta) |
| | |
| | |
| | |
| | |
| | meta.update(epoch=self.epoch + 1, iter=self.iter) |
| |
|
| | filename = filename_tmpl.format(self.iter + 1) |
| | filepath = osp.join(out_dir, filename) |
| | optimizer = self.optimizer if save_optimizer else None |
| | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) |
| | |
| | |
| | if create_symlink: |
| | dst_file = osp.join(out_dir, 'latest.pth') |
| | if platform.system() != 'Windows': |
| | mmcv.symlink(filename, dst_file) |
| | else: |
| | shutil.copy(filepath, dst_file) |
| |
|
| | def register_training_hooks(self, |
| | lr_config, |
| | optimizer_config=None, |
| | checkpoint_config=None, |
| | log_config=None, |
| | momentum_config=None, |
| | custom_hooks_config=None): |
| | """Register default hooks for iter-based training. |
| | |
| | Checkpoint hook, optimizer stepper hook and logger hooks will be set to |
| | `by_epoch=False` by default. |
| | |
| | Default hooks include: |
| | |
| | +----------------------+-------------------------+ |
| | | Hooks | Priority | |
| | +======================+=========================+ |
| | | LrUpdaterHook | VERY_HIGH (10) | |
| | +----------------------+-------------------------+ |
| | | MomentumUpdaterHook | HIGH (30) | |
| | +----------------------+-------------------------+ |
| | | OptimizerStepperHook | ABOVE_NORMAL (40) | |
| | +----------------------+-------------------------+ |
| | | CheckpointSaverHook | NORMAL (50) | |
| | +----------------------+-------------------------+ |
| | | IterTimerHook | LOW (70) | |
| | +----------------------+-------------------------+ |
| | | LoggerHook(s) | VERY_LOW (90) | |
| | +----------------------+-------------------------+ |
| | | CustomHook(s) | defaults to NORMAL (50) | |
| | +----------------------+-------------------------+ |
| | |
| | If custom hooks have same priority with default hooks, custom hooks |
| | will be triggered after default hooks. |
| | """ |
| | if checkpoint_config is not None: |
| | checkpoint_config.setdefault('by_epoch', False) |
| | if lr_config is not None: |
| | lr_config.setdefault('by_epoch', False) |
| | if log_config is not None: |
| | for info in log_config['hooks']: |
| | info.setdefault('by_epoch', False) |
| | super(IterBasedRunner, self).register_training_hooks( |
| | lr_config=lr_config, |
| | momentum_config=momentum_config, |
| | optimizer_config=optimizer_config, |
| | checkpoint_config=checkpoint_config, |
| | log_config=log_config, |
| | timer_config=IterTimerHook(), |
| | custom_hooks_config=custom_hooks_config) |
| |
|