| from collections import defaultdict |
| import logging |
| from pathlib import Path |
| from typing import Dict |
| from typing import Iterable |
| from typing import List |
| from typing import Optional |
| from typing import Tuple |
|
|
| import numpy as np |
| import torch |
| from torch.nn.parallel import data_parallel |
| from torch.utils.data import DataLoader |
| from typeguard import check_argument_types |
|
|
| from espnet2.fileio.datadir_writer import DatadirWriter |
| from espnet2.fileio.npy_scp import NpyScpWriter |
| from espnet2.torch_utils.device_funcs import to_device |
| from espnet2.torch_utils.forward_adaptor import ForwardAdaptor |
| from espnet2.train.abs_espnet_model import AbsESPnetModel |
|
|
|
|
| @torch.no_grad() |
| def collect_stats( |
| model: AbsESPnetModel, |
| train_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| valid_iter: DataLoader and Iterable[Tuple[List[str], Dict[str, torch.Tensor]]], |
| output_dir: Path, |
| ngpu: Optional[int], |
| log_interval: Optional[int], |
| write_collected_feats: bool, |
| ) -> None: |
| """Perform on collect_stats mode. |
| |
| Running for deriving the shape information from data |
| and gathering statistics. |
| This method is used before executing train(). |
| |
| """ |
| assert check_argument_types() |
|
|
| npy_scp_writers = {} |
| for itr, mode in zip([train_iter, valid_iter], ["train", "valid"]): |
| if log_interval is None: |
| try: |
| log_interval = max(len(itr) // 20, 10) |
| except TypeError: |
| log_interval = 100 |
|
|
| sum_dict = defaultdict(lambda: 0) |
| sq_dict = defaultdict(lambda: 0) |
| count_dict = defaultdict(lambda: 0) |
|
|
| with DatadirWriter(output_dir / mode) as datadir_writer: |
| for iiter, (keys, batch) in enumerate(itr, 1): |
| batch = to_device(batch, "cuda" if ngpu > 0 else "cpu") |
|
|
| |
| for name in batch: |
| if name.endswith("_lengths"): |
| continue |
| for i, (key, data) in enumerate(zip(keys, batch[name])): |
| if f"{name}_lengths" in batch: |
| lg = int(batch[f"{name}_lengths"][i]) |
| data = data[:lg] |
| datadir_writer[f"{name}_shape"][key] = ",".join( |
| map(str, data.shape) |
| ) |
|
|
| |
| if ngpu <= 1: |
| data = model.collect_feats(**batch) |
| else: |
| |
| data = data_parallel( |
| ForwardAdaptor(model, "collect_feats"), |
| (), |
| range(ngpu), |
| module_kwargs=batch, |
| ) |
|
|
| |
| for key, v in data.items(): |
| for i, (uttid, seq) in enumerate(zip(keys, v.cpu().numpy())): |
| |
| if f"{key}_lengths" in data: |
| length = data[f"{key}_lengths"][i] |
| |
| seq = seq[:length] |
| else: |
| |
| seq = seq[None] |
| |
| sum_dict[key] += seq.sum(0) |
| sq_dict[key] += (seq ** 2).sum(0) |
| count_dict[key] += len(seq) |
|
|
| |
| if write_collected_feats: |
| |
| if (key, mode) not in npy_scp_writers: |
| p = output_dir / mode / "collect_feats" |
| npy_scp_writers[(key, mode)] = NpyScpWriter( |
| p / f"data_{key}", p / f"{key}.scp" |
| ) |
| |
| npy_scp_writers[(key, mode)][uttid] = seq |
|
|
| if iiter % log_interval == 0: |
| logging.info(f"Niter: {iiter}") |
|
|
| for key in sum_dict: |
| np.savez( |
| output_dir / mode / f"{key}_stats.npz", |
| count=count_dict[key], |
| sum=sum_dict[key], |
| sum_square=sq_dict[key], |
| ) |
|
|
| |
| with (output_dir / mode / "batch_keys").open("w", encoding="utf-8") as f: |
| f.write( |
| "\n".join(filter(lambda x: not x.endswith("_lengths"), batch)) + "\n" |
| ) |
| with (output_dir / mode / "stats_keys").open("w", encoding="utf-8") as f: |
| f.write("\n".join(sum_dict) + "\n") |
|
|