| from collections.abc import Callable |
| from typing import List |
|
|
| import hw_asr.augmentations.spectrogram_augmentations |
| import hw_asr.augmentations.wave_augmentations |
| from hw_asr.augmentations.random_choice import RandomChoice |
| from hw_asr.augmentations.sequential_random_apply import SequentialRandomApply |
| |
| |
| from hw_asr.utils.parse_config import ConfigParser |
|
|
|
|
| def from_configs(configs: ConfigParser): |
| wave_augs = [] |
| if "augmentations" in configs.config and "wave" in configs.config["augmentations"]: |
| for aug_dict in configs.config["augmentations"]["wave"]: |
| wave_augs.append( |
| configs.init_obj(aug_dict, hw_asr.augmentations.wave_augmentations) |
| ) |
|
|
| spec_augs = [] |
| if "augmentations" in configs.config and "spectrogram" in configs.config["augmentations"]: |
| for aug_dict in configs.config["augmentations"]["spectrogram"]: |
| spec_augs.append( |
| configs.init_obj(aug_dict, hw_asr.augmentations.spectrogram_augmentations) |
| ) |
| return _to_function(RandomChoice, wave_augs, configs.config["augmentations"]["random_apply_p"]), _to_function(SequentialRandomApply, spec_augs, configs.config["augmentations"]["random_apply_p"]) |
|
|
|
|
| def _to_function(random_type, augs_list: List[Callable], p: float): |
| if len(augs_list) == 0: |
| return None |
| elif len(augs_list) == 1: |
| return augs_list[0] |
| else: |
| return random_type(augs_list, p) |
|
|