| import logging |
| from pathlib import Path |
| from typing import Sequence |
| from typing import Union |
| import warnings |
|
|
| import torch |
| from typeguard import check_argument_types |
| from typing import Collection |
|
|
| from espnet2.train.reporter import Reporter |
|
|
|
|
| @torch.no_grad() |
| def average_nbest_models( |
| output_dir: Path, |
| reporter: Reporter, |
| best_model_criterion: Sequence[Sequence[str]], |
| nbest: Union[Collection[int], int], |
| ) -> None: |
| """Generate averaged model from n-best models |
| |
| Args: |
| output_dir: The directory contains the model file for each epoch |
| reporter: Reporter instance |
| best_model_criterion: Give criterions to decide the best model. |
| e.g. [("valid", "loss", "min"), ("train", "acc", "max")] |
| nbest: |
| """ |
| assert check_argument_types() |
| if isinstance(nbest, int): |
| nbests = [nbest] |
| else: |
| nbests = list(nbest) |
| if len(nbests) == 0: |
| warnings.warn("At least 1 nbest values are required") |
| nbests = [1] |
| |
| nbest_epochs = [ |
| (ph, k, reporter.sort_epochs_and_values(ph, k, m)[: max(nbests)]) |
| for ph, k, m in best_model_criterion |
| if reporter.has(ph, k) |
| ] |
|
|
| _loaded = {} |
| for ph, cr, epoch_and_values in nbest_epochs: |
| _nbests = [i for i in nbests if i <= len(epoch_and_values)] |
| if len(_nbests) == 0: |
| _nbests = [1] |
|
|
| for n in _nbests: |
| if n == 0: |
| continue |
| elif n == 1: |
| |
| e, _ = epoch_and_values[0] |
| op = output_dir / f"{e}epoch.pth" |
| sym_op = output_dir / f"{ph}.{cr}.ave_1best.pth" |
| if sym_op.is_symlink() or sym_op.exists(): |
| sym_op.unlink() |
| sym_op.symlink_to(op.name) |
| else: |
| op = output_dir / f"{ph}.{cr}.ave_{n}best.pth" |
| logging.info( |
| f"Averaging {n}best models: " f'criterion="{ph}.{cr}": {op}' |
| ) |
|
|
| avg = None |
| |
| for e, _ in epoch_and_values[:n]: |
| if e not in _loaded: |
| _loaded[e] = torch.load( |
| output_dir / f"{e}epoch.pth", |
| map_location="cpu", |
| ) |
| states = _loaded[e] |
|
|
| if avg is None: |
| avg = states |
| else: |
| |
| for k in avg: |
| avg[k] = avg[k] + states[k] |
| for k in avg: |
| if str(avg[k].dtype).startswith("torch.int"): |
| |
| |
| |
| |
| |
| pass |
| else: |
| avg[k] = avg[k] / n |
|
|
| |
| torch.save(avg, op) |
|
|
| |
| op = output_dir / f"{ph}.{cr}.ave_{max(_nbests)}best.pth" |
| sym_op = output_dir / f"{ph}.{cr}.ave.pth" |
| if sym_op.is_symlink() or sym_op.exists(): |
| sym_op.unlink() |
| sym_op.symlink_to(op.name) |
|
|