| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| import copy |
|
|
| from nemo.collections.asr.data import audio_to_label |
| from nemo.collections.asr.data.audio_to_text_dataset import convert_to_config_list, get_chain_dataset |
| from nemo.collections.common.data.dataset import ConcatDataset |
|
|
|
|
| def get_classification_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToClassificationLabelDataset: |
| """ |
| Instantiates a Classification AudioLabelDataset. |
| |
| Args: |
| config: Config of the AudioToClassificationLabelDataset. |
| |
| Returns: |
| An instance of AudioToClassificationLabelDataset. |
| """ |
| dataset = audio_to_label.AudioToClassificationLabelDataset( |
| manifest_filepath=config['manifest_filepath'], |
| labels=config['labels'], |
| featurizer=featurizer, |
| max_duration=config.get('max_duration', None), |
| min_duration=config.get('min_duration', None), |
| trim=config.get('trim_silence', False), |
| is_regression_task=config.get('is_regression_task', False), |
| cal_labels_occurrence=config.get('cal_labels_occurrence', False), |
| ) |
| return dataset |
|
|
|
|
| def get_speech_label_dataset(featurizer, config: dict) -> audio_to_label.AudioToSpeechLabelDataset: |
| """ |
| Instantiates a Speech Label (e.g. VAD, speaker recognition) AudioLabelDataset. |
| |
| Args: |
| config: Config of the AudioToSpeechLabelDataSet. |
| |
| Returns: |
| An instance of AudioToSpeechLabelDataset. |
| """ |
| dataset = audio_to_label.AudioToSpeechLabelDataset( |
| manifest_filepath=config['manifest_filepath'], |
| labels=config['labels'], |
| featurizer=featurizer, |
| max_duration=config.get('max_duration', None), |
| min_duration=config.get('min_duration', None), |
| trim=config.get('trim_silence', False), |
| window_length_in_sec=config.get('window_length_in_sec', 0.31), |
| shift_length_in_sec=config.get('shift_length_in_sec', 0.01), |
| normalize_audio=config.get('normalize_audio', False), |
| cal_labels_occurrence=config.get('cal_labels_occurrence', False), |
| ) |
| return dataset |
|
|
|
|
| def get_tarred_classification_label_dataset( |
| featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int |
| ) -> audio_to_label.TarredAudioToClassificationLabelDataset: |
| """ |
| Instantiates a Classification TarredAudioLabelDataset. |
| |
| Args: |
| config: Config of the TarredAudioToClassificationLabelDataset. |
| shuffle_n: How many samples to look ahead and load to be shuffled. |
| See WebDataset documentation for more details. |
| global_rank: Global rank of this device. |
| world_size: Global world size in the training method. |
| |
| Returns: |
| An instance of TarredAudioToClassificationLabelDataset. |
| """ |
| tarred_audio_filepaths = config['tarred_audio_filepaths'] |
| manifest_filepaths = config['manifest_filepath'] |
| datasets = [] |
| tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths) |
| manifest_filepaths = convert_to_config_list(manifest_filepaths) |
|
|
| bucketing_weights = config.get('bucketing_weights', None) |
| if bucketing_weights: |
| for idx, weight in enumerate(bucketing_weights): |
| if not isinstance(weight, int) or weight <= 0: |
| raise ValueError(f"bucket weights must be positive integers") |
|
|
| if len(manifest_filepaths) != len(tarred_audio_filepaths): |
| raise ValueError( |
| f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets." |
| ) |
|
|
| for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( |
| zip(tarred_audio_filepaths, manifest_filepaths) |
| ): |
| if len(tarred_audio_filepath) == 1: |
| tarred_audio_filepath = tarred_audio_filepath[0] |
| dataset = audio_to_label.TarredAudioToClassificationLabelDataset( |
| audio_tar_filepaths=tarred_audio_filepath, |
| manifest_filepath=manifest_filepath, |
| labels=config['labels'], |
| featurizer=featurizer, |
| shuffle_n=shuffle_n, |
| max_duration=config.get('max_duration', None), |
| min_duration=config.get('min_duration', None), |
| trim=config.get('trim_silence', False), |
| shard_strategy=config.get('tarred_shard_strategy', 'scatter'), |
| global_rank=global_rank, |
| world_size=world_size, |
| is_regression_task=config.get('is_regression_task', False), |
| ) |
|
|
| if bucketing_weights: |
| [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])] |
| else: |
| datasets.append(dataset) |
|
|
| return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank) |
|
|
|
|
| def get_concat_tarred_speech_label_dataset( |
| featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int, |
| ): |
| tarred_audio_filepaths = config['tarred_audio_filepaths'] |
| manifest_filepaths = config['manifest_filepath'] |
| datasets = [] |
| for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( |
| zip(tarred_audio_filepaths, manifest_filepaths) |
| ): |
| conf = copy.deepcopy(config) |
| conf['manifest_filepath'] = manifest_filepath |
| conf['tarred_audio_filepaths'] = tarred_audio_filepath |
| dataset = get_tarred_speech_label_dataset( |
| config=conf, featurizer=featurizer, shuffle_n=shuffle_n, global_rank=global_rank, world_size=world_size, |
| ) |
| datasets.append(dataset) |
|
|
| dataset = ConcatDataset( |
| datasets, |
| sampling_technique=config.get('concat_sampling_technique', 'temperature'), |
| sampling_temperature=config.get('concat_sampling_temperature', 5), |
| sampling_probabilities=config.get('concat_sampling_probabilities', None), |
| global_rank=global_rank, |
| world_size=world_size, |
| shuffle=config['shuffle'], |
| ) |
| return dataset |
|
|
|
|
| def get_tarred_speech_label_dataset( |
| featurizer, config: dict, shuffle_n: int, global_rank: int, world_size: int, |
| ) -> audio_to_label.TarredAudioToSpeechLabelDataset: |
| """ |
| InInstantiates a Speech Label (e.g. VAD, speaker recognition) TarredAudioLabelDataset. |
| |
| Args: |
| config: Config of the TarredAudioToSpeechLabelDataset. |
| shuffle_n: How many samples to look ahead and load to be shuffled. |
| See WebDataset documentation for more details. |
| global_rank: Global rank of this device. |
| world_size: Global world size in the training method. |
| |
| Returns: |
| An instance of TarredAudioToSpeechLabelDataset. |
| """ |
| tarred_audio_filepaths = config['tarred_audio_filepaths'] |
| manifest_filepaths = config['manifest_filepath'] |
| datasets = [] |
| tarred_audio_filepaths = convert_to_config_list(tarred_audio_filepaths) |
| manifest_filepaths = convert_to_config_list(manifest_filepaths) |
|
|
| bucketing_weights = config.get('bucketing_weights', None) |
| if bucketing_weights: |
| for idx, weight in enumerate(bucketing_weights): |
| if not isinstance(weight, int) or weight <= 0: |
| raise ValueError(f"bucket weights must be positive integers") |
|
|
| if len(manifest_filepaths) != len(tarred_audio_filepaths): |
| raise ValueError( |
| f"manifest_filepaths (length={len(manifest_filepaths)}) and tarred_audio_filepaths (length={len(tarred_audio_filepaths)}) need to have the same number of buckets." |
| ) |
|
|
| for dataset_idx, (tarred_audio_filepath, manifest_filepath) in enumerate( |
| zip(tarred_audio_filepaths, manifest_filepaths) |
| ): |
| if len(tarred_audio_filepath) == 1: |
| tarred_audio_filepath = tarred_audio_filepath[0] |
| dataset = audio_to_label.TarredAudioToSpeechLabelDataset( |
| audio_tar_filepaths=tarred_audio_filepath, |
| manifest_filepath=manifest_filepath, |
| labels=config['labels'], |
| featurizer=featurizer, |
| shuffle_n=shuffle_n, |
| max_duration=config.get('max_duration', None), |
| min_duration=config.get('min_duration', None), |
| trim=config.get('trim_silence', False), |
| window_length_in_sec=config.get('window_length_in_sec', 8), |
| shift_length_in_sec=config.get('shift_length_in_sec', 0.075), |
| normalize_audio=config.get('normalize_audio', False), |
| shard_strategy=config.get('tarred_shard_strategy', 'scatter'), |
| global_rank=global_rank, |
| world_size=world_size, |
| ) |
|
|
| if bucketing_weights: |
| [datasets.append(dataset) for _ in range(bucketing_weights[dataset_idx])] |
| else: |
| datasets.append(dataset) |
|
|
| return get_chain_dataset(datasets=datasets, ds_config=config, rank=global_rank) |
|
|