|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
|
from itertools import islice |
|
|
from pathlib import Path |
|
|
|
|
|
from lhotse.cut import Cut |
|
|
from lhotse.dataset.sampling.dynamic_bucketing import estimate_duration_buckets |
|
|
from omegaconf import OmegaConf |
|
|
|
|
|
from nemo.collections.common.data.lhotse.cutset import read_cutset_from_config |
|
|
from nemo.collections.common.data.lhotse.dataloader import LhotseDataLoadingConfig |
|
|
|
|
|
|
|
|
def parse_args(): |
|
|
parser = argparse.ArgumentParser( |
|
|
description="Estimate duration bins for Lhotse dynamic bucketing using a sample of the input dataset. " |
|
|
"The dataset is read either from one or more manifest files and supports data weighting.", |
|
|
formatter_class=argparse.ArgumentDefaultsHelpFormatter, |
|
|
) |
|
|
parser.add_argument( |
|
|
"input", |
|
|
help='Data input. Options: ' |
|
|
'1) "path.json" - any single NeMo manifest; ' |
|
|
'2) "[[path1.json],[path2.json],...]" - any collection of NeMo manifests; ' |
|
|
'3) "[[path1.json,weight1],[path2.json,weight2],...]" - any collection of weighted NeMo manifests; ' |
|
|
'4) "input_cfg.yaml" - a new option supporting input configs, same as in model training \'input_cfg\' arg; ' |
|
|
'5) "path/to/shar_data" - a path to Lhotse Shar data directory; ' |
|
|
'6) "key=val" - in case none of the previous variants cover your case: "key" is the key you\'d use in NeMo training config with its corresponding value ', |
|
|
) |
|
|
parser.add_argument("-b", "--buckets", type=int, default=30, help="The desired number of buckets.") |
|
|
parser.add_argument( |
|
|
"-n", |
|
|
"--num_examples", |
|
|
type=int, |
|
|
default=-1, |
|
|
help="The number of examples (utterances) to estimate the bins. -1 means use all data " |
|
|
"(be careful: it could be iterated over infinitely).", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-l", |
|
|
"--min_duration", |
|
|
type=float, |
|
|
default=-float("inf"), |
|
|
help="If specified, we'll filter out utterances shorter than this.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-u", |
|
|
"--max_duration", |
|
|
type=float, |
|
|
default=float("inf"), |
|
|
help="If specified, we'll filter out utterances longer than this.", |
|
|
) |
|
|
parser.add_argument( |
|
|
"-q", "--quiet", type=bool, default=False, help="When specified, only print the estimated duration bins." |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
|
|
|
def main(): |
|
|
args = parse_args() |
|
|
if '=' in args.input: |
|
|
inp_arg = args.input |
|
|
elif args.input.endswith(".yaml"): |
|
|
inp_arg = f"input_cfg={args.input}" |
|
|
elif Path(args.input).is_dir(): |
|
|
inp_arg = f"shar_path={args.input}" |
|
|
else: |
|
|
inp_arg = f"manifest_filepath={args.input}" |
|
|
config = OmegaConf.merge( |
|
|
OmegaConf.structured(LhotseDataLoadingConfig), |
|
|
OmegaConf.from_dotlist([inp_arg, "metadata_only=true"]), |
|
|
) |
|
|
cuts, _ = read_cutset_from_config(config) |
|
|
min_dur, max_dur = args.min_duration, args.max_duration |
|
|
nonaudio, discarded, tot = 0, 0, 0 |
|
|
observed_max_dur = 0 |
|
|
|
|
|
def duration_ok(cut) -> bool: |
|
|
nonlocal nonaudio, discarded, tot, observed_max_dur |
|
|
tot += 1 |
|
|
if not isinstance(cut, Cut): |
|
|
nonaudio += 1 |
|
|
return False |
|
|
if not (min_dur <= cut.duration <= max_dur): |
|
|
discarded += 1 |
|
|
return False |
|
|
observed_max_dur = max(cut.duration, observed_max_dur) |
|
|
return True |
|
|
|
|
|
cuts = cuts.filter(duration_ok) |
|
|
if (N := args.num_examples) > 0: |
|
|
cuts = islice(cuts, N) |
|
|
duration_bins = estimate_duration_buckets(cuts, num_buckets=args.buckets) |
|
|
duration_bins = f"[{','.join(str(round(b, ndigits=5)) for b in duration_bins)}]" |
|
|
if args.quiet: |
|
|
print(duration_bins) |
|
|
return |
|
|
if discarded: |
|
|
ratio = discarded / tot |
|
|
print(f"Note: we discarded {discarded}/{tot} ({ratio:.2%}) utterances due to min/max duration filtering.") |
|
|
if nonaudio: |
|
|
print(f"Note: we discarded {nonaudio} non-audio examples found during iteration.") |
|
|
print(f"Used {tot - nonaudio - discarded} examples for the estimation.") |
|
|
print("Use the following options in your config:") |
|
|
print(f"\tnum_buckets={args.buckets}") |
|
|
print(f"\tbucket_duration_bins={duration_bins}") |
|
|
print(f"\tmax_duration={observed_max_dur}") |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|