| | |
| | import argparse |
| | import logging |
| | from pathlib import Path |
| | import sys |
| | from typing import Iterable |
| | from typing import Union |
| |
|
| | import numpy as np |
| |
|
| | from espnet.utils.cli_utils import get_commandline_args |
| |
|
| |
|
| | def aggregate_stats_dirs( |
| | input_dir: Iterable[Union[str, Path]], |
| | output_dir: Union[str, Path], |
| | log_level: str, |
| | skip_sum_stats: bool, |
| | ): |
| | logging.basicConfig( |
| | level=log_level, |
| | format="%(asctime)s (%(module)s:%(lineno)d) (levelname)s: %(message)s", |
| | ) |
| |
|
| | input_dirs = [Path(p) for p in input_dir] |
| | output_dir = Path(output_dir) |
| |
|
| | for mode in ["train", "valid"]: |
| | with (input_dirs[0] / mode / "batch_keys").open("r", encoding="utf-8") as f: |
| | batch_keys = [line.strip() for line in f if line.strip() != ""] |
| | with (input_dirs[0] / mode / "stats_keys").open("r", encoding="utf-8") as f: |
| | stats_keys = [line.strip() for line in f if line.strip() != ""] |
| | (output_dir / mode).mkdir(parents=True, exist_ok=True) |
| |
|
| | for key in batch_keys: |
| | with (output_dir / mode / f"{key}_shape").open( |
| | "w", encoding="utf-8" |
| | ) as fout: |
| | for idir in input_dirs: |
| | with (idir / mode / f"{key}_shape").open( |
| | "r", encoding="utf-8" |
| | ) as fin: |
| | |
| | |
| | lines = fin.readlines() |
| | lines = sorted(lines, key=lambda x: x.split()[0]) |
| | for line in lines: |
| | fout.write(line) |
| |
|
| | for key in stats_keys: |
| | if not skip_sum_stats: |
| | sum_stats = None |
| | for idir in input_dirs: |
| | stats = np.load(idir / mode / f"{key}_stats.npz") |
| | if sum_stats is None: |
| | sum_stats = dict(**stats) |
| | else: |
| | for k in stats: |
| | sum_stats[k] += stats[k] |
| |
|
| | np.savez(output_dir / mode / f"{key}_stats.npz", **sum_stats) |
| |
|
| | |
| | p = Path(mode) / "collect_feats" / f"{key}.scp" |
| | scp = input_dirs[0] / p |
| | if scp.exists(): |
| | (output_dir / p).parent.mkdir(parents=True, exist_ok=True) |
| | with (output_dir / p).open("w", encoding="utf-8") as fout: |
| | for idir in input_dirs: |
| | with (idir / p).open("r", encoding="utf-8") as fin: |
| | for line in fin: |
| | fout.write(line) |
| |
|
| |
|
| | def get_parser() -> argparse.ArgumentParser: |
| | parser = argparse.ArgumentParser( |
| | description="Aggregate statistics directories into one directory", |
| | formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
| | ) |
| | parser.add_argument( |
| | "--log_level", |
| | type=lambda x: x.upper(), |
| | default="INFO", |
| | choices=("CRITICAL", "ERROR", "WARNING", "INFO", "DEBUG", "NOTSET"), |
| | help="The verbose level of logging", |
| | ) |
| | parser.add_argument( |
| | "--skip_sum_stats", |
| | default=False, |
| | action="store_true", |
| | help="Skip computing the sum of statistics.", |
| | ) |
| |
|
| | parser.add_argument("--input_dir", action="append", help="Input directories") |
| | parser.add_argument("--output_dir", required=True, help="Output directory") |
| | return parser |
| |
|
| |
|
| | def main(cmd=None): |
| | print(get_commandline_args(), file=sys.stderr) |
| | parser = get_parser() |
| | args = parser.parse_args(cmd) |
| | kwargs = vars(args) |
| | aggregate_stats_dirs(**kwargs) |
| |
|
| |
|
| | if __name__ == "__main__": |
| | main() |
| |
|