|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import copy |
|
|
import os.path as osp |
|
|
import platform |
|
|
import time |
|
|
from abc import ABCMeta, abstractmethod |
|
|
from collections import OrderedDict |
|
|
from typing import Callable, Dict, List, Optional, Sequence, Tuple, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from torch.optim import Optimizer |
|
|
|
|
|
import mmengine |
|
|
from mmengine.config import Config, ConfigDict |
|
|
from mmengine.dist import (broadcast, get_dist_info, infer_launcher, |
|
|
is_distributed) |
|
|
from mmengine.logging import MMLogger |
|
|
from mmengine.model.wrappers import is_model_wrapper |
|
|
from mmengine.optim import (BaseOptimWrapper, OptimWrapperDict, |
|
|
_ParamScheduler, build_optim_wrapper) |
|
|
from mmengine.registry import MODELS, OPTIM_WRAPPERS, PARAM_SCHEDULERS |
|
|
from mmengine.utils import digit_version |
|
|
from mmengine.utils.dl_utils import (TORCH_VERSION, collect_env, |
|
|
set_multi_processing) |
|
|
|
|
|
ParamSchedulerType = Union[List[_ParamScheduler], Dict[str, |
|
|
List[_ParamScheduler]]] |
|
|
|
|
|
|
|
|
class BaseStrategy(metaclass=ABCMeta): |
|
|
"""Base class for all strategies. |
|
|
|
|
|
In the process of supporting FSDP, DeepSpeed, and ColossalAI, the |
|
|
scalability of the Runner faced challenges, which led to the redefinition |
|
|
of the Runner's responsibilities. The Strategy abstraction was split out, |
|
|
which is responsible for constructing, initializing, and saving/loading |
|
|
the state of training components such as models, optimizers, and parameter |
|
|
schedulers. |
|
|
|
|
|
Warning: |
|
|
This is an experimental feature, and its interface is subject to |
|
|
change. |
|
|
|
|
|
Keyword Args: |
|
|
work_dir (str): The working directory to save checkpoints. The logs |
|
|
will be saved in the subdirectory of `work_dir` named |
|
|
:attr:`timestamp`. Defaults to 'work_dirs'. |
|
|
experiment_name (str, optional): Name of current experiment. If not |
|
|
specified, timestamp will be used as :attr:`experiment_name`. |
|
|
Defaults to None. |
|
|
env_kwargs (dict, optional): Environment config passed in |
|
|
:meth:`setup_env`. Defaults to None. |
|
|
log_kwargs (dict, optional): Logger config passed in |
|
|
:meth:`build_logger`. Defaults to None. |
|
|
auto_scale_lr (dict, Optional): Config to scale the learning rate |
|
|
automatically. It includes ``base_batch_size`` and ``enable``. |
|
|
``base_batch_size`` is the batch size that the optimizer lr is |
|
|
based on. ``enable`` is the switch to turn on and off the feature. |
|
|
""" |
|
|
model: nn.Module |
|
|
optim_wrapper: BaseOptimWrapper |
|
|
param_schedulers: ParamSchedulerType |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
*, |
|
|
work_dir: str = 'work_dirs', |
|
|
experiment_name: Optional[str] = None, |
|
|
env_kwargs: Optional[dict] = None, |
|
|
log_kwargs: Optional[dict] = None, |
|
|
auto_scale_lr: Optional[dict] = None, |
|
|
): |
|
|
self._work_dir = osp.abspath(work_dir) |
|
|
mmengine.mkdir_or_exist(self._work_dir) |
|
|
|
|
|
self._env_kwargs = env_kwargs or {} |
|
|
self._setup_env(**self._env_kwargs) |
|
|
|
|
|
if experiment_name is not None: |
|
|
self._experiment_name = f'{experiment_name}_{self.timestamp}' |
|
|
else: |
|
|
self._experiment_name = self.timestamp |
|
|
|
|
|
self._log_dir = osp.join(self.work_dir, self.timestamp) |
|
|
mmengine.mkdir_or_exist(self._log_dir) |
|
|
|
|
|
log_kwargs = log_kwargs or {} |
|
|
self.logger = self.build_logger(**log_kwargs) |
|
|
|
|
|
self._auto_scale_lr = auto_scale_lr |
|
|
|
|
|
self.dispatch_kwargs: dict = {} |
|
|
self._prepared = False |
|
|
|
|
|
@property |
|
|
def work_dir(self): |
|
|
return self._work_dir |
|
|
|
|
|
@property |
|
|
def log_dir(self): |
|
|
return self._log_dir |
|
|
|
|
|
@property |
|
|
def experiment_name(self): |
|
|
return self._experiment_name |
|
|
|
|
|
@property |
|
|
def launcher(self): |
|
|
return self._launcher |
|
|
|
|
|
@property |
|
|
def distributed(self): |
|
|
return self._distributed |
|
|
|
|
|
@property |
|
|
def seed(self): |
|
|
return self._seed |
|
|
|
|
|
@property |
|
|
def rank(self): |
|
|
return self._rank |
|
|
|
|
|
@property |
|
|
def world_size(self): |
|
|
return self._world_size |
|
|
|
|
|
@property |
|
|
def timestamp(self): |
|
|
return self._timestamp |
|
|
|
|
|
@property |
|
|
def randomness(self): |
|
|
return self._randomness |
|
|
|
|
|
@abstractmethod |
|
|
def prepare( |
|
|
self, |
|
|
model: Union[nn.Module, dict], |
|
|
*, |
|
|
optim_wrapper: Union[BaseOptimWrapper, dict, None] = None, |
|
|
param_scheduler: Union[_ParamScheduler, Dict, List, None] = None, |
|
|
compile: Union[dict, bool] = False, |
|
|
dispatch_kwargs: Optional[dict] = None, |
|
|
): |
|
|
"""Prepare model and some components. |
|
|
|
|
|
Args: |
|
|
model (:obj:`torch.nn.Module` or dict): The model to be run. It |
|
|
can be a dict used for building a model. |
|
|
|
|
|
Keyword Args: |
|
|
optim_wrapper (BaseOptimWrapper or dict, optional): Computing the |
|
|
gradient of model parameters and updating them. |
|
|
Defaults to None. |
|
|
See :meth:`build_optim_wrapper` for examples. |
|
|
param_scheduler (_ParamScheduler or dict or list, optional): |
|
|
Parameter scheduler for updating optimizer parameters. If |
|
|
specified, :attr:`optim_wrapper` should also be specified. |
|
|
Defaults to None. |
|
|
See :meth:`build_param_scheduler` for examples. |
|
|
compile (dict, optional): Config to compile model. |
|
|
Defaults to False. Requires PyTorch>=2.0. |
|
|
dispatch_kwargs (dict, optional): Kwargs to be passed to other |
|
|
methods of Strategy. Defaults to None. |
|
|
""" |
|
|
|
|
|
def _setup_env( |
|
|
self, |
|
|
*, |
|
|
launcher: Optional[str] = None, |
|
|
cudnn_benchmark: bool = False, |
|
|
mp_cfg: Optional[dict] = None, |
|
|
dist_cfg: Optional[dict] = None, |
|
|
resource_limit: int = 4096, |
|
|
randomness: dict = dict(seed=None), |
|
|
): |
|
|
"""Setup environment. |
|
|
|
|
|
This method will do the following things: |
|
|
|
|
|
1. setup multi-processing |
|
|
2. setup distributed |
|
|
3. set random seed |
|
|
|
|
|
Keyword Args: |
|
|
launcher (str, optional): Way to launcher multi-process. Supported |
|
|
launchers are 'pytorch', 'mpi', 'slurm' and 'none'. If 'none' |
|
|
is provided, non-distributed environment will be launched. |
|
|
If launcher is None, the launcher will be inferred according |
|
|
some specified environments. Defaults to None. |
|
|
cudnn_benchmark (bool): Whether to enable cudnn benchmark. |
|
|
Defaults to False. |
|
|
mp_cfg (dict, optional): Multi-processing config. Defaults to None. |
|
|
dist_cfg (dict, optional): Distributed config. Defaults to None. |
|
|
resource_limit (int): Resource limit. Defaults to 4096. |
|
|
randomness (dict): Some settings to make the experiment as |
|
|
reproducible as possible like seed and deterministic. |
|
|
Defaults to ``dict(seed=None)``. If seed is None, a random |
|
|
number will be generated and it will be broadcasted to all |
|
|
other processes if in distributed environment. |
|
|
If ``cudnn_benchmark`` is ``True`` in but ``deterministic`` is |
|
|
``True`` in ``randomness``, the value of |
|
|
``torch.backends.cudnn.benchmark`` will be ``False`` finally. |
|
|
""" |
|
|
if launcher is None: |
|
|
launcher = infer_launcher() |
|
|
|
|
|
self._launcher = launcher |
|
|
if self._launcher == 'none': |
|
|
self._distributed = False |
|
|
else: |
|
|
self._distributed = True |
|
|
|
|
|
if cudnn_benchmark: |
|
|
torch.backends.cudnn.benchmark = True |
|
|
|
|
|
mp_cfg = mp_cfg if mp_cfg is not None else {} |
|
|
set_multi_processing(**mp_cfg, distributed=self._distributed) |
|
|
|
|
|
|
|
|
if self._distributed and not is_distributed(): |
|
|
dist_cfg = dist_cfg if dist_cfg is not None else {} |
|
|
self._setup_distributed(launcher, **dist_cfg) |
|
|
|
|
|
self._rank, self._world_size = get_dist_info() |
|
|
|
|
|
timestamp = torch.tensor(time.time(), dtype=torch.float64) |
|
|
|
|
|
broadcast(timestamp) |
|
|
self._timestamp = time.strftime('%Y%m%d_%H%M%S', |
|
|
time.localtime(timestamp.item())) |
|
|
|
|
|
|
|
|
|
|
|
if platform.system() != 'Windows': |
|
|
import resource |
|
|
rlimit = resource.getrlimit(resource.RLIMIT_NOFILE) |
|
|
base_soft_limit = rlimit[0] |
|
|
hard_limit = rlimit[1] |
|
|
soft_limit = min(max(resource_limit, base_soft_limit), hard_limit) |
|
|
resource.setrlimit(resource.RLIMIT_NOFILE, |
|
|
(soft_limit, hard_limit)) |
|
|
|
|
|
self._randomness = randomness |
|
|
self._set_randomness(**randomness) |
|
|
|
|
|
def _setup_distributed(self, *args, **kwargs): |
|
|
"""Setup distributed environment.""" |
|
|
pass |
|
|
|
|
|
def _set_randomness( |
|
|
self, |
|
|
seed: Optional[int] = None, |
|
|
diff_rank_seed: bool = False, |
|
|
deterministic: bool = False, |
|
|
) -> None: |
|
|
"""Set random seed to guarantee reproducible results. |
|
|
|
|
|
Args: |
|
|
seed (int, optional): A number to set random modules. |
|
|
Defaults to None. |
|
|
diff_rank_seed (bool): Whether or not set different seeds according |
|
|
to global rank. Defaults to False. |
|
|
deterministic (bool): Whether to set the deterministic option for |
|
|
CUDNN backend, i.e., set `torch.backends.cudnn.deterministic` |
|
|
to True and `torch.backends.cudnn.benchmark` to False. |
|
|
Defaults to False. |
|
|
See https://pytorch.org/docs/stable/notes/randomness.html for |
|
|
more details. |
|
|
""" |
|
|
from mmengine.runner import set_random_seed |
|
|
self._seed = set_random_seed( |
|
|
seed=seed, |
|
|
deterministic=deterministic, |
|
|
diff_rank_seed=diff_rank_seed) |
|
|
|
|
|
def build_model(self, model: Union[nn.Module, dict]) -> nn.Module: |
|
|
"""Build model. |
|
|
|
|
|
If ``model`` is a dict, it will be used to build a ``nn.Module`` |
|
|
object. Otherwise, if ``model`` is a ``nn.Module`` object it will be |
|
|
returned directly. |
|
|
|
|
|
An example of ``model``:: |
|
|
|
|
|
model = dict(type='ResNet') |
|
|
|
|
|
Args: |
|
|
model (nn.Module or dict): A ``nn.Module`` object or a dict to |
|
|
build ``nn.Module`` object. If ``model`` is a ``nn.Module`` |
|
|
object, just returns itself. |
|
|
|
|
|
Note: |
|
|
The returned model must implement ``train_step``, ``test_step`` |
|
|
if ``runner.train`` or ``runner.test`` will be called. If |
|
|
``runner.val`` will be called or ``val_cfg`` is configured, |
|
|
model must implement `val_step`. |
|
|
|
|
|
Returns: |
|
|
nn.Module: Model build from ``model``. |
|
|
""" |
|
|
if isinstance(model, nn.Module): |
|
|
return model |
|
|
elif isinstance(model, dict): |
|
|
model = MODELS.build(model) |
|
|
return model |
|
|
else: |
|
|
raise TypeError('model should be a nn.Module object or dict, ' |
|
|
f'but got {model}') |
|
|
|
|
|
def compile_model( |
|
|
self, |
|
|
model: nn.Module, |
|
|
compile: Union[dict, bool] = False, |
|
|
) -> nn.Module: |
|
|
"""Compile model. |
|
|
|
|
|
Args: |
|
|
model (nn.Module): Model to compile. |
|
|
|
|
|
Returns: |
|
|
nn.Module: Compiled model. |
|
|
""" |
|
|
if isinstance(compile, bool) and not compile: |
|
|
return model |
|
|
|
|
|
assert digit_version(TORCH_VERSION) >= digit_version('2.0.0'), ( |
|
|
'PyTorch >= 2.0.0 is required to enable torch.compile') |
|
|
|
|
|
if isinstance(compile, bool): |
|
|
compile = dict() |
|
|
|
|
|
target = compile.pop('target', 'forward') |
|
|
func = getattr(model, target) |
|
|
compiled_func = torch.compile(func, **compile) |
|
|
setattr(model, target, compiled_func) |
|
|
self.logger.info('Model has been "compiled". The first few iterations ' |
|
|
'will be slow, please be patient.') |
|
|
|
|
|
return model |
|
|
|
|
|
def _init_model_weights(self, model: nn.Module) -> nn.Module: |
|
|
"""Initialize the model weights if the model has |
|
|
:meth:`init_weights`""" |
|
|
if (hasattr(model, 'init_weights') and self.dispatch_kwargs.get( |
|
|
'init_weights_for_test_or_val', True)): |
|
|
model.init_weights() |
|
|
|
|
|
for _, params in model.state_dict().items(): |
|
|
broadcast(params) |
|
|
|
|
|
return model |
|
|
|
|
|
def build_optim_wrapper( |
|
|
self, |
|
|
optim_wrapper: Union[Optimizer, BaseOptimWrapper, dict], |
|
|
model: Optional[nn.Module] = None, |
|
|
) -> BaseOptimWrapper: |
|
|
"""Build optimizer wrapper. |
|
|
|
|
|
If ``optim_wrapper`` is a config dict for only one optimizer, |
|
|
the keys must contain ``optimizer``, and ``type`` is optional. |
|
|
It will build a :obj:`OptimWrapper` by default. |
|
|
|
|
|
If ``optim_wrapper`` is a config dict for multiple optimizers, i.e., |
|
|
it has multiple keys and each key is for an optimizer wrapper. The |
|
|
constructor must be specified since |
|
|
:obj:`DefaultOptimizerConstructor` cannot handle the building of |
|
|
training with multiple optimizers. |
|
|
|
|
|
If ``optim_wrapper`` is a dict of pre-built optimizer wrappers, i.e., |
|
|
each value of ``optim_wrapper`` represents an ``OptimWrapper`` |
|
|
instance. ``build_optim_wrapper`` will directly build the |
|
|
:obj:`OptimWrapperDict` instance from ``optim_wrapper``. |
|
|
|
|
|
Args: |
|
|
optim_wrapper (BaseOptimWrapper or dict): An OptimWrapper object or a |
|
|
dict to build OptimWrapper objects. If ``optim_wrapper`` is an |
|
|
OptimWrapper, just return an ``OptimizeWrapper`` instance. |
|
|
|
|
|
Note: |
|
|
For single optimizer training, if `optim_wrapper` is a config |
|
|
dict, `type` is optional(defaults to :obj:`OptimWrapper`) and it |
|
|
must contain `optimizer` to build the corresponding optimizer. |
|
|
|
|
|
Examples: |
|
|
>>> # build an optimizer |
|
|
>>> optim_wrapper_cfg = dict(type='OptimWrapper', optimizer=dict( |
|
|
... type='SGD', lr=0.01)) |
|
|
>>> # optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) |
|
|
>>> # is also valid. |
|
|
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) |
|
|
>>> optim_wrapper |
|
|
Type: OptimWrapper |
|
|
accumulative_counts: 1 |
|
|
optimizer: |
|
|
SGD ( |
|
|
Parameter Group 0 |
|
|
dampening: 0 |
|
|
lr: 0.01 |
|
|
momentum: 0 |
|
|
nesterov: False |
|
|
weight_decay: 0 |
|
|
) |
|
|
>>> # build optimizer without `type` |
|
|
>>> optim_wrapper_cfg = dict(optimizer=dict(type='SGD', lr=0.01)) |
|
|
>>> optim_wrapper = runner.build_optim_wrapper(optim_wrapper_cfg) |
|
|
>>> optim_wrapper |
|
|
Type: OptimWrapper |
|
|
accumulative_counts: 1 |
|
|
optimizer: |
|
|
SGD ( |
|
|
Parameter Group 0 |
|
|
dampening: 0 |
|
|
lr: 0.01 |
|
|
maximize: False |
|
|
momentum: 0 |
|
|
nesterov: False |
|
|
weight_decay: 0 |
|
|
) |
|
|
>>> # build multiple optimizers |
|
|
>>> optim_wrapper_cfg = dict( |
|
|
... generator=dict(type='OptimWrapper', optimizer=dict( |
|
|
... type='SGD', lr=0.01)), |
|
|
... discriminator=dict(type='OptimWrapper', optimizer=dict( |
|
|
... type='Adam', lr=0.001)) |
|
|
... # need to customize a multiple optimizer constructor |
|
|
... constructor='CustomMultiOptimizerConstructor', |
|
|
...) |
|
|
>>> optim_wrapper = runner.optim_wrapper(optim_wrapper_cfg) |
|
|
>>> optim_wrapper |
|
|
name: generator |
|
|
Type: OptimWrapper |
|
|
accumulative_counts: 1 |
|
|
optimizer: |
|
|
SGD ( |
|
|
Parameter Group 0 |
|
|
dampening: 0 |
|
|
lr: 0.1 |
|
|
momentum: 0 |
|
|
nesterov: False |
|
|
weight_decay: 0 |
|
|
) |
|
|
name: discriminator |
|
|
Type: OptimWrapper |
|
|
accumulative_counts: 1 |
|
|
optimizer: |
|
|
'discriminator': Adam ( |
|
|
Parameter Group 0 |
|
|
dampening: 0 |
|
|
lr: 0.02 |
|
|
momentum: 0 |
|
|
nesterov: False |
|
|
weight_decay: 0 |
|
|
) |
|
|
|
|
|
Important: |
|
|
If you need to build multiple optimizers, you should implement a |
|
|
MultiOptimWrapperConstructor which gets parameters passed to |
|
|
corresponding optimizers and compose the ``OptimWrapperDict``. |
|
|
More details about how to customize OptimizerConstructor can be |
|
|
found at `optimizer-docs`_. |
|
|
|
|
|
Returns: |
|
|
BaseOptimWrapper: Optimizer wrapper build from ``optimizer_cfg``. |
|
|
""" |
|
|
if isinstance(optim_wrapper, BaseOptimWrapper): |
|
|
return optim_wrapper |
|
|
if isinstance(optim_wrapper, (dict, ConfigDict, Config)): |
|
|
|
|
|
optimizer = optim_wrapper.get('optimizer', None) |
|
|
|
|
|
|
|
|
|
|
|
if isinstance(optimizer, Optimizer): |
|
|
optim_wrapper.setdefault('type', 'OptimWrapper') |
|
|
return OPTIM_WRAPPERS.build(optim_wrapper) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if optimizer is not None or 'constructor' in optim_wrapper: |
|
|
assert model is not None |
|
|
return build_optim_wrapper(model, optim_wrapper) |
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
optim_wrappers = OrderedDict() |
|
|
for name, optim in optim_wrapper.items(): |
|
|
if not isinstance(optim, BaseOptimWrapper): |
|
|
raise ValueError( |
|
|
'each item mush be an optimizer object when ' |
|
|
'"type" and "constructor" are not in ' |
|
|
f'optimizer, but got {name}={optim}') |
|
|
optim_wrappers[name] = optim |
|
|
return OptimWrapperDict(**optim_wrappers) |
|
|
else: |
|
|
raise TypeError('optimizer wrapper should be an OptimWrapper ' |
|
|
f'object or dict, but got {optim_wrapper}') |
|
|
|
|
|
def _build_param_scheduler( |
|
|
self, |
|
|
scheduler: Union[_ParamScheduler, Dict, List], |
|
|
optim_wrapper: BaseOptimWrapper, |
|
|
default_args: dict, |
|
|
) -> List[_ParamScheduler]: |
|
|
"""Build parameter schedulers for a single optimizer. |
|
|
|
|
|
Args: |
|
|
scheduler (_ParamScheduler or dict or list): A Param Scheduler |
|
|
object or a dict or list of dict to build parameter schedulers. |
|
|
optim_wrapper (BaseOptimWrapper): An optimizer wrapper object is |
|
|
passed to construct ParamScheduler object. |
|
|
|
|
|
Returns: |
|
|
list[_ParamScheduler]: List of parameter schedulers build from |
|
|
``scheduler``. |
|
|
|
|
|
Note: |
|
|
If the train loop is built, when building parameter schedulers, |
|
|
it supports setting the max epochs/iters as the default ``end`` |
|
|
of schedulers, and supports converting epoch-based schedulers |
|
|
to iter-based according to the ``convert_to_iter_based`` key. |
|
|
""" |
|
|
if not isinstance(scheduler, Sequence): |
|
|
schedulers = [scheduler] |
|
|
else: |
|
|
schedulers = scheduler |
|
|
|
|
|
max_epochs = default_args.pop('max_epochs', None) |
|
|
max_iters = default_args.pop('max_iters', None) |
|
|
|
|
|
param_schedulers = [] |
|
|
for scheduler in schedulers: |
|
|
if isinstance(scheduler, _ParamScheduler): |
|
|
param_schedulers.append(scheduler) |
|
|
elif isinstance(scheduler, dict): |
|
|
_scheduler = copy.deepcopy(scheduler) |
|
|
|
|
|
|
|
|
if _scheduler.get('by_epoch', True): |
|
|
if max_epochs is None: |
|
|
raise ValueError( |
|
|
'max_epochs must be specified in default_args') |
|
|
default_end = max_epochs |
|
|
else: |
|
|
if max_iters is None: |
|
|
raise ValueError( |
|
|
'max_iters must be specified in default_args') |
|
|
default_end = max_iters |
|
|
_scheduler.setdefault('end', default_end) |
|
|
self.logger.debug( |
|
|
f'The `end` of {_scheduler["type"]} is not set. ' |
|
|
'Use the max epochs/iters of train loop as default.') |
|
|
|
|
|
param_schedulers.append( |
|
|
PARAM_SCHEDULERS.build( |
|
|
_scheduler, |
|
|
default_args=dict( |
|
|
optimizer=optim_wrapper, **default_args))) |
|
|
else: |
|
|
raise TypeError( |
|
|
'scheduler should be a _ParamScheduler object or dict, ' |
|
|
f'but got {scheduler}') |
|
|
return param_schedulers |
|
|
|
|
|
def build_param_scheduler( |
|
|
self, |
|
|
scheduler: Union[_ParamScheduler, Dict, List], |
|
|
optim_wrapper: BaseOptimWrapper, |
|
|
default_args: Optional[dict] = None, |
|
|
) -> ParamSchedulerType: |
|
|
"""Build parameter schedulers. |
|
|
|
|
|
``build_param_scheduler`` should be called after |
|
|
``build_optim_wrapper`` because the building logic will change |
|
|
according to the number of optimizers built by the runner. |
|
|
The cases are as below: |
|
|
|
|
|
- Single optimizer: When only one optimizer is built and used in the |
|
|
runner, ``build_param_scheduler`` will return a list of |
|
|
parameter schedulers. |
|
|
- Multiple optimizers: When two or more optimizers are built and used |
|
|
in runner, ``build_param_scheduler`` will return a dict containing |
|
|
the same keys with multiple optimizers and each value is a list of |
|
|
parameter schedulers. Note that, if you want different optimizers to |
|
|
use different parameter schedulers to update optimizer's |
|
|
hyper-parameters, the input parameter ``scheduler`` also needs to be |
|
|
a dict and its key are consistent with multiple optimizers. |
|
|
Otherwise, the same parameter schedulers will be used to update |
|
|
optimizer's hyper-parameters. |
|
|
|
|
|
Args: |
|
|
scheduler (_ParamScheduler or dict or list): A Param Scheduler |
|
|
object or a dict or list of dict to build parameter schedulers. |
|
|
|
|
|
Examples: |
|
|
>>> # build one scheduler |
|
|
>>> optim_cfg = dict(dict(type='SGD', lr=0.01)) |
|
|
>>> runner.optim_wrapper = runner.build_optim_wrapper( |
|
|
>>> optim_cfg) |
|
|
>>> scheduler_cfg = dict(type='MultiStepLR', milestones=[1, 2]) |
|
|
>>> schedulers = runner.build_param_scheduler(scheduler_cfg) |
|
|
>>> schedulers |
|
|
[<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f6966290>] # noqa: E501 |
|
|
|
|
|
>>> # build multiple schedulers |
|
|
>>> scheduler_cfg = [ |
|
|
... dict(type='MultiStepLR', milestones=[1, 2]), |
|
|
... dict(type='StepLR', step_size=1) |
|
|
... ] |
|
|
>>> schedulers = runner.build_param_scheduler(scheduler_cfg) |
|
|
>>> schedulers |
|
|
[<mmengine.optim.scheduler.lr_scheduler.MultiStepLR at 0x7f70f60dd3d0>, # noqa: E501 |
|
|
<mmengine.optim.scheduler.lr_scheduler.StepLR at 0x7f70f6eb6150>] |
|
|
|
|
|
Above examples only provide the case of one optimizer and one scheduler |
|
|
or multiple schedulers. If you want to know how to set parameter |
|
|
scheduler when using multiple optimizers, you can find more examples |
|
|
`optimizer-docs`_. |
|
|
|
|
|
Returns: |
|
|
list[_ParamScheduler] or dict[str, list[_ParamScheduler]]: List of |
|
|
parameter schedulers or a dictionary contains list of parameter |
|
|
schedulers build from ``scheduler``. |
|
|
|
|
|
.. _optimizer-docs: |
|
|
https://mmengine.readthedocs.io/en/latest/tutorials/optim_wrapper.html |
|
|
""" |
|
|
if default_args is None: |
|
|
default_args = {} |
|
|
if 'epoch_length' in self.dispatch_kwargs: |
|
|
default_args['epoch_length'] = self.dispatch_kwargs[ |
|
|
'epoch_length'] |
|
|
if 'max_epochs' in self.dispatch_kwargs: |
|
|
default_args['max_epochs'] = self.dispatch_kwargs['max_epochs'] |
|
|
if 'max_iters' in self.dispatch_kwargs: |
|
|
default_args['max_iters'] = self.dispatch_kwargs['max_iters'] |
|
|
|
|
|
param_schedulers: ParamSchedulerType |
|
|
if not isinstance(optim_wrapper, OptimWrapperDict): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
assert isinstance(optim_wrapper, BaseOptimWrapper), ( |
|
|
'`build_optimizer` should be called before' |
|
|
'`build_param_scheduler` because the latter depends ' |
|
|
'on the former') |
|
|
param_schedulers = self._build_param_scheduler( |
|
|
scheduler, optim_wrapper, default_args) |
|
|
return param_schedulers |
|
|
else: |
|
|
param_schedulers = dict() |
|
|
for name, optimizer in optim_wrapper.items(): |
|
|
if isinstance(scheduler, dict) and 'type' not in scheduler: |
|
|
|
|
|
|
|
|
param_schedulers[name] = self._build_param_scheduler( |
|
|
scheduler[name], optimizer, default_args) |
|
|
else: |
|
|
param_schedulers[name] = self._build_param_scheduler( |
|
|
scheduler, optimizer, default_args) |
|
|
|
|
|
return param_schedulers |
|
|
|
|
|
def _scale_lr(self) -> None: |
|
|
"""Automatically scaling learning rate in training according to the |
|
|
ratio of ``base_batch_size`` in ``autoscalelr_cfg`` and real batch |
|
|
size. |
|
|
|
|
|
It scales the learning rate linearly according to the |
|
|
`paper <https://arxiv.org/abs/1706.02677>`_. |
|
|
|
|
|
Note: |
|
|
``scale_lr`` must be called after building optimizer wrappers |
|
|
and before building parameter schedulers. |
|
|
""" |
|
|
if (self._auto_scale_lr is None |
|
|
or not self._auto_scale_lr.get('enable', False)): |
|
|
return None |
|
|
|
|
|
assert 'base_batch_size' in self._auto_scale_lr, \ |
|
|
'Lack of `base_batch_size` in `auto_scale_lr`.' |
|
|
|
|
|
try: |
|
|
real_bs = self.world_size * self.dispatch_kwargs['train_micro_batch_size_per_gpu'] |
|
|
except: |
|
|
real_bs = self.world_size * self.train_micro_batch_size_per_gpu |
|
|
|
|
|
base_bs = self._auto_scale_lr['base_batch_size'] |
|
|
ratio = float(real_bs) / float(base_bs) |
|
|
self.logger.info(f'LR is set based on batch size of {base_bs} ' |
|
|
f'and the current batch size is {real_bs}. ' |
|
|
f'Scaling the original LR by {ratio}.') |
|
|
|
|
|
def _is_built(schedulers): |
|
|
if isinstance(schedulers, dict): |
|
|
return False if 'type' in schedulers else any( |
|
|
_is_built(s) for s in schedulers.values()) |
|
|
if isinstance(schedulers, list): |
|
|
return any(_is_built(s) for s in schedulers) |
|
|
return isinstance(schedulers, _ParamScheduler) |
|
|
|
|
|
if _is_built(self.param_schedulers): |
|
|
raise RuntimeError('`scale_lr` should be called before building ' |
|
|
'ParamScheduler because ParamScheduler will ' |
|
|
'store initial lr from optimizer wrappers') |
|
|
|
|
|
assert isinstance(self.optim_wrapper, BaseOptimWrapper), \ |
|
|
'`scale_lr should be called after building OptimWrapper' |
|
|
|
|
|
if isinstance(self.optim_wrapper, OptimWrapperDict): |
|
|
wrappers = list(self.optim_wrapper.values()) |
|
|
else: |
|
|
wrappers = [self.optim_wrapper] |
|
|
|
|
|
for wrapper in wrappers: |
|
|
for group in wrapper.optimizer.param_groups: |
|
|
group['lr'] = group['lr'] * ratio |
|
|
|
|
|
def build_logger( |
|
|
self, |
|
|
log_level: Union[int, str] = 'INFO', |
|
|
log_file: Optional[str] = None, |
|
|
**kwargs, |
|
|
) -> MMLogger: |
|
|
"""Build a global asscessable MMLogger. |
|
|
|
|
|
Args: |
|
|
log_level (int or str): The log level of MMLogger handlers. |
|
|
Defaults to 'INFO'. |
|
|
log_file (str, optional): Path of filename to save log. |
|
|
Defaults to None. |
|
|
**kwargs: Remaining parameters passed to ``MMLogger``. |
|
|
|
|
|
Returns: |
|
|
MMLogger: A MMLogger object build from ``logger``. |
|
|
""" |
|
|
if log_file is None: |
|
|
log_file = osp.join(self.log_dir, f'{self._timestamp}.log') |
|
|
|
|
|
log_cfg = dict(log_level=log_level, log_file=log_file, **kwargs) |
|
|
log_cfg.setdefault('name', self.experiment_name) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
log_cfg.setdefault('file_mode', 'a') |
|
|
|
|
|
return MMLogger.get_instance(**log_cfg) |
|
|
|
|
|
def model_state_dict(self) -> dict: |
|
|
"""Get model state dict.""" |
|
|
from mmengine.runner import weights_to_cpu |
|
|
return weights_to_cpu(self.model.state_dict()) |
|
|
|
|
|
def optim_state_dict(self) -> dict: |
|
|
"""Get optimizer state dict.""" |
|
|
if isinstance(self.optim_wrapper, BaseOptimWrapper): |
|
|
return self.optim_wrapper.state_dict() |
|
|
else: |
|
|
raise TypeError('self.optim_wrapper should be a `BaseOptimWrapper`' |
|
|
f' instance, but got {self.optim_wrapper}') |
|
|
|
|
|
def scheduler_state_dict(self) -> Union[dict, list]: |
|
|
"""Get parameter scheduler state dict.""" |
|
|
if isinstance(self.param_schedulers, dict): |
|
|
state_dict: dict = dict() |
|
|
for name, schedulers in self.param_schedulers.items(): |
|
|
state_dict[name] = [] |
|
|
for scheduler in schedulers: |
|
|
state_dict[name].append(scheduler.state_dict()) |
|
|
return state_dict |
|
|
else: |
|
|
state_list = [] |
|
|
for scheduler in self.param_schedulers: |
|
|
state_list.append(scheduler.state_dict()) |
|
|
return state_list |
|
|
|
|
|
def load_model_state_dict( |
|
|
self, |
|
|
state_dict: dict, |
|
|
*, |
|
|
strict: bool = False, |
|
|
revise_keys: list = [(r'^module.', '')], |
|
|
) -> None: |
|
|
"""Load model state from dict.""" |
|
|
from mmengine.runner.checkpoint import _load_checkpoint_to_model |
|
|
|
|
|
if is_model_wrapper(self.model): |
|
|
model = self.model.module |
|
|
else: |
|
|
model = self.model |
|
|
|
|
|
_load_checkpoint_to_model(model, state_dict, strict, revise_keys) |
|
|
|
|
|
def load_optim_state_dict(self, state_dict: dict) -> None: |
|
|
"""Load optimizer state from dict.""" |
|
|
self.optim_wrapper.load_state_dict(state_dict) |
|
|
|
|
|
def load_scheduler_state_dict(self, state_dict: Union[dict, list]) -> None: |
|
|
"""Load scheduler state from dict.""" |
|
|
if isinstance(self.param_schedulers, dict): |
|
|
assert isinstance(state_dict, dict) |
|
|
for name, schedulers in self.param_schedulers.items(): |
|
|
for scheduler, ckpt_scheduler in zip(schedulers, |
|
|
state_dict[name]): |
|
|
scheduler.load_state_dict(ckpt_scheduler) |
|
|
else: |
|
|
for scheduler, ckpt_scheduler in zip( |
|
|
self.param_schedulers, |
|
|
state_dict): |
|
|
scheduler.load_state_dict(ckpt_scheduler) |
|
|
|
|
|
def load_or_resume( |
|
|
self, |
|
|
*, |
|
|
load_from: Optional[str] = None, |
|
|
resume: Union[bool, str] = False, |
|
|
) -> Optional[dict]: |
|
|
"""Load checkpoint or resume from checkpoint. |
|
|
|
|
|
Args: |
|
|
load_from (str, optional): The checkpoint file to load from. |
|
|
Defaults to None. |
|
|
resume (bool or str): Whether to resume training. Defaults to |
|
|
False. If ``resume`` is True and ``load_from`` is None, |
|
|
automatically to find latest checkpoint from ``work_dir``. |
|
|
If not found, resuming does nothing. If ``resume`` is a string, |
|
|
it will be treated as the checkpoint file to resume from. |
|
|
""" |
|
|
from mmengine.runner import find_latest_checkpoint |
|
|
|
|
|
if not resume and load_from is None: |
|
|
return None |
|
|
|
|
|
|
|
|
resume_from = None |
|
|
if isinstance(resume, str): |
|
|
resume_from = resume |
|
|
elif resume and load_from is None: |
|
|
|
|
|
resume_from = find_latest_checkpoint(self._work_dir) |
|
|
self.logger.info( |
|
|
f'Auto resumed from the latest checkpoint {resume_from}.') |
|
|
elif resume and load_from is not None: |
|
|
|
|
|
resume_from = load_from |
|
|
|
|
|
if resume_from is not None: |
|
|
return self.resume(resume_from) |
|
|
elif load_from is not None: |
|
|
return self.load_checkpoint(load_from) |
|
|
|
|
|
return None |
|
|
|
|
|
@abstractmethod |
|
|
def load_checkpoint( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
map_location: Union[str, Callable] = 'cpu', |
|
|
strict: bool = False, |
|
|
revise_keys: list = [(r'^module.', '')], |
|
|
callback: Optional[Callable] = None, |
|
|
) -> dict: |
|
|
"""Load checkpoint from given ``filename``. |
|
|
|
|
|
Args: |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
|
|
|
Keyword Args: |
|
|
map_location (str or callable): A string or a callable function to |
|
|
specifying how to remap storage locations. |
|
|
Defaults to 'cpu'. |
|
|
strict (bool): strict (bool): Whether to allow different params for |
|
|
the model and checkpoint. |
|
|
revise_keys (list): A list of customized keywords to modify the |
|
|
state_dict in checkpoint. Each item is a (pattern, replacement) |
|
|
pair of the regular expression operations. Defaults to strip |
|
|
the prefix 'module.' by [(r'^module\\.', '')]. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint after loading the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def resume( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
resume_optimizer: bool = True, |
|
|
resume_param_scheduler: bool = True, |
|
|
map_location: Union[str, Callable] = 'default', |
|
|
callback: Optional[Callable] = None, |
|
|
) -> dict: |
|
|
"""Resume training from given ``filename``. |
|
|
|
|
|
Four types of states will be resumed. |
|
|
|
|
|
- model state |
|
|
- optimizer state |
|
|
- scheduler state |
|
|
- randomness state |
|
|
|
|
|
Args: |
|
|
filename (str): Accept local filepath, URL, ``torchvision://xxx``, |
|
|
``open-mmlab://xxx``. |
|
|
|
|
|
Keyword Args: |
|
|
resume_optimizer (bool): Whether to resume optimizer state. |
|
|
Defaults to True. |
|
|
resume_param_scheduler (bool): Whether to resume param scheduler |
|
|
state. Defaults to True. |
|
|
map_location (str or callable):A string or a callable function to |
|
|
specifying how to remap storage locations. |
|
|
Defaults to 'default'. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint before saving the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
|
|
|
@abstractmethod |
|
|
def save_checkpoint( |
|
|
self, |
|
|
filename: str, |
|
|
*, |
|
|
save_optimizer: bool = True, |
|
|
save_param_scheduler: bool = True, |
|
|
extra_ckpt: Optional[dict] = None, |
|
|
callback: Optional[Callable] = None, |
|
|
) -> None: |
|
|
"""Save checkpoint to given ``filename``. |
|
|
|
|
|
Args: |
|
|
filename (str): Filename to save checkpoint. |
|
|
|
|
|
Keyword Args: |
|
|
save_optimizer (bool): Whether to save the optimizer to |
|
|
the checkpoint. Defaults to True. |
|
|
save_param_scheduler (bool): Whether to save the param_scheduler |
|
|
to the checkpoint. Defaults to True. |
|
|
extra_ckpt (dict, optional): Extra checkpoint to save. |
|
|
Defaults to None. |
|
|
callback (callable, callable): Callback function to modify the |
|
|
checkpoint before saving the checkpoint. |
|
|
Defaults to None. |
|
|
""" |
|
|
|
|
|
def collect_env(self) -> Tuple[dict, dict]: |
|
|
"""Collect the information of the running environments.""" |
|
|
system_env = collect_env() |
|
|
runtime_env: OrderedDict = OrderedDict() |
|
|
runtime_env.update(self._env_kwargs) |
|
|
runtime_env.update(self.randomness) |
|
|
runtime_env['Distributed launcher'] = self.launcher |
|
|
runtime_env['Distributed training'] = self.distributed |
|
|
runtime_env['GPU number'] = self.world_size |
|
|
|
|
|
return system_env, runtime_env |
|
|
|
|
|
def _prepared_components(self): |
|
|
return_items = [self.model] |
|
|
if hasattr(self, 'optim_wrapper'): |
|
|
return_items.append(self.optim_wrapper) |
|
|
|
|
|
if hasattr(self, 'param_schedulers'): |
|
|
return_items.append(self.param_schedulers) |
|
|
|
|
|
return return_items[0] if len(return_items) == 1 else return_items |
|
|
|