| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| |
|
| | import argparse |
| | import logging |
| | from functools import lru_cache |
| | from pathlib import Path |
| | from typing import Any, Dict, Optional |
| |
|
| | import torch |
| | from dataset import HubertDataset |
| | from lhotse import CutSet, load_manifest_lazy |
| | from lhotse.dataset import DynamicBucketingSampler, SimpleCutSampler |
| | from lhotse.utils import fix_random_seed |
| | from torch.utils.data import DataLoader |
| |
|
| | from icefall.utils import str2bool |
| |
|
| |
|
| | class _SeedWorkers: |
| | def __init__(self, seed: int): |
| | self.seed = seed |
| |
|
| | def __call__(self, worker_id: int): |
| | fix_random_seed(self.seed + worker_id) |
| |
|
| |
|
| | class LibriSpeechDataModule: |
| | """ |
| | DataModule for SSL experiments. |
| | It assumes there is always one train and valid dataloader, |
| | but there can be multiple test dataloaders (e.g. LibriSpeech test-clean |
| | and test-other). |
| | |
| | It contains all the common data pipeline modules used in SSL |
| | experiments, e.g.: |
| | - dynamic batch size, |
| | - bucketing samplers, |
| | |
| | This class should be derived for specific corpora used in SSL tasks. |
| | """ |
| |
|
| | def __init__(self, args: argparse.Namespace): |
| | self.args = args |
| |
|
| | @classmethod |
| | def add_arguments(cls, parser: argparse.ArgumentParser): |
| | group = parser.add_argument_group( |
| | title="SSL data related options", |
| | description="These options are used for the preparation of " |
| | "PyTorch DataLoaders from Lhotse CutSet's -- they control the " |
| | "effective batch sizes, sampling strategies.", |
| | ) |
| | group.add_argument( |
| | "--full-libri", |
| | type=str2bool, |
| | default=True, |
| | help="When enabled use 960h LibriSpeech. " "Otherwise, use 100h subset.", |
| | ) |
| |
|
| | group.add_argument( |
| | "--manifest-dir", |
| | type=Path, |
| | default=Path("data/kmeans"), |
| | help="Path to directory with train/valid/test cuts.", |
| | ) |
| | group.add_argument( |
| | "--max-duration", |
| | type=float, |
| | default=200.0, |
| | help="Maximum pooled recordings duration (seconds) in a " |
| | "single batch. You can reduce it if it causes CUDA OOM.", |
| | ) |
| | group.add_argument( |
| | "--bucketing-sampler", |
| | type=str2bool, |
| | default=True, |
| | help="When enabled, the batches will come from buckets of " |
| | "similar duration (saves padding frames).", |
| | ) |
| | group.add_argument( |
| | "--num-buckets", |
| | type=int, |
| | default=30, |
| | help="The number of buckets for the DynamicBucketingSampler" |
| | "(you might want to increase it for larger datasets).", |
| | ) |
| | group.add_argument( |
| | "--shuffle", |
| | type=str2bool, |
| | default=True, |
| | help="When enabled (=default), the examples will be " |
| | "shuffled for each epoch.", |
| | ) |
| | group.add_argument( |
| | "--drop-last", |
| | type=str2bool, |
| | default=True, |
| | help="Whether to drop last batch. Used by sampler.", |
| | ) |
| | group.add_argument( |
| | "--num-workers", |
| | type=int, |
| | default=2, |
| | help="The number of training dataloader workers that " |
| | "collect the batches.", |
| | ) |
| | group.add_argument( |
| | "--do-normalize", |
| | type=str2bool, |
| | default=True, |
| | help="whether to normalize the data", |
| | ) |
| | group.add_argument( |
| | "--random-crop", |
| | type=str2bool, |
| | default=True, |
| | help="always crop from the beginning if false", |
| | ) |
| |
|
| | def train_dataloaders( |
| | self, |
| | cuts_train: CutSet, |
| | max_sample_size: Optional[int] = None, |
| | sample_rate: float = 16000, |
| | label_rate: float = 50, |
| | random_crop: bool = True, |
| | pad_audio: bool = False, |
| | num_classes: list = [504], |
| | do_normalize: bool = True, |
| | sampler_state_dict: Optional[Dict[str, Any]] = None, |
| | ) -> DataLoader: |
| | """ |
| | Args: |
| | cuts_train: |
| | CutSet for training. |
| | sampler_state_dict: |
| | The state dict for the training sampler. |
| | """ |
| | logging.info("About to create train dataset") |
| | train = HubertDataset( |
| | max_sample_size=max_sample_size, |
| | sample_rate=sample_rate, |
| | label_rate=label_rate, |
| | random_crop=random_crop, |
| | pad_audio=pad_audio, |
| | num_classes=num_classes, |
| | do_normalize=do_normalize, |
| | ) |
| |
|
| | if self.args.bucketing_sampler: |
| | logging.info("Using DynamicBucketingSampler.") |
| | train_sampler = DynamicBucketingSampler( |
| | cuts_train, |
| | max_duration=self.args.max_duration, |
| | shuffle=self.args.shuffle, |
| | num_buckets=self.args.num_buckets, |
| | drop_last=self.args.drop_last, |
| | ) |
| | else: |
| | logging.info("Using SimpleCutSampler.") |
| | train_sampler = SimpleCutSampler( |
| | cuts_train, |
| | max_duration=self.args.max_duration, |
| | shuffle=self.args.shuffle, |
| | ) |
| | logging.info("About to create train dataloader") |
| |
|
| | if sampler_state_dict is not None: |
| | logging.info("Loading sampler state dict") |
| | train_sampler.load_state_dict(sampler_state_dict) |
| |
|
| | |
| | |
| | seed = torch.randint(0, 100000, ()).item() |
| | worker_init_fn = _SeedWorkers(seed) |
| |
|
| | train_dl = DataLoader( |
| | train, |
| | sampler=train_sampler, |
| | batch_size=None, |
| | num_workers=self.args.num_workers, |
| | persistent_workers=False, |
| | worker_init_fn=worker_init_fn, |
| | ) |
| |
|
| | return train_dl |
| |
|
| | def valid_dataloaders( |
| | self, |
| | cuts_valid: CutSet, |
| | max_sample_size: Optional[int] = None, |
| | sample_rate: float = 16000, |
| | label_rate: float = 50, |
| | random_crop: bool = True, |
| | pad_audio: bool = False, |
| | num_classes: list = [504], |
| | do_normalize: bool = True, |
| | ) -> DataLoader: |
| | logging.info("About to create dev dataset") |
| | validate = HubertDataset( |
| | max_sample_size=max_sample_size, |
| | sample_rate=sample_rate, |
| | label_rate=label_rate, |
| | random_crop=random_crop, |
| | pad_audio=pad_audio, |
| | num_classes=num_classes, |
| | do_normalize=do_normalize, |
| | ) |
| | valid_sampler = DynamicBucketingSampler( |
| | cuts_valid, |
| | max_duration=self.args.max_duration, |
| | shuffle=False, |
| | ) |
| | logging.info("About to create dev dataloader") |
| | valid_dl = DataLoader( |
| | validate, |
| | sampler=valid_sampler, |
| | batch_size=None, |
| | num_workers=2, |
| | persistent_workers=False, |
| | ) |
| |
|
| | return valid_dl |
| |
|
| | def test_dataloaders( |
| | self, |
| | cuts: CutSet, |
| | sample_rate: float = 16000, |
| | label_rate: float = 50, |
| | random_crop: bool = True, |
| | pad_audio: bool = False, |
| | num_classes: list = [504], |
| | do_normalize: bool = True, |
| | ) -> DataLoader: |
| | logging.debug("About to create test dataset") |
| | test = HubertDataset( |
| | sample_rate=sample_rate, |
| | label_rate=label_rate, |
| | random_crop=random_crop, |
| | pad_audio=pad_audio, |
| | num_classes=num_classes, |
| | do_normalize=do_normalize, |
| | ) |
| | sampler = DynamicBucketingSampler( |
| | cuts, |
| | max_duration=self.args.max_duration, |
| | shuffle=False, |
| | ) |
| | logging.debug("About to create test dataloader") |
| | test_dl = DataLoader( |
| | test, |
| | batch_size=None, |
| | sampler=sampler, |
| | num_workers=self.args.num_workers, |
| | ) |
| | return test_dl |
| |
|
| | @lru_cache() |
| | def train_clean_100_cuts(self) -> CutSet: |
| | logging.info("About to get train-clean-100 cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_train-clean-100.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def train_clean_360_cuts(self) -> CutSet: |
| | logging.info("About to get train-clean-360 cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_train-clean-360.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def train_other_500_cuts(self) -> CutSet: |
| | logging.info("About to get train-other-500 cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_train-other-500.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def train_all_shuf_cuts(self) -> CutSet: |
| | logging.info( |
| | "About to get the shuffled train-clean-100, \ |
| | train-clean-360 and train-other-500 cuts" |
| | ) |
| | train_clean_100_cuts = self.train_clean_100_cuts() |
| | train_clean_360_cuts = self.train_clean_360_cuts() |
| | train_other_500_cuts = self.train_other_500_cuts() |
| | return CutSet.mux( |
| | train_clean_100_cuts, |
| | train_clean_360_cuts, |
| | train_other_500_cuts, |
| | weights=[ |
| | 28539, |
| | 104014, |
| | 148688, |
| | ], |
| | ) |
| |
|
| | @lru_cache() |
| | def dev_clean_cuts(self) -> CutSet: |
| | logging.info("About to get dev-clean cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_dev-clean.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def dev_other_cuts(self) -> CutSet: |
| | logging.info("About to get dev-other cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_dev-other.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def test_clean_cuts(self) -> CutSet: |
| | logging.info("About to get test-clean cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_test-clean.jsonl.gz" |
| | ) |
| |
|
| | @lru_cache() |
| | def test_other_cuts(self) -> CutSet: |
| | logging.info("About to get test-other cuts") |
| | return load_manifest_lazy( |
| | self.args.manifest_dir / "librispeech_cuts_test-other.jsonl.gz" |
| | ) |
| |
|