| | |
| | import os.path as osp |
| | import platform |
| | import shutil |
| |
|
| | import torch |
| | from torch.optim import Optimizer |
| |
|
| | import mmcv |
| | from mmcv.runner import RUNNERS, EpochBasedRunner |
| | from .checkpoint import save_checkpoint |
| |
|
| | try: |
| | import apex |
| | except: |
| | print('apex is not installed') |
| |
|
| |
|
| | @RUNNERS.register_module() |
| | class EpochBasedRunnerAmp(EpochBasedRunner): |
| | """Epoch-based Runner with AMP support. |
| | |
| | This runner train models epoch by epoch. |
| | """ |
| |
|
| | def save_checkpoint(self, |
| | out_dir, |
| | filename_tmpl='epoch_{}.pth', |
| | save_optimizer=True, |
| | meta=None, |
| | create_symlink=True): |
| | """Save the checkpoint. |
| | |
| | Args: |
| | out_dir (str): The directory that checkpoints are saved. |
| | filename_tmpl (str, optional): The checkpoint filename template, |
| | which contains a placeholder for the epoch number. |
| | Defaults to 'epoch_{}.pth'. |
| | save_optimizer (bool, optional): Whether to save the optimizer to |
| | the checkpoint. Defaults to True. |
| | meta (dict, optional): The meta information to be saved in the |
| | checkpoint. Defaults to None. |
| | create_symlink (bool, optional): Whether to create a symlink |
| | "latest.pth" to point to the latest checkpoint. |
| | Defaults to True. |
| | """ |
| | if meta is None: |
| | meta = dict(epoch=self.epoch + 1, iter=self.iter) |
| | elif isinstance(meta, dict): |
| | meta.update(epoch=self.epoch + 1, iter=self.iter) |
| | else: |
| | raise TypeError( |
| | f'meta should be a dict or None, but got {type(meta)}') |
| | if self.meta is not None: |
| | meta.update(self.meta) |
| |
|
| | filename = filename_tmpl.format(self.epoch + 1) |
| | filepath = osp.join(out_dir, filename) |
| | optimizer = self.optimizer if save_optimizer else None |
| | save_checkpoint(self.model, filepath, optimizer=optimizer, meta=meta) |
| | |
| | |
| | if create_symlink: |
| | dst_file = osp.join(out_dir, 'latest.pth') |
| | if platform.system() != 'Windows': |
| | mmcv.symlink(filename, dst_file) |
| | else: |
| | shutil.copy(filepath, dst_file) |
| |
|
| | def resume(self, |
| | checkpoint, |
| | resume_optimizer=True, |
| | map_location='default'): |
| | if map_location == 'default': |
| | if torch.cuda.is_available(): |
| | device_id = torch.cuda.current_device() |
| | checkpoint = self.load_checkpoint( |
| | checkpoint, |
| | map_location=lambda storage, loc: storage.cuda(device_id)) |
| | else: |
| | checkpoint = self.load_checkpoint(checkpoint) |
| | else: |
| | checkpoint = self.load_checkpoint( |
| | checkpoint, map_location=map_location) |
| |
|
| | self._epoch = checkpoint['meta']['epoch'] |
| | self._iter = checkpoint['meta']['iter'] |
| | if 'optimizer' in checkpoint and resume_optimizer: |
| | if isinstance(self.optimizer, Optimizer): |
| | self.optimizer.load_state_dict(checkpoint['optimizer']) |
| | elif isinstance(self.optimizer, dict): |
| | for k in self.optimizer.keys(): |
| | self.optimizer[k].load_state_dict( |
| | checkpoint['optimizer'][k]) |
| | else: |
| | raise TypeError( |
| | 'Optimizer should be dict or torch.optim.Optimizer ' |
| | f'but got {type(self.optimizer)}') |
| |
|
| | if 'amp' in checkpoint: |
| | apex.amp.load_state_dict(checkpoint['amp']) |
| | self.logger.info('load amp state dict') |
| |
|
| | self.logger.info('resumed epoch %d, iter %d', self.epoch, self.iter) |
| |
|