| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| """ |
| This file computes fbank features of the aishell dataset. |
| It looks for manifests in the directory data/manifests. |
| |
| The generated fbank features are saved in data/fbank. |
| """ |
|
|
| import argparse |
| import logging |
| import os |
| from pathlib import Path |
|
|
| import torch |
| from lhotse import ( |
| CutSet, |
| Fbank, |
| FbankConfig, |
| LilcomChunkyWriter, |
| WhisperFbank, |
| WhisperFbankConfig, |
| ) |
| from lhotse.recipes.utils import read_manifests_if_cached |
|
|
| from icefall.utils import get_executor, str2bool |
|
|
| |
| |
| |
| |
| torch.set_num_threads(1) |
| torch.set_num_interop_threads(1) |
|
|
|
|
| def compute_fbank_aishell( |
| num_mel_bins: int = 80, |
| perturb_speed: bool = False, |
| whisper_fbank: bool = False, |
| output_dir: str = "data/fbank", |
| ): |
| src_dir = Path("data/manifests") |
| output_dir = Path(output_dir) |
| num_jobs = min(15, os.cpu_count()) |
|
|
| dataset_parts = ( |
| "train", |
| "dev", |
| "test", |
| ) |
| prefix = "aishell" |
| suffix = "jsonl.gz" |
| manifests = read_manifests_if_cached( |
| dataset_parts=dataset_parts, |
| output_dir=src_dir, |
| prefix=prefix, |
| suffix=suffix, |
| ) |
| assert manifests is not None |
|
|
| assert len(manifests) == len(dataset_parts), ( |
| len(manifests), |
| len(dataset_parts), |
| list(manifests.keys()), |
| dataset_parts, |
| ) |
| if whisper_fbank: |
| extractor = WhisperFbank( |
| WhisperFbankConfig(num_filters=num_mel_bins, device="cuda") |
| ) |
| else: |
| extractor = Fbank(FbankConfig(num_mel_bins=num_mel_bins)) |
|
|
| with get_executor() as ex: |
| for partition, m in manifests.items(): |
| if (output_dir / f"{prefix}_cuts_{partition}.{suffix}").is_file(): |
| logging.info(f"{partition} already exists - skipping.") |
| continue |
| logging.info(f"Processing {partition}") |
| cut_set = CutSet.from_manifests( |
| recordings=m["recordings"], |
| supervisions=m["supervisions"], |
| ) |
| if "train" in partition and perturb_speed: |
| logging.info("Doing speed perturb") |
| cut_set = ( |
| cut_set + cut_set.perturb_speed(0.9) + cut_set.perturb_speed(1.1) |
| ) |
| cut_set = cut_set.compute_and_store_features( |
| extractor=extractor, |
| storage_path=f"{output_dir}/{prefix}_feats_{partition}", |
| |
| num_jobs=num_jobs if ex is None else 80, |
| executor=ex, |
| storage_type=LilcomChunkyWriter, |
| ) |
| cut_set.to_file(output_dir / f"{prefix}_cuts_{partition}.{suffix}") |
|
|
|
|
| def get_args(): |
| parser = argparse.ArgumentParser() |
| parser.add_argument( |
| "--num-mel-bins", |
| type=int, |
| default=80, |
| help="""The number of mel bins for Fbank""", |
| ) |
| parser.add_argument( |
| "--perturb-speed", |
| type=str2bool, |
| default=False, |
| help="Enable 0.9 and 1.1 speed perturbation for data augmentation. Default: False.", |
| ) |
| parser.add_argument( |
| "--whisper-fbank", |
| type=str2bool, |
| default=False, |
| help="Use WhisperFbank instead of Fbank. Default: False.", |
| ) |
| parser.add_argument( |
| "--output-dir", |
| type=str, |
| default="data/fbank", |
| help="Output directory. Default: data/fbank.", |
| ) |
| return parser.parse_args() |
|
|
|
|
| if __name__ == "__main__": |
| formatter = "%(asctime)s %(levelname)s [%(filename)s:%(lineno)d] %(message)s" |
|
|
| logging.basicConfig(format=formatter, level=logging.INFO) |
|
|
| args = get_args() |
| compute_fbank_aishell( |
| num_mel_bins=args.num_mel_bins, |
| perturb_speed=args.perturb_speed, |
| whisper_fbank=args.whisper_fbank, |
| output_dir=args.output_dir, |
| ) |
|
|