| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import argparse |
| import inspect |
| import logging |
| from functools import lru_cache |
| from pathlib import Path |
| from typing import Any, Dict, List, Optional |
|
|
| from lhotse import CutSet, Fbank, FbankConfig, load_manifest, load_manifest_lazy |
| from lhotse.dataset import ( |
| CutConcatenate, |
| CutMix, |
| DynamicBucketingSampler, |
| K2SpeechRecognitionDataset, |
| PrecomputedFeatures, |
| SimpleCutSampler, |
| SpecAugment, |
| ) |
| from lhotse.dataset.input_strategies import OnTheFlyFeatures |
| from torch.utils.data import DataLoader |
|
|
| from icefall.utils import str2bool |
|
|
|
|
| class AishellAsrDataModule: |
| """ |
| DataModule for k2 ASR experiments. |
| It assumes there is always one train and valid dataloader, |
| but there can be multiple test dataloaders (e.g. LibriSpeech test-clean |
| and test-other). |
| |
| It contains all the common data pipeline modules used in ASR |
| experiments, e.g.: |
| - dynamic batch size, |
| - bucketing samplers, |
| - cut concatenation, |
| - augmentation, |
| - on-the-fly feature extraction |
| |
| This class should be derived for specific corpora used in ASR tasks. |
| """ |
|
|
| def __init__(self, args: argparse.Namespace): |
| self.args = args |
|
|
| @classmethod |
| def add_arguments(cls, parser: argparse.ArgumentParser): |
| group = parser.add_argument_group( |
| title="ASR data related options", |
| description="These options are used for the preparation of " |
| "PyTorch DataLoaders from Lhotse CutSet's -- they control the " |
| "effective batch sizes, sampling strategies, applied data " |
| "augmentations, etc.", |
| ) |
| group.add_argument( |
| "--manifest-dir", |
| type=Path, |
| default=Path("data/fbank"), |
| help="Path to directory with train/valid/test cuts.", |
| ) |
| group.add_argument( |
| "--max-duration", |
| type=int, |
| default=200.0, |
| help="Maximum pooled recordings duration (seconds) in a " |
| "single batch. You can reduce it if it causes CUDA OOM.", |
| ) |
| group.add_argument( |
| "--bucketing-sampler", |
| type=str2bool, |
| default=True, |
| help="When enabled, the batches will come from buckets of " |
| "similar duration (saves padding frames).", |
| ) |
| group.add_argument( |
| "--num-buckets", |
| type=int, |
| default=15, |
| help="The number of buckets for the DynamicBucketingSampler" |
| "(you might want to increase it for larger datasets).", |
| ) |
| group.add_argument( |
| "--concatenate-cuts", |
| type=str2bool, |
| default=False, |
| help="When enabled, utterances (cuts) will be concatenated " |
| "to minimize the amount of padding.", |
| ) |
| group.add_argument( |
| "--duration-factor", |
| type=float, |
| default=1.0, |
| help="Determines the maximum duration of a concatenated cut " |
| "relative to the duration of the longest cut in a batch.", |
| ) |
| group.add_argument( |
| "--gap", |
| type=float, |
| default=1.0, |
| help="The amount of padding (in seconds) inserted between " |
| "concatenated cuts. This padding is filled with noise when " |
| "noise augmentation is used.", |
| ) |
| group.add_argument( |
| "--on-the-fly-feats", |
| type=str2bool, |
| default=False, |
| help="When enabled, use on-the-fly cut mixing and feature " |
| "extraction. Will drop existing precomputed feature manifests " |
| "if available.", |
| ) |
| group.add_argument( |
| "--shuffle", |
| type=str2bool, |
| default=True, |
| help="When enabled (=default), the examples will be " |
| "shuffled for each epoch.", |
| ) |
| group.add_argument( |
| "--drop-last", |
| type=str2bool, |
| default=True, |
| help="Whether to drop last batch. Used by sampler.", |
| ) |
| group.add_argument( |
| "--return-cuts", |
| type=str2bool, |
| default=True, |
| help="When enabled, each batch will have the " |
| "field: batch['supervisions']['cut'] with the cuts that " |
| "were used to construct it.", |
| ) |
|
|
| group.add_argument( |
| "--num-workers", |
| type=int, |
| default=2, |
| help="The number of training dataloader workers that " |
| "collect the batches.", |
| ) |
|
|
| group.add_argument( |
| "--enable-spec-aug", |
| type=str2bool, |
| default=True, |
| help="When enabled, use SpecAugment for training dataset.", |
| ) |
|
|
| group.add_argument( |
| "--spec-aug-time-warp-factor", |
| type=int, |
| default=80, |
| help="Used only when --enable-spec-aug is True. " |
| "It specifies the factor for time warping in SpecAugment. " |
| "Larger values mean more warping. " |
| "A value less than 1 means to disable time warp.", |
| ) |
|
|
| group.add_argument( |
| "--enable-musan", |
| type=str2bool, |
| default=True, |
| help="When enabled, select noise from MUSAN and mix it" |
| "with training dataset. ", |
| ) |
|
|
| def train_dataloaders( |
| self, cuts_train: CutSet, sampler_state_dict: Optional[Dict[str, Any]] = None |
| ) -> DataLoader: |
| """ |
| Args: |
| cuts_train: |
| CutSet for training. |
| sampler_state_dict: |
| The state dict for the training sampler. |
| """ |
| logging.info("About to get Musan cuts") |
| cuts_musan = load_manifest(self.args.manifest_dir / "musan_cuts.jsonl.gz") |
|
|
| transforms = [] |
| if self.args.enable_musan: |
| logging.info("Enable MUSAN") |
| transforms.append( |
| CutMix(cuts=cuts_musan, p=0.5, snr=(10, 20), preserve_id=True) |
| ) |
| else: |
| logging.info("Disable MUSAN") |
|
|
| if self.args.concatenate_cuts: |
| logging.info( |
| f"Using cut concatenation with duration factor " |
| f"{self.args.duration_factor} and gap {self.args.gap}." |
| ) |
| |
| |
| |
| transforms = [ |
| CutConcatenate( |
| duration_factor=self.args.duration_factor, gap=self.args.gap |
| ) |
| ] + transforms |
|
|
| input_transforms = [] |
| if self.args.enable_spec_aug: |
| logging.info("Enable SpecAugment") |
| logging.info(f"Time warp factor: {self.args.spec_aug_time_warp_factor}") |
| |
| |
| |
| num_frame_masks = 10 |
| num_frame_masks_parameter = inspect.signature( |
| SpecAugment.__init__ |
| ).parameters["num_frame_masks"] |
| if num_frame_masks_parameter.default == 1: |
| num_frame_masks = 2 |
| logging.info(f"Num frame mask: {num_frame_masks}") |
| input_transforms.append( |
| SpecAugment( |
| time_warp_factor=self.args.spec_aug_time_warp_factor, |
| num_frame_masks=num_frame_masks, |
| features_mask_size=27, |
| num_feature_masks=2, |
| frames_mask_size=100, |
| ) |
| ) |
| else: |
| logging.info("Disable SpecAugment") |
|
|
| logging.info("About to create train dataset") |
| train = K2SpeechRecognitionDataset( |
| cut_transforms=transforms, |
| input_transforms=input_transforms, |
| return_cuts=self.args.return_cuts, |
| ) |
|
|
| if self.args.on_the_fly_feats: |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| train = K2SpeechRecognitionDataset( |
| cut_transforms=transforms, |
| input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), |
| input_transforms=input_transforms, |
| return_cuts=self.args.return_cuts, |
| ) |
|
|
| if self.args.bucketing_sampler: |
| logging.info("Using DynamicBucketingSampler.") |
| train_sampler = DynamicBucketingSampler( |
| cuts_train, |
| max_duration=self.args.max_duration, |
| shuffle=self.args.shuffle, |
| num_buckets=self.args.num_buckets, |
| buffer_size=self.args.num_buckets * 5000, |
| drop_last=self.args.drop_last, |
| ) |
| else: |
| logging.info("Using SimpleCutSampler.") |
| train_sampler = SimpleCutSampler( |
| cuts_train, |
| max_duration=self.args.max_duration, |
| shuffle=self.args.shuffle, |
| ) |
| logging.info("About to create train dataloader") |
|
|
| if sampler_state_dict is not None: |
| logging.info("Loading sampler state dict") |
| train_sampler.load_state_dict(sampler_state_dict) |
|
|
| train_dl = DataLoader( |
| train, |
| sampler=train_sampler, |
| batch_size=None, |
| num_workers=self.args.num_workers, |
| persistent_workers=False, |
| ) |
|
|
| return train_dl |
|
|
| def valid_dataloaders(self, cuts_valid: CutSet) -> DataLoader: |
| transforms = [] |
| if self.args.concatenate_cuts: |
| transforms = [ |
| CutConcatenate( |
| duration_factor=self.args.duration_factor, gap=self.args.gap |
| ) |
| ] + transforms |
|
|
| logging.info("About to create dev dataset") |
| if self.args.on_the_fly_feats: |
| validate = K2SpeechRecognitionDataset( |
| cut_transforms=transforms, |
| input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))), |
| return_cuts=self.args.return_cuts, |
| ) |
| else: |
| validate = K2SpeechRecognitionDataset( |
| cut_transforms=transforms, |
| return_cuts=self.args.return_cuts, |
| ) |
| valid_sampler = DynamicBucketingSampler( |
| cuts_valid, |
| max_duration=self.args.max_duration, |
| shuffle=False, |
| ) |
| logging.info("About to create dev dataloader") |
| valid_dl = DataLoader( |
| validate, |
| sampler=valid_sampler, |
| batch_size=None, |
| num_workers=2, |
| persistent_workers=False, |
| ) |
|
|
| return valid_dl |
|
|
| def test_dataloaders(self, cuts: CutSet) -> DataLoader: |
| logging.info("About to create test dataset") |
| test = K2SpeechRecognitionDataset( |
| input_strategy=OnTheFlyFeatures(Fbank(FbankConfig(num_mel_bins=80))) |
| if self.args.on_the_fly_feats |
| else PrecomputedFeatures(), |
| return_cuts=self.args.return_cuts, |
| ) |
| sampler = DynamicBucketingSampler( |
| cuts, |
| max_duration=self.args.max_duration, |
| shuffle=False, |
| ) |
| test_dl = DataLoader( |
| test, |
| batch_size=None, |
| sampler=sampler, |
| num_workers=self.args.num_workers, |
| ) |
| return test_dl |
|
|
| @lru_cache() |
| def train_cuts(self) -> CutSet: |
| logging.info("About to get train cuts") |
| cuts_train = load_manifest_lazy( |
| self.args.manifest_dir / "aishell_cuts_train.jsonl.gz" |
| ) |
| return cuts_train |
|
|
| @lru_cache() |
| def valid_cuts(self) -> CutSet: |
| logging.info("About to get dev cuts") |
| return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_dev.jsonl.gz") |
|
|
| @lru_cache() |
| def test_cuts(self) -> List[CutSet]: |
| logging.info("About to get test cuts") |
| return load_manifest_lazy(self.args.manifest_dir / "aishell_cuts_test.jsonl.gz") |
|
|