Comparative-Analysis-of-Speech-Synthesis-Models
/
TensorFlowTTS
/tensorflow_tts
/utils
/griffin_lim.py
| # -*- coding: utf-8 -*- | |
| # Copyright 2020 Minh Nguyen (@dathudeptrai) | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Griffin-Lim phase reconstruction algorithm from mel spectrogram.""" | |
| import os | |
| import librosa | |
| import numpy as np | |
| import soundfile as sf | |
| import tensorflow as tf | |
| from sklearn.preprocessing import StandardScaler | |
| def griffin_lim_lb( | |
| mel_spec, stats_path, dataset_config, n_iter=32, output_dir=None, wav_name="lb" | |
| ): | |
| """Generate wave from mel spectrogram with Griffin-Lim algorithm using Librosa. | |
| Args: | |
| mel_spec (ndarray): array representing the mel spectrogram. | |
| stats_path (str): path to the `stats.npy` file containing norm statistics. | |
| dataset_config (Dict): dataset configuration parameters. | |
| n_iter (int): number of iterations for GL. | |
| output_dir (str): output directory where audio file will be saved. | |
| wav_name (str): name of the output file. | |
| Returns: | |
| gl_lb (ndarray): generated wave. | |
| """ | |
| scaler = StandardScaler() | |
| scaler.mean_, scaler.scale_ = np.load(stats_path) | |
| mel_spec = np.power(10.0, scaler.inverse_transform(mel_spec)).T | |
| mel_basis = librosa.filters.mel( | |
| dataset_config["sampling_rate"], | |
| n_fft=dataset_config["fft_size"], | |
| n_mels=dataset_config["num_mels"], | |
| fmin=dataset_config["fmin"], | |
| fmax=dataset_config["fmax"], | |
| ) | |
| mel_to_linear = np.maximum(1e-10, np.dot(np.linalg.pinv(mel_basis), mel_spec)) | |
| gl_lb = librosa.griffinlim( | |
| mel_to_linear, | |
| n_iter=n_iter, | |
| hop_length=dataset_config["hop_size"], | |
| win_length=dataset_config["win_length"] or dataset_config["fft_size"], | |
| ) | |
| if output_dir: | |
| output_path = os.path.join(output_dir, f"{wav_name}.wav") | |
| sf.write(output_path, gl_lb, dataset_config["sampling_rate"], "PCM_16") | |
| return gl_lb | |
| class TFGriffinLim(tf.keras.layers.Layer): | |
| """Griffin-Lim algorithm for phase reconstruction from mel spectrogram magnitude.""" | |
| def __init__(self, stats_path, dataset_config, normalized: bool = True): | |
| """Init GL params. | |
| Args: | |
| stats_path (str): path to the `stats.npy` file containing norm statistics. | |
| dataset_config (Dict): dataset configuration parameters. | |
| """ | |
| super().__init__() | |
| self.normalized = normalized | |
| if normalized: | |
| scaler = StandardScaler() | |
| scaler.mean_, scaler.scale_ = np.load(stats_path) | |
| self.scaler = scaler | |
| self.ds_config = dataset_config | |
| self.mel_basis = librosa.filters.mel( | |
| self.ds_config["sampling_rate"], | |
| n_fft=self.ds_config["fft_size"], | |
| n_mels=self.ds_config["num_mels"], | |
| fmin=self.ds_config["fmin"], | |
| fmax=self.ds_config["fmax"], | |
| ) # [num_mels, fft_size // 2 + 1] | |
| def save_wav(self, gl_tf, output_dir, wav_name): | |
| """Generate WAV file and save it. | |
| Args: | |
| gl_tf (tf.Tensor): reconstructed signal from GL algorithm. | |
| output_dir (str): output directory where audio file will be saved. | |
| wav_name (str): name of the output file. | |
| """ | |
| encode_fn = lambda x: tf.audio.encode_wav(x, self.ds_config["sampling_rate"]) | |
| gl_tf = tf.expand_dims(gl_tf, -1) | |
| if not isinstance(wav_name, list): | |
| wav_name = [wav_name] | |
| if len(gl_tf.shape) > 2: | |
| bs, *_ = gl_tf.shape | |
| assert bs == len(wav_name), "Batch and 'wav_name' have different size." | |
| tf_wav = tf.map_fn(encode_fn, gl_tf, dtype=tf.string) | |
| for idx in tf.range(bs): | |
| output_path = os.path.join(output_dir, f"{wav_name[idx]}.wav") | |
| tf.io.write_file(output_path, tf_wav[idx]) | |
| else: | |
| tf_wav = encode_fn(gl_tf) | |
| tf.io.write_file(os.path.join(output_dir, f"{wav_name[0]}.wav"), tf_wav) | |
| def call(self, mel_spec, n_iter=32): | |
| """Apply GL algorithm to batched mel spectrograms. | |
| Args: | |
| mel_spec (tf.Tensor): normalized mel spectrogram. | |
| n_iter (int): number of iterations to run GL algorithm. | |
| Returns: | |
| (tf.Tensor): reconstructed signal from GL algorithm. | |
| """ | |
| # de-normalize mel spectogram | |
| if self.normalized: | |
| mel_spec = tf.math.pow( | |
| 10.0, mel_spec * self.scaler.scale_ + self.scaler.mean_ | |
| ) | |
| else: | |
| mel_spec = tf.math.pow( | |
| 10.0, mel_spec | |
| ) # TODO @dathudeptrai check if its ok without it wavs were too quiet | |
| inverse_mel = tf.linalg.pinv(self.mel_basis) | |
| # [:, num_mels] @ [fft_size // 2 + 1, num_mels].T | |
| mel_to_linear = tf.linalg.matmul(mel_spec, inverse_mel, transpose_b=True) | |
| mel_to_linear = tf.cast(tf.math.maximum(1e-10, mel_to_linear), tf.complex64) | |
| init_phase = tf.cast( | |
| tf.random.uniform(tf.shape(mel_to_linear), maxval=1), tf.complex64 | |
| ) | |
| phase = tf.math.exp(2j * np.pi * init_phase) | |
| for _ in tf.range(n_iter): | |
| inverse = tf.signal.inverse_stft( | |
| mel_to_linear * phase, | |
| frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"], | |
| frame_step=self.ds_config["hop_size"], | |
| fft_length=self.ds_config["fft_size"], | |
| window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]), | |
| ) | |
| phase = tf.signal.stft( | |
| inverse, | |
| self.ds_config["win_length"] or self.ds_config["fft_size"], | |
| self.ds_config["hop_size"], | |
| self.ds_config["fft_size"], | |
| ) | |
| phase /= tf.cast(tf.maximum(1e-10, tf.abs(phase)), tf.complex64) | |
| return tf.signal.inverse_stft( | |
| mel_to_linear * phase, | |
| frame_length=self.ds_config["win_length"] or self.ds_config["fft_size"], | |
| frame_step=self.ds_config["hop_size"], | |
| fft_length=self.ds_config["fft_size"], | |
| window_fn=tf.signal.inverse_stft_window_fn(self.ds_config["hop_size"]), | |
| ) | |