|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import time |
|
|
from typing import Callable, Dict, List, Optional, Union |
|
|
|
|
|
import torch.nn as nn |
|
|
|
|
|
import mmengine |
|
|
from mmengine.device import get_device |
|
|
from mmengine.model import revert_sync_batchnorm |
|
|
from mmengine.optim import BaseOptimWrapper, _ParamScheduler |
|
|
from mmengine.registry import STRATEGIES |
|
|
from mmengine.utils import get_git_hash |
|
|
from .base import BaseStrategy |
|
|
|
|
|
|
|
|
@STRATEGIES.register_module() |
|
|
class SingleDeviceStrategy(BaseStrategy): |
|
|
"""Strategy for single device training.""" |
|
|
|
|
|
def prepare( |
|
|
self, |
|
|
model: Union[nn.Module, dict], |
|
|
*, |
|
|
optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, |
|
|
param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, |
|
|
compile: Union[dict, bool] = False, |
|
|
dispatch_kwargs: Optional[dict] = None, |
|
|
): |
|
|
"""Prepare model and some components. |
|
|
|
|
|
Args: |
|
|
model (:obj:`torch.nn.Module` or dict): The model to be run. It |
|
|
can be a dict used for build a model. |
|
|
|
|
|
Keyword Args: |
|
|
optim_wrapper (BaseOptimWrapper or dict, optional): Computing the |
|
|
gradient of model parameters and updating them. |
|
|
Defaults to None. |
|
|
See :meth:`build_optim_wrapper` for examples. |
|
|
param_scheduler (_ParamScheduler or dict or list, optional): |
|
|
Parameter scheduler for updating optimizer parameters. If |
|
|
specified, :attr:`optim_wrapper` should also be specified. |
|
|
Defaults to None. |
|
|
See :meth:`build_param_scheduler` for examples. |
|
|
compile (dict, optional): Config to compile model. |
|
|
Defaults to False. Requires PyTorch>=2.0. |
|
|
dispatch_kwargs (dict, optional): Kwargs to be passed to other |
|
|
methods of Strategy. Defaults to None. |
|
|
If ``accumulative_counts`` is set in ``optim_wrapper``, you |
|
|
need to provide ``max_iters`` in ``dispatch_kwargs``. |
|
|
""" |
|
|
if self._prepared: |
|
|
return self._prepared_components() |
|
|
if dispatch_kwargs is not None: |
|
|
self.dispatch_kwargs.update(dispatch_kwargs) |
|
|
|
|
|
model = self.build_model(model) |
|
|
model = self._init_model_weights(model) |
|
|
model = self._wrap_model(model) |
|
|
model = self.compile_model(model, compile=compile) |
|
|
|
|
|
self.model = model |
|
|
|
|
|
if optim_wrapper is not None: |
|
|
self.optim_wrapper = self.build_optim_wrapper(optim_wrapper, model) |
|
|
|
|
|
if param_scheduler is not None: |
|
|
self.param_schedulers = self.build_param_scheduler( |
|
|
param_scheduler, self.optim_wrapper) |
|
|
|
|
|
if optim_wrapper is not None: |
|
|
self._scale_lr() |
|
|
|
|
|
accumulative_counts = getattr(self.optim_wrapper, |
|
|
'_accumulative_counts', 1) |
|
|
if accumulative_counts > 1: |
|
|
if 'max_iters' not in self.dispatch_kwargs: |
|
|
raise ValueError( |
|
|
'"max_iters" must be specified because ' |
|
|
'"accumulative_counts" was set as ' |
|
|
f'{accumulative_counts} which is greater than 1.') |
|
|
|
|
|
self.optim_wrapper.initialize_count_status( |
|
|
self.model, 0, self.dispatch_kwargs['max_iters']) |
|
|
self._prepared = True |
|
|
return self._prepared_components() |
|
|
|
|
|
def _wrap_model(self, model: nn.Module) -> nn.Module: |
|
|
model = self.convert_model(model) |
|
|
current_device = get_device() |
|
|
return model.to(current_device) |
|
|
|
|
|
def convert_model(self, model: nn.Module) -> nn.Module: |
|
|
"""Convert layers of model. |
|
|
|
|
|
convert all ``SyncBatchNorm`` (SyncBN) and |
|
|
``mmcv.ops.sync_bn.SyncBatchNorm`` (MMSyncBN) layers in the model to |
|
|
``BatchNormXd`` layers. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to convert. |
|
|
""" |
|
|
self.logger.info( |
|
|
'Distributed training is not used, all SyncBatchNorm (SyncBN) ' |
|
|
'layers in the model will be automatically reverted to ' |
|
|
'BatchNormXd layers if they are used.') |
|
|
model = revert_sync_batchnorm(model) |
|
|
return model |
|
|
|
|
|
def load_checkpoint( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
map_location: Union[str, Callable] = 'cpu', |
|
|
strict: bool = False, |
|
|
revise_keys: list = [(r'^module.', '')], |
|
|
callback: Optional[Callable] = None, |
|
|
) -> dict: |
|
|
"""Load checkpoint from given ``filename``. |
|
|
|
|
|
Args: |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
|
|
|
Keyword Args: |
|
|
map_location (str or callable): A string or a callable function to |
|
|
specifying how to remap storage locations. |
|
|
Defaults to 'cpu'. |
|
|
strict (bool): strict (bool): Whether to allow different params for |
|
|
the model and checkpoint. |
|
|
revise_keys (list): A list of customized keywords to modify the |
|
|
state_dict in checkpoint. Each item is a (pattern, replacement) |
|
|
pair of the regular expression operations. Defaults to strip |
|
|
the prefix 'module.' by [(r'^module\\.', '')]. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint after loading the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
from mmengine.runner.checkpoint import _load_checkpoint |
|
|
|
|
|
self.logger.info(f'Load checkpoint from {filename}') |
|
|
|
|
|
if map_location == 'default': |
|
|
device = get_device() |
|
|
checkpoint = _load_checkpoint(filename, map_location=device) |
|
|
else: |
|
|
checkpoint = _load_checkpoint(filename, map_location=map_location) |
|
|
|
|
|
|
|
|
if callback is not None: |
|
|
callback(checkpoint) |
|
|
|
|
|
state_dict = checkpoint.pop('state_dict') |
|
|
self.load_model_state_dict( |
|
|
state_dict, strict=strict, revise_keys=revise_keys) |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
def resume( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
resume_optimizer: bool = True, |
|
|
resume_param_scheduler: bool = True, |
|
|
map_location: Union[str, Callable] = 'default', |
|
|
callback: Optional[Callable] = None, |
|
|
) -> dict: |
|
|
"""Resume training from given ``filename``. |
|
|
|
|
|
Four types of states will be resumed. |
|
|
|
|
|
- model state |
|
|
- optimizer state |
|
|
- scheduler state |
|
|
- randomness state |
|
|
|
|
|
Args: |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
|
|
|
Keyword Args: |
|
|
resume_optimizer (bool): Whether to resume optimizer state. |
|
|
Defaults to True. |
|
|
resume_param_scheduler (bool): Whether to resume param scheduler |
|
|
state. Defaults to True. |
|
|
map_location (str or callable):A string or a callable function to |
|
|
specifying how to remap storage locations. |
|
|
Defaults to 'default'. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint before saving the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
self.logger.info(f'Resume checkpoint from {filename}') |
|
|
|
|
|
checkpoint = self.load_checkpoint( |
|
|
filename, map_location=map_location, callback=callback) |
|
|
|
|
|
if resume_optimizer: |
|
|
self.load_optim_state_dict(checkpoint.pop('optimizer')) |
|
|
|
|
|
if resume_param_scheduler and hasattr(self, 'param_schedulers'): |
|
|
self.load_scheduler_state_dict(checkpoint.pop('param_schedulers')) |
|
|
|
|
|
|
|
|
resumed_seed = checkpoint['meta'].get('seed', None) |
|
|
current_seed = self._randomness.get('seed') |
|
|
if resumed_seed is not None and resumed_seed != current_seed: |
|
|
if current_seed is not None: |
|
|
self.logger.warning(f'The value of random seed in the ' |
|
|
f'checkpoint "{resumed_seed}" is ' |
|
|
f'different from the value in ' |
|
|
f'`randomness` config "{current_seed}"') |
|
|
self._randomness.update(seed=resumed_seed) |
|
|
self._set_randomness(**self._randomness) |
|
|
|
|
|
|
|
|
cur_iter = checkpoint['meta']['iter'] |
|
|
|
|
|
if hasattr(self, 'optim_wrapper'): |
|
|
accumulative_counts = getattr(self.optim_wrapper, |
|
|
'_accumulative_counts', 1) |
|
|
if accumulative_counts > 1: |
|
|
if 'max_iters' not in self.dispatch_kwargs: |
|
|
raise ValueError( |
|
|
'"max_iters" must be specified because ' |
|
|
'"accumulative_counts" was set as ' |
|
|
f'{accumulative_counts} which is greater than 1.') |
|
|
|
|
|
self.optim_wrapper.initialize_count_status( |
|
|
self.model, cur_iter, self.dispatch_kwargs['max_iters']) |
|
|
|
|
|
return checkpoint |
|
|
|
|
|
def save_checkpoint( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
save_optimizer: bool = True, |
|
|
save_param_scheduler: bool = True, |
|
|
extra_ckpt: Optional[dict] = None, |
|
|
callback: Optional[Callable] = None, |
|
|
) -> None: |
|
|
"""Save checkpoint to given ``filename``. |
|
|
|
|
|
Args: |
|
|
filename (str): Filename to save checkpoint. |
|
|
|
|
|
Keyword Args: |
|
|
save_optimizer (bool): Whether to save the optimizer to |
|
|
the checkpoint. Defaults to True. |
|
|
save_param_scheduler (bool): Whether to save the param_scheduler |
|
|
to the checkpoint. Defaults to True. |
|
|
extra_ckpt (dict, optional): Extra checkpoint to save. |
|
|
Defaults to None. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint before saving the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
from mmengine.runner.checkpoint import save_checkpoint |
|
|
|
|
|
state_dict: dict = dict() |
|
|
state_dict['state_dict'] = self.model_state_dict() |
|
|
|
|
|
|
|
|
if save_optimizer and hasattr(self, 'optim_wrapper'): |
|
|
state_dict['optimizer'] = self.optim_state_dict() |
|
|
|
|
|
if save_param_scheduler and hasattr(self, 'param_schedulers'): |
|
|
state_dict['param_schedulers'] = self.scheduler_state_dict() |
|
|
|
|
|
|
|
|
if extra_ckpt is None: |
|
|
extra_ckpt = dict() |
|
|
if 'meta' not in extra_ckpt: |
|
|
extra_ckpt['meta'] = dict() |
|
|
extra_ckpt['meta'].update( |
|
|
seed=self.seed, |
|
|
time=time.strftime('%Y%m%d_%H%M%S', time.localtime()), |
|
|
mmengine=mmengine.__version__ + get_git_hash(), |
|
|
) |
|
|
|
|
|
state_dict.update(extra_ckpt) |
|
|
|
|
|
|
|
|
if callback is not None: |
|
|
callback(state_dict) |
|
|
|
|
|
save_checkpoint(state_dict, filename) |
|
|
|