|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Any, Dict, Optional, Union |
|
|
|
|
|
import numpy as np |
|
|
import torch |
|
|
|
|
|
from mmengine.registry import HOOKS |
|
|
from mmengine.utils import get_git_hash |
|
|
from mmengine.version import __version__ |
|
|
from .hook import Hook |
|
|
|
|
|
DATA_BATCH = Optional[Union[dict, tuple, list]] |
|
|
|
|
|
|
|
|
def _is_scalar(value: Any) -> bool: |
|
|
"""Determine the value is a scalar type value. |
|
|
|
|
|
Args: |
|
|
value (Any): value of log. |
|
|
|
|
|
Returns: |
|
|
bool: whether the value is a scalar type value. |
|
|
""" |
|
|
if isinstance(value, np.ndarray): |
|
|
return value.size == 1 |
|
|
elif isinstance(value, (int, float, np.number)): |
|
|
return True |
|
|
elif isinstance(value, torch.Tensor): |
|
|
return value.numel() == 1 |
|
|
return False |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class RuntimeInfoHook(Hook): |
|
|
"""A hook that updates runtime information into message hub. |
|
|
|
|
|
E.g. ``epoch``, ``iter``, ``max_epochs``, and ``max_iters`` for the |
|
|
training state. Components that cannot access the runner can get runtime |
|
|
information through the message hub. |
|
|
""" |
|
|
|
|
|
priority = 'VERY_HIGH' |
|
|
|
|
|
def before_run(self, runner) -> None: |
|
|
"""Update metainfo. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
metainfo = dict( |
|
|
cfg=runner.cfg.pretty_text, |
|
|
seed=runner.seed, |
|
|
experiment_name=runner.experiment_name, |
|
|
mmengine_version=__version__ + get_git_hash()) |
|
|
runner.message_hub.update_info_dict(metainfo) |
|
|
|
|
|
self.last_loop_stage = None |
|
|
|
|
|
def before_train(self, runner) -> None: |
|
|
"""Update resumed training state. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
runner.message_hub.update_info('loop_stage', 'train') |
|
|
runner.message_hub.update_info('epoch', runner.epoch) |
|
|
runner.message_hub.update_info('iter', runner.iter) |
|
|
runner.message_hub.update_info('max_epochs', runner.max_epochs) |
|
|
runner.message_hub.update_info('max_iters', runner.max_iters) |
|
|
if hasattr(runner.train_dataloader.dataset, 'metainfo'): |
|
|
runner.message_hub.update_info( |
|
|
'dataset_meta', runner.train_dataloader.dataset.metainfo) |
|
|
|
|
|
def after_train(self, runner) -> None: |
|
|
runner.message_hub.pop_info('loop_stage') |
|
|
|
|
|
def before_train_epoch(self, runner) -> None: |
|
|
"""Update current epoch information before every epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
runner.message_hub.update_info('epoch', runner.epoch) |
|
|
|
|
|
def before_train_iter(self, |
|
|
runner, |
|
|
batch_idx: int, |
|
|
data_batch: DATA_BATCH = None) -> None: |
|
|
"""Update current iter and learning rate information before every |
|
|
iteration. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
batch_idx (int): The index of the current batch in the train loop. |
|
|
data_batch (Sequence[dict], optional): Data from dataloader. |
|
|
Defaults to None. |
|
|
""" |
|
|
runner.message_hub.update_info('iter', runner.iter) |
|
|
lr_dict = runner.optim_wrapper.get_lr() |
|
|
assert isinstance(lr_dict, dict), ( |
|
|
'`runner.optim_wrapper.get_lr()` should return a dict ' |
|
|
'of learning rate when training with OptimWrapper(single ' |
|
|
'optimizer) or OptimWrapperDict(multiple optimizer), ' |
|
|
f'but got {type(lr_dict)} please check your optimizer ' |
|
|
'constructor return an `OptimWrapper` or `OptimWrapperDict` ' |
|
|
'instance') |
|
|
for name, lr in lr_dict.items(): |
|
|
runner.message_hub.update_scalar(f'train/{name}', lr[0]) |
|
|
|
|
|
def after_train_iter(self, |
|
|
runner, |
|
|
batch_idx: int, |
|
|
data_batch: DATA_BATCH = None, |
|
|
outputs: Optional[dict] = None) -> None: |
|
|
"""Update ``log_vars`` in model outputs every iteration. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
batch_idx (int): The index of the current batch in the train loop. |
|
|
data_batch (Sequence[dict], optional): Data from dataloader. |
|
|
Defaults to None. |
|
|
outputs (dict, optional): Outputs from model. Defaults to None. |
|
|
""" |
|
|
if outputs is not None: |
|
|
for key, value in outputs.items(): |
|
|
if key.startswith('vis_'): |
|
|
continue |
|
|
runner.message_hub.update_scalar(f'train/{key}', value) |
|
|
|
|
|
def before_val(self, runner) -> None: |
|
|
self.last_loop_stage = runner.message_hub.get_info('loop_stage') |
|
|
runner.message_hub.update_info('loop_stage', 'val') |
|
|
|
|
|
def after_val_epoch(self, |
|
|
runner, |
|
|
metrics: Optional[Dict[str, float]] = None) -> None: |
|
|
"""All subclasses should override this method, if they need any |
|
|
operations after each validation epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the validation process. |
|
|
metrics (Dict[str, float], optional): Evaluation results of all |
|
|
metrics on validation dataset. The keys are the names of the |
|
|
metrics, and the values are corresponding results. |
|
|
""" |
|
|
if metrics is not None: |
|
|
for key, value in metrics.items(): |
|
|
if _is_scalar(value): |
|
|
runner.message_hub.update_scalar(f'val/{key}', value) |
|
|
else: |
|
|
runner.message_hub.update_info(f'val/{key}', value) |
|
|
|
|
|
def after_val(self, runner) -> None: |
|
|
|
|
|
|
|
|
|
|
|
if self.last_loop_stage == 'train': |
|
|
runner.message_hub.update_info('loop_stage', self.last_loop_stage) |
|
|
self.last_loop_stage = None |
|
|
else: |
|
|
runner.message_hub.pop_info('loop_stage') |
|
|
|
|
|
def before_test(self, runner) -> None: |
|
|
runner.message_hub.update_info('loop_stage', 'test') |
|
|
|
|
|
def after_test(self, runner) -> None: |
|
|
runner.message_hub.pop_info('loop_stage') |
|
|
|
|
|
def after_test_epoch(self, |
|
|
runner, |
|
|
metrics: Optional[Dict[str, float]] = None) -> None: |
|
|
"""All subclasses should override this method, if they need any |
|
|
operations after each test epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the testing process. |
|
|
metrics (Dict[str, float], optional): Evaluation results of all |
|
|
metrics on test dataset. The keys are the names of the |
|
|
metrics, and the values are corresponding results. |
|
|
""" |
|
|
if metrics is not None: |
|
|
for key, value in metrics.items(): |
|
|
if _is_scalar(value): |
|
|
runner.message_hub.update_scalar(f'test/{key}', value) |
|
|
else: |
|
|
runner.message_hub.update_info(f'test/{key}', value) |
|
|
|