|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import hashlib |
|
|
import logging |
|
|
import os.path as osp |
|
|
import pickle |
|
|
from collections import deque |
|
|
from math import inf |
|
|
from pathlib import Path |
|
|
from typing import Callable, Dict, List, Optional, Sequence, Union |
|
|
|
|
|
from mmengine.dist import is_main_process, master_only |
|
|
from mmengine.fileio import FileClient, get_file_backend |
|
|
from mmengine.logging import print_log |
|
|
from mmengine.registry import HOOKS |
|
|
from mmengine.utils import is_list_of, is_seq_of |
|
|
from .hook import Hook |
|
|
|
|
|
DATA_BATCH = Optional[Union[dict, tuple, list]] |
|
|
|
|
|
|
|
|
@HOOKS.register_module() |
|
|
class CheckpointHook(Hook): |
|
|
"""Save checkpoints periodically. |
|
|
|
|
|
Args: |
|
|
interval (int): The saving period. If ``by_epoch=True``, interval |
|
|
indicates epochs, otherwise it indicates iterations. |
|
|
Defaults to -1, which means "never". |
|
|
by_epoch (bool): Saving checkpoints by epoch or by iteration. |
|
|
Defaults to True. |
|
|
save_optimizer (bool): Whether to save optimizer state_dict in the |
|
|
checkpoint. It is usually used for resuming experiments. |
|
|
Defaults to True. |
|
|
save_param_scheduler (bool): Whether to save param_scheduler state_dict |
|
|
in the checkpoint. It is usually used for resuming experiments. |
|
|
Defaults to True. |
|
|
out_dir (str, Path, Optional): The root directory to save checkpoints. |
|
|
If not specified, ``runner.work_dir`` will be used by default. If |
|
|
specified, the ``out_dir`` will be the concatenation of ``out_dir`` |
|
|
and the last level directory of ``runner.work_dir``. For example, |
|
|
if the input ``our_dir`` is ``./tmp`` and ``runner.work_dir`` is |
|
|
``./work_dir/cur_exp``, then the ckpt will be saved in |
|
|
``./tmp/cur_exp``. Defaults to None. |
|
|
max_keep_ckpts (int): The maximum checkpoints to keep. |
|
|
In some cases we want only the latest few checkpoints and would |
|
|
like to delete old ones to save the disk space. |
|
|
Defaults to -1, which means unlimited. |
|
|
save_last (bool): Whether to force the last checkpoint to be |
|
|
saved regardless of interval. Defaults to True. |
|
|
save_best (str, List[str], optional): If a metric is specified, it |
|
|
would measure the best checkpoint during evaluation. If a list of |
|
|
metrics is passed, it would measure a group of best checkpoints |
|
|
corresponding to the passed metrics. The information about best |
|
|
checkpoint(s) would be saved in ``runner.message_hub`` to keep |
|
|
best score value and best checkpoint path, which will be also |
|
|
loaded when resuming checkpoint. Options are the evaluation metrics |
|
|
on the test dataset. e.g., ``bbox_mAP``, ``segm_mAP`` for bbox |
|
|
detection and instance segmentation. ``AR@100`` for proposal |
|
|
recall. If ``save_best`` is ``auto``, the first key of the returned |
|
|
``OrderedDict`` result will be used. Defaults to None. |
|
|
rule (str, List[str], optional): Comparison rule for best score. If |
|
|
set to None, it will infer a reasonable rule. Keys such as 'acc', |
|
|
'top' .etc will be inferred by 'greater' rule. Keys contain 'loss' |
|
|
will be inferred by 'less' rule. If ``save_best`` is a list of |
|
|
metrics and ``rule`` is a str, all metrics in ``save_best`` will |
|
|
share the comparison rule. If ``save_best`` and ``rule`` are both |
|
|
lists, their length must be the same, and metrics in ``save_best`` |
|
|
will use the corresponding comparison rule in ``rule``. Options |
|
|
are 'greater', 'less', None and list which contains 'greater' and |
|
|
'less'. Defaults to None. |
|
|
greater_keys (List[str], optional): Metric keys that will be |
|
|
inferred by 'greater' comparison rule. If ``None``, |
|
|
_default_greater_keys will be used. Defaults to None. |
|
|
less_keys (List[str], optional): Metric keys that will be |
|
|
inferred by 'less' comparison rule. If ``None``, _default_less_keys |
|
|
will be used. Defaults to None. |
|
|
file_client_args (dict, optional): Arguments to instantiate a |
|
|
FileClient. See :class:`mmengine.fileio.FileClient` for details. |
|
|
Defaults to None. It will be deprecated in future. Please use |
|
|
``backend_args`` instead. |
|
|
filename_tmpl (str, optional): String template to indicate checkpoint |
|
|
name. If specified, must contain one and only one "{}", which will |
|
|
be replaced with ``epoch + 1`` if ``by_epoch=True`` else |
|
|
``iteration + 1``. |
|
|
Defaults to None, which means "epoch_{}.pth" or "iter_{}.pth" |
|
|
accordingly. |
|
|
backend_args (dict, optional): Arguments to instantiate the |
|
|
prefix of uri corresponding backend. Defaults to None. |
|
|
`New in version 0.2.0.` |
|
|
published_keys (str, List[str], optional): If ``save_last`` is ``True`` |
|
|
or ``save_best`` is not ``None``, it will automatically |
|
|
publish model with keys in the list after training. |
|
|
Defaults to None. |
|
|
`New in version 0.7.1.` |
|
|
save_begin (int): Control the epoch number or iteration number |
|
|
at which checkpoint saving begins. Defaults to 0, which means |
|
|
saving at the beginning. |
|
|
`New in version 0.8.3.` |
|
|
|
|
|
Examples: |
|
|
>>> # Save best based on single metric |
|
|
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
|
|
>>> rule='less') |
|
|
>>> # Save best based on multi metrics with the same comparison rule |
|
|
>>> CheckpointHook(interval=2, by_epoch=True, |
|
|
>>> save_best=['acc', 'mIoU'], rule='greater') |
|
|
>>> # Save best based on multi metrics with different comparison rule |
|
|
>>> CheckpointHook(interval=2, by_epoch=True, |
|
|
>>> save_best=['FID', 'IS'], rule=['less', 'greater']) |
|
|
>>> # Save best based on single metric and publish model after training |
|
|
>>> CheckpointHook(interval=2, by_epoch=True, save_best='acc', |
|
|
>>> rule='less', published_keys=['meta', 'state_dict']) |
|
|
""" |
|
|
out_dir: str |
|
|
|
|
|
priority = 'VERY_LOW' |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
rule_map = {'greater': lambda x, y: x > y, 'less': lambda x, y: x < y} |
|
|
init_value_map = {'greater': -inf, 'less': inf} |
|
|
_default_greater_keys = [ |
|
|
'acc', 'top', 'AR@', 'auc', 'precision', 'mAP', 'mDice', 'mIoU', |
|
|
'mAcc', 'aAcc' |
|
|
] |
|
|
_default_less_keys = ['loss'] |
|
|
|
|
|
def __init__(self, |
|
|
interval: int = -1, |
|
|
by_epoch: bool = True, |
|
|
save_optimizer: bool = True, |
|
|
save_param_scheduler: bool = True, |
|
|
out_dir: Optional[Union[str, Path]] = None, |
|
|
max_keep_ckpts: int = -1, |
|
|
save_last: bool = True, |
|
|
save_best: Union[str, List[str], None] = None, |
|
|
rule: Union[str, List[str], None] = None, |
|
|
greater_keys: Optional[Sequence[str]] = None, |
|
|
less_keys: Optional[Sequence[str]] = None, |
|
|
file_client_args: Optional[dict] = None, |
|
|
filename_tmpl: Optional[str] = None, |
|
|
backend_args: Optional[dict] = None, |
|
|
published_keys: Union[str, List[str], None] = None, |
|
|
save_begin: int = 0, |
|
|
**kwargs) -> None: |
|
|
self.interval = interval |
|
|
self.by_epoch = by_epoch |
|
|
self.save_optimizer = save_optimizer |
|
|
self.save_param_scheduler = save_param_scheduler |
|
|
self.out_dir = out_dir |
|
|
self.max_keep_ckpts = max_keep_ckpts |
|
|
self.save_last = save_last |
|
|
self.args = kwargs |
|
|
|
|
|
if file_client_args is not None: |
|
|
print_log( |
|
|
'"file_client_args" will be deprecated in future. ' |
|
|
'Please use "backend_args" instead', |
|
|
logger='current', |
|
|
level=logging.WARNING) |
|
|
if backend_args is not None: |
|
|
raise ValueError( |
|
|
'"file_client_args" and "backend_args" cannot be set ' |
|
|
'at the same time.') |
|
|
|
|
|
self.file_client_args = file_client_args |
|
|
self.backend_args = backend_args |
|
|
|
|
|
if filename_tmpl is None: |
|
|
if self.by_epoch: |
|
|
self.filename_tmpl = 'epoch_{}.pth' |
|
|
else: |
|
|
self.filename_tmpl = 'iter_{}.pth' |
|
|
else: |
|
|
self.filename_tmpl = filename_tmpl |
|
|
|
|
|
|
|
|
assert (isinstance(save_best, str) or is_list_of(save_best, str) |
|
|
or (save_best is None)), ( |
|
|
'"save_best" should be a str or list of str or None, ' |
|
|
f'but got {type(save_best)}') |
|
|
|
|
|
if isinstance(save_best, list): |
|
|
if 'auto' in save_best: |
|
|
assert len(save_best) == 1, ( |
|
|
'Only support one "auto" in "save_best" list.') |
|
|
assert len(save_best) == len( |
|
|
set(save_best)), ('Find duplicate element in "save_best".') |
|
|
else: |
|
|
|
|
|
if save_best is not None: |
|
|
save_best = [save_best] |
|
|
self.save_best = save_best |
|
|
|
|
|
|
|
|
assert (isinstance(rule, str) or is_list_of(rule, str) |
|
|
or (rule is None)), ( |
|
|
'"rule" should be a str or list of str or None, ' |
|
|
f'but got {type(rule)}') |
|
|
if isinstance(rule, list): |
|
|
|
|
|
assert len(rule) in [ |
|
|
1, |
|
|
len(self.save_best) |
|
|
], ('Number of "rule" must be 1 or the same as number of ' |
|
|
f'"save_best", but got {len(rule)}.') |
|
|
else: |
|
|
|
|
|
rule = [rule] |
|
|
|
|
|
if greater_keys is None: |
|
|
self.greater_keys = self._default_greater_keys |
|
|
else: |
|
|
if not isinstance(greater_keys, (list, tuple)): |
|
|
greater_keys = (greater_keys, ) |
|
|
assert is_seq_of(greater_keys, str) |
|
|
self.greater_keys = greater_keys |
|
|
|
|
|
if less_keys is None: |
|
|
self.less_keys = self._default_less_keys |
|
|
else: |
|
|
if not isinstance(less_keys, (list, tuple)): |
|
|
less_keys = (less_keys, ) |
|
|
assert is_seq_of(less_keys, str) |
|
|
self.less_keys = less_keys |
|
|
|
|
|
if self.save_best is not None: |
|
|
self.is_better_than: Dict[str, Callable] = dict() |
|
|
self._init_rule(rule, self.save_best) |
|
|
if len(self.key_indicators) == 1: |
|
|
self.best_ckpt_path: Optional[str] = None |
|
|
else: |
|
|
self.best_ckpt_path_dict: Dict = dict() |
|
|
|
|
|
|
|
|
if not (isinstance(published_keys, str) |
|
|
or is_seq_of(published_keys, str) or published_keys is None): |
|
|
raise TypeError( |
|
|
'"published_keys" should be a str or a sequence of str or ' |
|
|
f'None, but got {type(published_keys)}') |
|
|
|
|
|
if isinstance(published_keys, str): |
|
|
published_keys = [published_keys] |
|
|
elif isinstance(published_keys, (list, tuple)): |
|
|
assert len(published_keys) == len(set(published_keys)), ( |
|
|
'Find duplicate elements in "published_keys".') |
|
|
self.published_keys = published_keys |
|
|
|
|
|
self.last_ckpt = None |
|
|
if save_begin < 0: |
|
|
raise ValueError( |
|
|
'save_begin should not be less than 0, but got {save_begin}') |
|
|
self.save_begin = save_begin |
|
|
|
|
|
def before_train(self, runner) -> None: |
|
|
"""Finish all operations, related to checkpoint. |
|
|
|
|
|
This function will get the appropriate file client, and the directory |
|
|
to save these checkpoints of the model. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if self.out_dir is None: |
|
|
self.out_dir = runner.work_dir |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
self.file_client = FileClient.infer_client(self.file_client_args, |
|
|
self.out_dir) |
|
|
|
|
|
if self.file_client_args is None: |
|
|
self.file_backend = get_file_backend( |
|
|
self.out_dir, backend_args=self.backend_args) |
|
|
else: |
|
|
self.file_backend = self.file_client |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.out_dir != runner.work_dir: |
|
|
basename = osp.basename(runner.work_dir.rstrip(osp.sep)) |
|
|
self.out_dir = self.file_backend.join_path( |
|
|
self.out_dir, basename) |
|
|
|
|
|
runner.logger.info(f'Checkpoints will be saved to {self.out_dir}.') |
|
|
|
|
|
if self.save_best is not None: |
|
|
if len(self.key_indicators) == 1: |
|
|
if 'best_ckpt' not in runner.message_hub.runtime_info: |
|
|
self.best_ckpt_path = None |
|
|
else: |
|
|
self.best_ckpt_path = runner.message_hub.get_info( |
|
|
'best_ckpt') |
|
|
else: |
|
|
for key_indicator in self.key_indicators: |
|
|
best_ckpt_name = f'best_ckpt_{key_indicator}' |
|
|
if best_ckpt_name not in runner.message_hub.runtime_info: |
|
|
self.best_ckpt_path_dict[key_indicator] = None |
|
|
else: |
|
|
self.best_ckpt_path_dict[ |
|
|
key_indicator] = runner.message_hub.get_info( |
|
|
best_ckpt_name) |
|
|
|
|
|
if self.max_keep_ckpts > 0: |
|
|
keep_ckpt_ids = [] |
|
|
if 'keep_ckpt_ids' in runner.message_hub.runtime_info: |
|
|
keep_ckpt_ids = runner.message_hub.get_info('keep_ckpt_ids') |
|
|
|
|
|
while len(keep_ckpt_ids) > self.max_keep_ckpts: |
|
|
step = keep_ckpt_ids.pop(0) |
|
|
if is_main_process(): |
|
|
path = self.file_backend.join_path( |
|
|
self.out_dir, self.filename_tmpl.format(step)) |
|
|
if self.file_backend.isfile(path): |
|
|
self.file_backend.remove(path) |
|
|
elif self.file_backend.isdir(path): |
|
|
|
|
|
self.file_backend.rmtree(path) |
|
|
|
|
|
self.keep_ckpt_ids: deque = deque(keep_ckpt_ids, |
|
|
self.max_keep_ckpts) |
|
|
|
|
|
def after_train_epoch(self, runner) -> None: |
|
|
"""Save the checkpoint and synchronize buffers after each epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if not self.by_epoch: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.every_n_epochs(runner, self.interval, self.save_begin) or ( |
|
|
self.save_last and self.is_last_train_epoch(runner)): |
|
|
runner.logger.info( |
|
|
f'Saving checkpoint at {runner.epoch + 1} epochs') |
|
|
self._save_checkpoint(runner) |
|
|
|
|
|
def after_val_epoch(self, runner, metrics): |
|
|
"""Save the checkpoint and synchronize buffers after each evaluation |
|
|
epoch. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
metrics (dict): Evaluation results of all metrics |
|
|
""" |
|
|
if len(metrics) == 0: |
|
|
runner.logger.warning( |
|
|
'Since `metrics` is an empty dict, the behavior to save ' |
|
|
'the best checkpoint will be skipped in this evaluation.') |
|
|
return |
|
|
|
|
|
self._save_best_checkpoint(runner, metrics) |
|
|
|
|
|
def after_train(self, runner) -> None: |
|
|
"""Publish the checkpoint after training. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if self.published_keys is None: |
|
|
return |
|
|
|
|
|
if self.save_last and self.last_ckpt is not None: |
|
|
self._publish_model(runner, self.last_ckpt) |
|
|
|
|
|
if getattr(self, 'best_ckpt_path', None) is not None: |
|
|
self._publish_model(runner, str(self.best_ckpt_path)) |
|
|
if getattr(self, 'best_ckpt_path_dict', None) is not None: |
|
|
for best_ckpt in self.best_ckpt_path_dict.values(): |
|
|
self._publish_model(runner, best_ckpt) |
|
|
|
|
|
@master_only |
|
|
def _publish_model(self, runner, ckpt_path: str) -> None: |
|
|
"""Remove unnecessary keys from ckpt_path and save the new checkpoint. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
ckpt_path (str): The checkpoint path that ought to be published. |
|
|
""" |
|
|
from mmengine.runner import save_checkpoint |
|
|
from mmengine.runner.checkpoint import _load_checkpoint |
|
|
checkpoint = _load_checkpoint(ckpt_path) |
|
|
assert self.published_keys is not None |
|
|
removed_keys = [] |
|
|
for key in list(checkpoint.keys()): |
|
|
if key not in self.published_keys: |
|
|
removed_keys.append(key) |
|
|
checkpoint.pop(key) |
|
|
if removed_keys: |
|
|
print_log( |
|
|
f'Key {removed_keys} will be removed because they are not ' |
|
|
'found in published_keys. If you want to keep them, ' |
|
|
f'please set `{removed_keys}` in published_keys', |
|
|
logger='current') |
|
|
checkpoint_data = pickle.dumps(checkpoint) |
|
|
sha = hashlib.sha256(checkpoint_data).hexdigest() |
|
|
final_path = osp.splitext(ckpt_path)[0] + f'-{sha[:8]}.pth' |
|
|
save_checkpoint(checkpoint, final_path) |
|
|
print_log( |
|
|
f'The checkpoint ({ckpt_path}) is published to ' |
|
|
f'{final_path}.', |
|
|
logger='current') |
|
|
|
|
|
def _save_checkpoint_with_step(self, runner, step, meta): |
|
|
|
|
|
|
|
|
if self.max_keep_ckpts > 0: |
|
|
|
|
|
|
|
|
if len(self.keep_ckpt_ids) > 0 and self.keep_ckpt_ids[-1] == step: |
|
|
pass |
|
|
else: |
|
|
if len(self.keep_ckpt_ids) == self.max_keep_ckpts: |
|
|
_step = self.keep_ckpt_ids.popleft() |
|
|
if is_main_process(): |
|
|
ckpt_path = self.file_backend.join_path( |
|
|
self.out_dir, self.filename_tmpl.format(_step)) |
|
|
|
|
|
if self.file_backend.isfile(ckpt_path): |
|
|
self.file_backend.remove(ckpt_path) |
|
|
elif self.file_backend.isdir(ckpt_path): |
|
|
|
|
|
self.file_backend.rmtree(ckpt_path) |
|
|
|
|
|
self.keep_ckpt_ids.append(step) |
|
|
runner.message_hub.update_info('keep_ckpt_ids', |
|
|
list(self.keep_ckpt_ids)) |
|
|
|
|
|
ckpt_filename = self.filename_tmpl.format(step) |
|
|
self.last_ckpt = self.file_backend.join_path(self.out_dir, |
|
|
ckpt_filename) |
|
|
runner.message_hub.update_info('last_ckpt', self.last_ckpt) |
|
|
|
|
|
runner.save_checkpoint( |
|
|
self.out_dir, |
|
|
ckpt_filename, |
|
|
self.file_client_args, |
|
|
save_optimizer=self.save_optimizer, |
|
|
save_param_scheduler=self.save_param_scheduler, |
|
|
meta=meta, |
|
|
by_epoch=self.by_epoch, |
|
|
backend_args=self.backend_args, |
|
|
**self.args) |
|
|
|
|
|
|
|
|
|
|
|
if not is_main_process(): |
|
|
return |
|
|
|
|
|
save_file = osp.join(runner.work_dir, 'last_checkpoint') |
|
|
with open(save_file, 'w') as f: |
|
|
f.write(self.last_ckpt) |
|
|
|
|
|
def _save_checkpoint(self, runner) -> None: |
|
|
"""Save the current checkpoint and delete outdated checkpoint. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
""" |
|
|
if self.by_epoch: |
|
|
step = runner.epoch + 1 |
|
|
meta = dict(epoch=step, iter=runner.iter) |
|
|
else: |
|
|
step = runner.iter + 1 |
|
|
meta = dict(epoch=runner.epoch, iter=step) |
|
|
|
|
|
self._save_checkpoint_with_step(runner, step, meta=meta) |
|
|
|
|
|
def _save_best_checkpoint(self, runner, metrics) -> None: |
|
|
"""Save the current checkpoint and delete outdated checkpoint. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
metrics (dict): Evaluation results of all metrics. |
|
|
""" |
|
|
if not self.save_best: |
|
|
return |
|
|
|
|
|
if self.by_epoch: |
|
|
ckpt_filename = self.filename_tmpl.format(runner.epoch) |
|
|
cur_type, cur_time = 'epoch', runner.epoch |
|
|
else: |
|
|
ckpt_filename = self.filename_tmpl.format(runner.iter) |
|
|
cur_type, cur_time = 'iter', runner.iter |
|
|
|
|
|
meta = dict(epoch=runner.epoch, iter=runner.iter) |
|
|
|
|
|
|
|
|
if 'auto' in self.key_indicators: |
|
|
self._init_rule(self.rules, [list(metrics.keys())[0]]) |
|
|
|
|
|
best_ckpt_updated = False |
|
|
|
|
|
|
|
|
for key_indicator, rule in zip(self.key_indicators, self.rules): |
|
|
key_score = metrics[key_indicator] |
|
|
|
|
|
if len(self.key_indicators) == 1: |
|
|
best_score_key = 'best_score' |
|
|
runtime_best_ckpt_key = 'best_ckpt' |
|
|
best_ckpt_path = self.best_ckpt_path |
|
|
else: |
|
|
best_score_key = f'best_score_{key_indicator}' |
|
|
runtime_best_ckpt_key = f'best_ckpt_{key_indicator}' |
|
|
best_ckpt_path = self.best_ckpt_path_dict[key_indicator] |
|
|
|
|
|
if best_score_key not in runner.message_hub.runtime_info: |
|
|
best_score = self.init_value_map[rule] |
|
|
else: |
|
|
best_score = runner.message_hub.get_info(best_score_key) |
|
|
|
|
|
if key_score is None or not self.is_better_than[key_indicator]( |
|
|
key_score, best_score): |
|
|
continue |
|
|
|
|
|
best_ckpt_updated = True |
|
|
|
|
|
best_score = key_score |
|
|
runner.message_hub.update_info(best_score_key, best_score) |
|
|
|
|
|
if best_ckpt_path and is_main_process(): |
|
|
is_removed = False |
|
|
if self.file_backend.isfile(best_ckpt_path): |
|
|
self.file_backend.remove(best_ckpt_path) |
|
|
is_removed = True |
|
|
elif self.file_backend.isdir(best_ckpt_path): |
|
|
|
|
|
self.file_backend.rmtree(best_ckpt_path) |
|
|
is_removed = True |
|
|
|
|
|
if is_removed: |
|
|
runner.logger.info( |
|
|
f'The previous best checkpoint {best_ckpt_path} ' |
|
|
'is removed') |
|
|
|
|
|
best_ckpt_name = f'best_{key_indicator}_{ckpt_filename}' |
|
|
|
|
|
best_ckpt_name = best_ckpt_name.replace('/', '_') |
|
|
if len(self.key_indicators) == 1: |
|
|
self.best_ckpt_path = self.file_backend.join_path( |
|
|
self.out_dir, best_ckpt_name) |
|
|
runner.message_hub.update_info(runtime_best_ckpt_key, |
|
|
self.best_ckpt_path) |
|
|
else: |
|
|
self.best_ckpt_path_dict[ |
|
|
key_indicator] = self.file_backend.join_path( |
|
|
self.out_dir, best_ckpt_name) |
|
|
runner.message_hub.update_info( |
|
|
runtime_best_ckpt_key, |
|
|
self.best_ckpt_path_dict[key_indicator]) |
|
|
runner.save_checkpoint( |
|
|
self.out_dir, |
|
|
filename=best_ckpt_name, |
|
|
file_client_args=self.file_client_args, |
|
|
save_optimizer=False, |
|
|
save_param_scheduler=False, |
|
|
meta=meta, |
|
|
by_epoch=False, |
|
|
backend_args=self.backend_args) |
|
|
runner.logger.info( |
|
|
f'The best checkpoint with {best_score:0.4f} {key_indicator} ' |
|
|
f'at {cur_time} {cur_type} is saved to {best_ckpt_name}.') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if best_ckpt_updated and self.last_ckpt is not None: |
|
|
self._save_checkpoint_with_step(runner, cur_time, meta) |
|
|
|
|
|
def _init_rule(self, rules, key_indicators) -> None: |
|
|
"""Initialize rule, key_indicator, comparison_func, and best score. If |
|
|
key_indicator is a list of string and rule is a string, all metric in |
|
|
the key_indicator will share the same rule. |
|
|
|
|
|
Here is the rule to determine which rule is used for key indicator when |
|
|
the rule is not specific (note that the key indicator matching is case- |
|
|
insensitive): |
|
|
|
|
|
1. If the key indicator is in ``self.greater_keys``, the rule |
|
|
will be specified as 'greater'. |
|
|
2. Or if the key indicator is in ``self.less_keys``, the rule |
|
|
will be specified as 'less'. |
|
|
3. Or if any one item in ``self.greater_keys`` is a substring of |
|
|
key_indicator, the rule will be specified as 'greater'. |
|
|
4. Or if any one item in ``self.less_keys`` is a substring of |
|
|
key_indicator, the rule will be specified as 'less'. |
|
|
|
|
|
Args: |
|
|
rule (List[Optional[str]]): Comparison rule for best score. |
|
|
key_indicator (List[str]): Key indicator to determine |
|
|
the comparison rule. |
|
|
""" |
|
|
if len(rules) == 1: |
|
|
rules = rules * len(key_indicators) |
|
|
|
|
|
self.rules = [] |
|
|
for rule, key_indicator in zip(rules, key_indicators): |
|
|
|
|
|
if rule not in self.rule_map and rule is not None: |
|
|
raise KeyError('rule must be greater, less or None, ' |
|
|
f'but got {rule}.') |
|
|
|
|
|
if rule is None and key_indicator != 'auto': |
|
|
|
|
|
|
|
|
key_indicator_lc = key_indicator.lower() |
|
|
greater_keys = {key.lower() for key in self.greater_keys} |
|
|
less_keys = {key.lower() for key in self.less_keys} |
|
|
|
|
|
if key_indicator_lc in greater_keys: |
|
|
rule = 'greater' |
|
|
elif key_indicator_lc in less_keys: |
|
|
rule = 'less' |
|
|
elif any(key in key_indicator_lc for key in greater_keys): |
|
|
rule = 'greater' |
|
|
elif any(key in key_indicator_lc for key in less_keys): |
|
|
rule = 'less' |
|
|
else: |
|
|
raise ValueError('Cannot infer the rule for key ' |
|
|
f'{key_indicator}, thus a specific rule ' |
|
|
'must be specified.') |
|
|
if rule is not None: |
|
|
self.is_better_than[key_indicator] = self.rule_map[rule] |
|
|
self.rules.append(rule) |
|
|
|
|
|
self.key_indicators = key_indicators |
|
|
|
|
|
def after_train_iter(self, |
|
|
runner, |
|
|
batch_idx: int, |
|
|
data_batch: DATA_BATCH = None, |
|
|
outputs=Optional[dict]) -> None: |
|
|
"""Save the checkpoint and synchronize buffers after each iteration. |
|
|
|
|
|
Args: |
|
|
runner (Runner): The runner of the training process. |
|
|
batch_idx (int): The index of the current batch in the train loop. |
|
|
data_batch (dict or tuple or list, optional): Data from dataloader. |
|
|
outputs (dict, optional): Outputs from model. |
|
|
""" |
|
|
if self.by_epoch: |
|
|
return |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if self.every_n_train_iters(runner, self.interval, |
|
|
self.save_begin) or \ |
|
|
(self.save_last and |
|
|
self.is_last_train_iter(runner)): |
|
|
runner.logger.info( |
|
|
f'Saving checkpoint at {runner.iter + 1} iterations') |
|
|
self._save_checkpoint(runner) |
|
|
|