|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
from mmengine.optim import _ParamScheduler |
|
|
from mmengine.registry import HOOKS |
|
|
from mmengine.utils import is_list_of |
|
|
from .hook import Hook |
|
|
|
|
|
DATA_BATCH = Optional[Union[dict, tuple, list]] |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class ParamSchedulerHook(Hook): |
|
|
"""A hook to update some hyper-parameters in optimizer, e.g., learning rate |
|
|
and momentum.""" |
|
|
|
|
|
priority = 'LOW' |
|
|
|
|
|
def after_train_iter(self, |
|
|
runner, |
|
|
batch_idx: int, |
|
|
data_batch: DATA_BATCH = None, |
|
|
outputs: Optional[dict] = None) -> None: |
|
|
"""Call step function for each scheduler after each training iteration. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
batch_idx (int): The index of the current batch in the train loop. |
|
|
data_batch (dict or tuple or list, optional): Data from dataloader. |
|
|
In order to keep this interface consistent with other hooks, |
|
|
we keep ``data_batch`` here. |
|
|
outputs (dict, optional): Outputs from model. |
|
|
In order to keep this interface consistent with other hooks, we |
|
|
keep ``data_batch`` here. |
|
|
""" |
|
|
|
|
|
if runner.param_schedulers is None: |
|
|
return |
|
|
|
|
|
def step(param_schedulers): |
|
|
assert isinstance(param_schedulers, list) |
|
|
for scheduler in param_schedulers: |
|
|
if not scheduler.by_epoch: |
|
|
scheduler.step() |
|
|
|
|
|
if isinstance(runner.param_schedulers, list): |
|
|
step(runner.param_schedulers) |
|
|
elif isinstance(runner.param_schedulers, dict): |
|
|
for param_schedulers in runner.param_schedulers.values(): |
|
|
step(param_schedulers) |
|
|
else: |
|
|
raise TypeError( |
|
|
'runner.param_schedulers should be list of ParamScheduler or ' |
|
|
'a dict containing list of ParamScheduler, ' |
|
|
f'but got {runner.param_schedulers}') |
|
|
|
|
|
def after_train_epoch(self, runner) -> None: |
|
|
"""Call step function for each scheduler after each training epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
|
|
|
if runner.param_schedulers is None: |
|
|
return |
|
|
|
|
|
def step(param_schedulers): |
|
|
assert isinstance(param_schedulers, list) |
|
|
for scheduler in param_schedulers: |
|
|
if scheduler.by_epoch: |
|
|
scheduler.step() |
|
|
|
|
|
if isinstance(runner.param_schedulers, list): |
|
|
step(runner.param_schedulers) |
|
|
elif isinstance(runner.param_schedulers, dict): |
|
|
for param_schedulers in runner.param_schedulers.values(): |
|
|
step(param_schedulers) |
|
|
else: |
|
|
raise TypeError( |
|
|
'runner.param_schedulers should be list of ParamScheduler or ' |
|
|
'a dict containing list of ParamScheduler, ' |
|
|
f'but got {runner.param_schedulers}') |
|
|
|
|
|
def after_val_epoch(self, |
|
|
runner, |
|
|
metrics: Optional[Dict[str, float]] = None) -> None: |
|
|
"""Call step function for each scheduler which has attribute |
|
|
``need_val_args`` after each validation epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the validation process. |
|
|
metrics (Dict[str, float], optional): Evaluation results of all |
|
|
metrics on validation dataset. The keys are the names of the |
|
|
metrics, and the values are corresponding results. |
|
|
|
|
|
Note: |
|
|
if ``runner.param_schedulers`` is not built before, |
|
|
the hook ``after_val_epoch`` will be skipped. |
|
|
""" |
|
|
|
|
|
if runner.param_schedulers is None: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
if metrics is None: |
|
|
return |
|
|
|
|
|
def step(param_schedulers): |
|
|
|
|
|
if not is_list_of(param_schedulers, _ParamScheduler): |
|
|
return |
|
|
|
|
|
for scheduler in param_schedulers: |
|
|
if (scheduler.by_epoch |
|
|
and getattr(scheduler, 'need_val_args', False)): |
|
|
scheduler.step(metrics) |
|
|
|
|
|
if isinstance(runner.param_schedulers, list): |
|
|
step(runner.param_schedulers) |
|
|
elif isinstance(runner.param_schedulers, dict): |
|
|
for param_schedulers in runner.param_schedulers.values(): |
|
|
step(param_schedulers) |
|
|
else: |
|
|
raise TypeError( |
|
|
'runner.param_schedulers should be list of ParamScheduler or ' |
|
|
'a dict containing list of ParamScheduler, ' |
|
|
f'but got {runner.param_schedulers}') |
|
|
|