| | import glob |
| | import logging |
| | import os |
| | import random |
| |
|
| | import numpy as np |
| | from scipy import signal |
| |
|
| | from TTS.encoder.models.base_encoder import BaseEncoder |
| | from TTS.encoder.models.lstm import LSTMSpeakerEncoder |
| | from TTS.encoder.models.resnet import ResNetSpeakerEncoder |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| |
|
| | class AugmentWAV: |
| | def __init__(self, ap, augmentation_config): |
| | self.ap = ap |
| | self.use_additive_noise = False |
| |
|
| | if "additive" in augmentation_config.keys(): |
| | self.additive_noise_config = augmentation_config["additive"] |
| | additive_path = self.additive_noise_config["sounds_path"] |
| | if additive_path: |
| | self.use_additive_noise = True |
| | |
| | self.additive_noise_types = [] |
| | for key in self.additive_noise_config.keys(): |
| | if isinstance(self.additive_noise_config[key], dict): |
| | self.additive_noise_types.append(key) |
| |
|
| | additive_files = glob.glob(os.path.join(additive_path, "**/*.wav"), recursive=True) |
| |
|
| | self.noise_list = {} |
| |
|
| | for wav_file in additive_files: |
| | noise_dir = wav_file.replace(additive_path, "").split(os.sep)[0] |
| | |
| | if noise_dir not in self.additive_noise_types: |
| | continue |
| | if noise_dir not in self.noise_list: |
| | self.noise_list[noise_dir] = [] |
| | self.noise_list[noise_dir].append(wav_file) |
| |
|
| | logger.info( |
| | "Using Additive Noise Augmentation: with %d audios instances from %s", |
| | len(additive_files), |
| | self.additive_noise_types, |
| | ) |
| |
|
| | self.use_rir = False |
| |
|
| | if "rir" in augmentation_config.keys(): |
| | self.rir_config = augmentation_config["rir"] |
| | if self.rir_config["rir_path"]: |
| | self.rir_files = glob.glob(os.path.join(self.rir_config["rir_path"], "**/*.wav"), recursive=True) |
| | self.use_rir = True |
| |
|
| | logger.info("Using RIR Noise Augmentation: with %d audios instances", len(self.rir_files)) |
| |
|
| | self.create_augmentation_global_list() |
| |
|
| | def create_augmentation_global_list(self): |
| | if self.use_additive_noise: |
| | self.global_noise_list = self.additive_noise_types |
| | else: |
| | self.global_noise_list = [] |
| | if self.use_rir: |
| | self.global_noise_list.append("RIR_AUG") |
| |
|
| | def additive_noise(self, noise_type, audio): |
| | clean_db = 10 * np.log10(np.mean(audio**2) + 1e-4) |
| |
|
| | noise_list = random.sample( |
| | self.noise_list[noise_type], |
| | random.randint( |
| | self.additive_noise_config[noise_type]["min_num_noises"], |
| | self.additive_noise_config[noise_type]["max_num_noises"], |
| | ), |
| | ) |
| |
|
| | audio_len = audio.shape[0] |
| | noises_wav = None |
| | for noise in noise_list: |
| | noiseaudio = self.ap.load_wav(noise, sr=self.ap.sample_rate)[:audio_len] |
| |
|
| | if noiseaudio.shape[0] < audio_len: |
| | continue |
| |
|
| | noise_snr = random.uniform( |
| | self.additive_noise_config[noise_type]["min_snr_in_db"], |
| | self.additive_noise_config[noise_type]["max_num_noises"], |
| | ) |
| | noise_db = 10 * np.log10(np.mean(noiseaudio**2) + 1e-4) |
| | noise_wav = np.sqrt(10 ** ((clean_db - noise_db - noise_snr) / 10)) * noiseaudio |
| |
|
| | if noises_wav is None: |
| | noises_wav = noise_wav |
| | else: |
| | noises_wav += noise_wav |
| |
|
| | |
| | if noises_wav is None: |
| | return self.additive_noise(noise_type, audio) |
| |
|
| | return audio + noises_wav |
| |
|
| | def reverberate(self, audio): |
| | audio_len = audio.shape[0] |
| |
|
| | rir_file = random.choice(self.rir_files) |
| | rir = self.ap.load_wav(rir_file, sr=self.ap.sample_rate) |
| | rir = rir / np.sqrt(np.sum(rir**2)) |
| | return signal.convolve(audio, rir, mode=self.rir_config["conv_mode"])[:audio_len] |
| |
|
| | def apply_one(self, audio): |
| | noise_type = random.choice(self.global_noise_list) |
| | if noise_type == "RIR_AUG": |
| | return self.reverberate(audio) |
| |
|
| | return self.additive_noise(noise_type, audio) |
| |
|
| |
|
| | def setup_encoder_model(config: "Coqpit") -> BaseEncoder: |
| | if config.model_params["model_name"].lower() == "lstm": |
| | model = LSTMSpeakerEncoder( |
| | config.model_params["input_dim"], |
| | config.model_params["proj_dim"], |
| | config.model_params["lstm_dim"], |
| | config.model_params["num_lstm_layers"], |
| | use_torch_spec=config.model_params.get("use_torch_spec", False), |
| | audio_config=config.audio, |
| | ) |
| | elif config.model_params["model_name"].lower() == "resnet": |
| | model = ResNetSpeakerEncoder( |
| | input_dim=config.model_params["input_dim"], |
| | proj_dim=config.model_params["proj_dim"], |
| | log_input=config.model_params.get("log_input", False), |
| | use_torch_spec=config.model_params.get("use_torch_spec", False), |
| | audio_config=config.audio, |
| | ) |
| | else: |
| | msg = f"Model not supported: {config.model_params['model_name']}" |
| | raise ValueError(msg) |
| | return model |
| |
|