Spaces:
Sleeping
Sleeping
Upload 18 files
Browse files- app.py +58 -0
- audio_processing.py +93 -0
- config.json +38 -0
- data_utils.py +111 -0
- distributed.py +173 -0
- hparams.py +106 -0
- inference.ipynb +0 -0
- layers.py +80 -0
- logger.py +48 -0
- logic.py +100 -0
- loss_function.py +19 -0
- loss_scaler.py +131 -0
- model.py +529 -0
- multiproc.py +23 -0
- plotting_utils.py +61 -0
- stft.py +141 -0
- train.py +290 -0
- utils.py +29 -0
app.py
ADDED
|
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from flask import Flask, render_template, request, jsonify
|
| 2 |
+
from logic import synthesize_voice, plot_data, plot_waveforms
|
| 3 |
+
import base64
|
| 4 |
+
|
| 5 |
+
app = Flask(__name__)
|
| 6 |
+
|
| 7 |
+
# Hugging Face model URLs
|
| 8 |
+
tacotron2_model_url = "your_tacotron2_model_url"
|
| 9 |
+
hifi_gan_model_url = "your_hifi_gan_model_url"
|
| 10 |
+
|
| 11 |
+
# You need to replace the placeholders above with the actual URLs for the models.
|
| 12 |
+
|
| 13 |
+
@app.route('/')
|
| 14 |
+
def index():
|
| 15 |
+
return render_template('index.html')
|
| 16 |
+
|
| 17 |
+
@app.route('/synthesize', methods=['POST'])
|
| 18 |
+
def synthesize():
|
| 19 |
+
font_type = request.json['font_select']
|
| 20 |
+
input_text = request.json['input_text']
|
| 21 |
+
|
| 22 |
+
# Font selection logic (you can customize this based on your requirements)
|
| 23 |
+
if font_type == 'Preeti':
|
| 24 |
+
# Implement Preeti font logic
|
| 25 |
+
pass
|
| 26 |
+
elif font_type == 'Unicode':
|
| 27 |
+
# Implement Unicode font logic
|
| 28 |
+
pass
|
| 29 |
+
|
| 30 |
+
# Generate mel-spectrogram using Tacotron2
|
| 31 |
+
mel_output_data, mel_output_postnet_data, alignments_data = synthesize_voice(input_text, "Shruti_finetuned")
|
| 32 |
+
|
| 33 |
+
# Convert mel-spectrogram to base64 for display in HTML
|
| 34 |
+
mel_output_base64 = plot_data([mel_output_data, mel_output_postnet_data, alignments_data])
|
| 35 |
+
|
| 36 |
+
# Save the generated audio file
|
| 37 |
+
audio_file_path = 'audio_output/mel1_generated_e2e.wav'
|
| 38 |
+
|
| 39 |
+
# Plot the waveform
|
| 40 |
+
wave_base64 = plot_waveforms(audio_file_path)
|
| 41 |
+
|
| 42 |
+
# Encode audio content as Base64
|
| 43 |
+
with open(audio_file_path, 'rb') as audio_file:
|
| 44 |
+
audio_base64 = base64.b64encode(audio_file.read()).decode('utf-8')
|
| 45 |
+
|
| 46 |
+
#You can customize the response based on what information you want to send to the frontend
|
| 47 |
+
response_data = {
|
| 48 |
+
'mel_spectrogram': mel_output_base64,
|
| 49 |
+
'audio_data': audio_base64,
|
| 50 |
+
'waveform' : wave_base64,
|
| 51 |
+
'some_other_data': 'example_value',
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
return jsonify(response_data)
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
if __name__ == '__main__':
|
| 58 |
+
app.run(debug=True, threaded=True)
|
audio_processing.py
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.signal import get_window
|
| 4 |
+
import librosa.util as librosa_util
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def window_sumsquare(window, n_frames, hop_length=200, win_length=800,
|
| 8 |
+
n_fft=800, dtype=np.float32, norm=None):
|
| 9 |
+
"""
|
| 10 |
+
# from librosa 0.6
|
| 11 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
| 12 |
+
|
| 13 |
+
This is used to estimate modulation effects induced by windowing
|
| 14 |
+
observations in short-time fourier transforms.
|
| 15 |
+
|
| 16 |
+
Parameters
|
| 17 |
+
----------
|
| 18 |
+
window : string, tuple, number, callable, or list-like
|
| 19 |
+
Window specification, as in `get_window`
|
| 20 |
+
|
| 21 |
+
n_frames : int > 0
|
| 22 |
+
The number of analysis frames
|
| 23 |
+
|
| 24 |
+
hop_length : int > 0
|
| 25 |
+
The number of samples to advance between frames
|
| 26 |
+
|
| 27 |
+
win_length : [optional]
|
| 28 |
+
The length of the window function. By default, this matches `n_fft`.
|
| 29 |
+
|
| 30 |
+
n_fft : int > 0
|
| 31 |
+
The length of each analysis frame.
|
| 32 |
+
|
| 33 |
+
dtype : np.dtype
|
| 34 |
+
The data type of the output
|
| 35 |
+
|
| 36 |
+
Returns
|
| 37 |
+
-------
|
| 38 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
| 39 |
+
The sum-squared envelope of the window function
|
| 40 |
+
"""
|
| 41 |
+
if win_length is None:
|
| 42 |
+
win_length = n_fft
|
| 43 |
+
|
| 44 |
+
n = n_fft + hop_length * (n_frames - 1)
|
| 45 |
+
x = np.zeros(n, dtype=dtype)
|
| 46 |
+
|
| 47 |
+
# Compute the squared window at the desired length
|
| 48 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
| 49 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm)**2
|
| 50 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
| 51 |
+
|
| 52 |
+
# Fill the envelope
|
| 53 |
+
for i in range(n_frames):
|
| 54 |
+
sample = i * hop_length
|
| 55 |
+
x[sample:min(n, sample + n_fft)] += win_sq[:max(0, min(n_fft, n - sample))]
|
| 56 |
+
return x
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
| 60 |
+
"""
|
| 61 |
+
PARAMS
|
| 62 |
+
------
|
| 63 |
+
magnitudes: spectrogram magnitudes
|
| 64 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
| 65 |
+
"""
|
| 66 |
+
|
| 67 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
| 68 |
+
angles = angles.astype(np.float32)
|
| 69 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
| 70 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
| 71 |
+
|
| 72 |
+
for i in range(n_iters):
|
| 73 |
+
_, angles = stft_fn.transform(signal)
|
| 74 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
| 75 |
+
return signal
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 79 |
+
"""
|
| 80 |
+
PARAMS
|
| 81 |
+
------
|
| 82 |
+
C: compression factor
|
| 83 |
+
"""
|
| 84 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def dynamic_range_decompression(x, C=1):
|
| 88 |
+
"""
|
| 89 |
+
PARAMS
|
| 90 |
+
------
|
| 91 |
+
C: compression factor used to compress
|
| 92 |
+
"""
|
| 93 |
+
return torch.exp(x) / C
|
config.json
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"resblock": "1",
|
| 3 |
+
"num_gpus": 0,
|
| 4 |
+
"batch_size": 16,
|
| 5 |
+
"learning_rate": 0.0002,
|
| 6 |
+
"adam_b1": 0.8,
|
| 7 |
+
"adam_b2": 0.99,
|
| 8 |
+
"lr_decay": 0.999,
|
| 9 |
+
"seed": 1234,
|
| 10 |
+
|
| 11 |
+
"upsample_rates": [8,8,2,2],
|
| 12 |
+
"upsample_kernel_sizes": [16,16,4,4],
|
| 13 |
+
"upsample_initial_channel": 512,
|
| 14 |
+
"resblock_kernel_sizes": [3,7,11],
|
| 15 |
+
"resblock_dilation_sizes": [[1,3,5], [1,3,5], [1,3,5]],
|
| 16 |
+
"resblock_initial_channel": 256,
|
| 17 |
+
|
| 18 |
+
"segment_size": 8192,
|
| 19 |
+
"num_mels": 80,
|
| 20 |
+
"num_freq": 1025,
|
| 21 |
+
"n_fft": 1024,
|
| 22 |
+
"hop_size": 256,
|
| 23 |
+
"win_size": 1024,
|
| 24 |
+
|
| 25 |
+
"sampling_rate": 22050,
|
| 26 |
+
|
| 27 |
+
"fmin": 0,
|
| 28 |
+
"fmax": 8000,
|
| 29 |
+
"fmax_loss": null,
|
| 30 |
+
|
| 31 |
+
"num_workers": 4,
|
| 32 |
+
|
| 33 |
+
"dist_config": {
|
| 34 |
+
"dist_backend": "nccl",
|
| 35 |
+
"dist_url": "tcp://localhost:54321",
|
| 36 |
+
"world_size": 1
|
| 37 |
+
}
|
| 38 |
+
}
|
data_utils.py
ADDED
|
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.data
|
| 5 |
+
|
| 6 |
+
import layers
|
| 7 |
+
from utils import load_wav_to_torch, load_filepaths_and_text
|
| 8 |
+
from text import text_to_sequence
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class TextMelLoader(torch.utils.data.Dataset):
|
| 12 |
+
"""
|
| 13 |
+
1) loads audio,text pairs
|
| 14 |
+
2) normalizes text and converts them to sequences of one-hot vectors
|
| 15 |
+
3) computes mel-spectrograms from audio files.
|
| 16 |
+
"""
|
| 17 |
+
def __init__(self, audiopaths_and_text, hparams):
|
| 18 |
+
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
| 19 |
+
self.text_cleaners = hparams.text_cleaners
|
| 20 |
+
self.max_wav_value = hparams.max_wav_value
|
| 21 |
+
self.sampling_rate = hparams.sampling_rate
|
| 22 |
+
self.load_mel_from_disk = hparams.load_mel_from_disk
|
| 23 |
+
self.stft = layers.TacotronSTFT(
|
| 24 |
+
hparams.filter_length, hparams.hop_length, hparams.win_length,
|
| 25 |
+
hparams.n_mel_channels, hparams.sampling_rate, hparams.mel_fmin,
|
| 26 |
+
hparams.mel_fmax)
|
| 27 |
+
random.seed(hparams.seed)
|
| 28 |
+
random.shuffle(self.audiopaths_and_text)
|
| 29 |
+
|
| 30 |
+
def get_mel_text_pair(self, audiopath_and_text):
|
| 31 |
+
# separate filename and text
|
| 32 |
+
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
| 33 |
+
text = self.get_text(text)
|
| 34 |
+
mel = self.get_mel(audiopath)
|
| 35 |
+
return (text, mel)
|
| 36 |
+
|
| 37 |
+
def get_mel(self, filename):
|
| 38 |
+
if not self.load_mel_from_disk:
|
| 39 |
+
audio, sampling_rate = load_wav_to_torch(filename)
|
| 40 |
+
if sampling_rate != self.stft.sampling_rate:
|
| 41 |
+
raise ValueError("{} {} SR doesn't match target {} SR".format(
|
| 42 |
+
sampling_rate, self.stft.sampling_rate))
|
| 43 |
+
audio_norm = audio / self.max_wav_value
|
| 44 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 45 |
+
audio_norm = torch.autograd.Variable(audio_norm, requires_grad=False)
|
| 46 |
+
melspec = self.stft.mel_spectrogram(audio_norm)
|
| 47 |
+
melspec = torch.squeeze(melspec, 0)
|
| 48 |
+
else:
|
| 49 |
+
melspec = torch.from_numpy(np.load(filename))
|
| 50 |
+
assert melspec.size(0) == self.stft.n_mel_channels, (
|
| 51 |
+
'Mel dimension mismatch: given {}, expected {}'.format(
|
| 52 |
+
melspec.size(0), self.stft.n_mel_channels))
|
| 53 |
+
|
| 54 |
+
return melspec
|
| 55 |
+
|
| 56 |
+
def get_text(self, text):
|
| 57 |
+
text_norm = torch.IntTensor(text_to_sequence(text, self.text_cleaners))
|
| 58 |
+
return text_norm
|
| 59 |
+
|
| 60 |
+
def __getitem__(self, index):
|
| 61 |
+
return self.get_mel_text_pair(self.audiopaths_and_text[index])
|
| 62 |
+
|
| 63 |
+
def __len__(self):
|
| 64 |
+
return len(self.audiopaths_and_text)
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
class TextMelCollate():
|
| 68 |
+
""" Zero-pads model inputs and targets based on number of frames per setep
|
| 69 |
+
"""
|
| 70 |
+
def __init__(self, n_frames_per_step):
|
| 71 |
+
self.n_frames_per_step = n_frames_per_step
|
| 72 |
+
|
| 73 |
+
def __call__(self, batch):
|
| 74 |
+
"""Collate's training batch from normalized text and mel-spectrogram
|
| 75 |
+
PARAMS
|
| 76 |
+
------
|
| 77 |
+
batch: [text_normalized, mel_normalized]
|
| 78 |
+
"""
|
| 79 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 80 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
| 81 |
+
torch.LongTensor([len(x[0]) for x in batch]),
|
| 82 |
+
dim=0, descending=True)
|
| 83 |
+
max_input_len = input_lengths[0]
|
| 84 |
+
|
| 85 |
+
text_padded = torch.LongTensor(len(batch), max_input_len)
|
| 86 |
+
text_padded.zero_()
|
| 87 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 88 |
+
text = batch[ids_sorted_decreasing[i]][0]
|
| 89 |
+
text_padded[i, :text.size(0)] = text
|
| 90 |
+
|
| 91 |
+
# Right zero-pad mel-spec
|
| 92 |
+
num_mels = batch[0][1].size(0)
|
| 93 |
+
max_target_len = max([x[1].size(1) for x in batch])
|
| 94 |
+
if max_target_len % self.n_frames_per_step != 0:
|
| 95 |
+
max_target_len += self.n_frames_per_step - max_target_len % self.n_frames_per_step
|
| 96 |
+
assert max_target_len % self.n_frames_per_step == 0
|
| 97 |
+
|
| 98 |
+
# include mel padded and gate padded
|
| 99 |
+
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
| 100 |
+
mel_padded.zero_()
|
| 101 |
+
gate_padded = torch.FloatTensor(len(batch), max_target_len)
|
| 102 |
+
gate_padded.zero_()
|
| 103 |
+
output_lengths = torch.LongTensor(len(batch))
|
| 104 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 105 |
+
mel = batch[ids_sorted_decreasing[i]][1]
|
| 106 |
+
mel_padded[i, :, :mel.size(1)] = mel
|
| 107 |
+
gate_padded[i, mel.size(1)-1:] = 1
|
| 108 |
+
output_lengths[i] = mel.size(1)
|
| 109 |
+
|
| 110 |
+
return text_padded, input_lengths, mel_padded, gate_padded, \
|
| 111 |
+
output_lengths
|
distributed.py
ADDED
|
@@ -0,0 +1,173 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.distributed as dist
|
| 3 |
+
from torch.nn.modules import Module
|
| 4 |
+
from torch.autograd import Variable
|
| 5 |
+
|
| 6 |
+
def _flatten_dense_tensors(tensors):
|
| 7 |
+
"""Flatten dense tensors into a contiguous 1D buffer. Assume tensors are of
|
| 8 |
+
same dense type.
|
| 9 |
+
Since inputs are dense, the resulting tensor will be a concatenated 1D
|
| 10 |
+
buffer. Element-wise operation on this buffer will be equivalent to
|
| 11 |
+
operating individually.
|
| 12 |
+
Arguments:
|
| 13 |
+
tensors (Iterable[Tensor]): dense tensors to flatten.
|
| 14 |
+
Returns:
|
| 15 |
+
A contiguous 1D buffer containing input tensors.
|
| 16 |
+
"""
|
| 17 |
+
if len(tensors) == 1:
|
| 18 |
+
return tensors[0].contiguous().view(-1)
|
| 19 |
+
flat = torch.cat([t.contiguous().view(-1) for t in tensors], dim=0)
|
| 20 |
+
return flat
|
| 21 |
+
|
| 22 |
+
def _unflatten_dense_tensors(flat, tensors):
|
| 23 |
+
"""View a flat buffer using the sizes of tensors. Assume that tensors are of
|
| 24 |
+
same dense type, and that flat is given by _flatten_dense_tensors.
|
| 25 |
+
Arguments:
|
| 26 |
+
flat (Tensor): flattened dense tensors to unflatten.
|
| 27 |
+
tensors (Iterable[Tensor]): dense tensors whose sizes will be used to
|
| 28 |
+
unflatten flat.
|
| 29 |
+
Returns:
|
| 30 |
+
Unflattened dense tensors with sizes same as tensors and values from
|
| 31 |
+
flat.
|
| 32 |
+
"""
|
| 33 |
+
outputs = []
|
| 34 |
+
offset = 0
|
| 35 |
+
for tensor in tensors:
|
| 36 |
+
numel = tensor.numel()
|
| 37 |
+
outputs.append(flat.narrow(0, offset, numel).view_as(tensor))
|
| 38 |
+
offset += numel
|
| 39 |
+
return tuple(outputs)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
'''
|
| 43 |
+
This version of DistributedDataParallel is designed to be used in conjunction with the multiproc.py
|
| 44 |
+
launcher included with this example. It assumes that your run is using multiprocess with 1
|
| 45 |
+
GPU/process, that the model is on the correct device, and that torch.set_device has been
|
| 46 |
+
used to set the device.
|
| 47 |
+
|
| 48 |
+
Parameters are broadcasted to the other processes on initialization of DistributedDataParallel,
|
| 49 |
+
and will be allreduced at the finish of the backward pass.
|
| 50 |
+
'''
|
| 51 |
+
class DistributedDataParallel(Module):
|
| 52 |
+
|
| 53 |
+
def __init__(self, module):
|
| 54 |
+
super(DistributedDataParallel, self).__init__()
|
| 55 |
+
#fallback for PyTorch 0.3
|
| 56 |
+
if not hasattr(dist, '_backend'):
|
| 57 |
+
self.warn_on_half = True
|
| 58 |
+
else:
|
| 59 |
+
self.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
|
| 60 |
+
|
| 61 |
+
self.module = module
|
| 62 |
+
|
| 63 |
+
for p in self.module.state_dict().values():
|
| 64 |
+
if not torch.is_tensor(p):
|
| 65 |
+
continue
|
| 66 |
+
dist.broadcast(p, 0)
|
| 67 |
+
|
| 68 |
+
def allreduce_params():
|
| 69 |
+
if(self.needs_reduction):
|
| 70 |
+
self.needs_reduction = False
|
| 71 |
+
buckets = {}
|
| 72 |
+
for param in self.module.parameters():
|
| 73 |
+
if param.requires_grad and param.grad is not None:
|
| 74 |
+
tp = type(param.data)
|
| 75 |
+
if tp not in buckets:
|
| 76 |
+
buckets[tp] = []
|
| 77 |
+
buckets[tp].append(param)
|
| 78 |
+
if self.warn_on_half:
|
| 79 |
+
if torch.cuda.HalfTensor in buckets:
|
| 80 |
+
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
|
| 81 |
+
" It is recommended to use the NCCL backend in this case. This currently requires" +
|
| 82 |
+
"PyTorch built from top of tree master.")
|
| 83 |
+
self.warn_on_half = False
|
| 84 |
+
|
| 85 |
+
for tp in buckets:
|
| 86 |
+
bucket = buckets[tp]
|
| 87 |
+
grads = [param.grad.data for param in bucket]
|
| 88 |
+
coalesced = _flatten_dense_tensors(grads)
|
| 89 |
+
dist.all_reduce(coalesced)
|
| 90 |
+
coalesced /= dist.get_world_size()
|
| 91 |
+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
| 92 |
+
buf.copy_(synced)
|
| 93 |
+
|
| 94 |
+
for param in list(self.module.parameters()):
|
| 95 |
+
def allreduce_hook(*unused):
|
| 96 |
+
param._execution_engine.queue_callback(allreduce_params)
|
| 97 |
+
if param.requires_grad:
|
| 98 |
+
param.register_hook(allreduce_hook)
|
| 99 |
+
|
| 100 |
+
def forward(self, *inputs, **kwargs):
|
| 101 |
+
self.needs_reduction = True
|
| 102 |
+
return self.module(*inputs, **kwargs)
|
| 103 |
+
|
| 104 |
+
'''
|
| 105 |
+
def _sync_buffers(self):
|
| 106 |
+
buffers = list(self.module._all_buffers())
|
| 107 |
+
if len(buffers) > 0:
|
| 108 |
+
# cross-node buffer sync
|
| 109 |
+
flat_buffers = _flatten_dense_tensors(buffers)
|
| 110 |
+
dist.broadcast(flat_buffers, 0)
|
| 111 |
+
for buf, synced in zip(buffers, _unflatten_dense_tensors(flat_buffers, buffers)):
|
| 112 |
+
buf.copy_(synced)
|
| 113 |
+
def train(self, mode=True):
|
| 114 |
+
# Clear NCCL communicator and CUDA event cache of the default group ID,
|
| 115 |
+
# These cache will be recreated at the later call. This is currently a
|
| 116 |
+
# work-around for a potential NCCL deadlock.
|
| 117 |
+
if dist._backend == dist.dist_backend.NCCL:
|
| 118 |
+
dist._clear_group_cache()
|
| 119 |
+
super(DistributedDataParallel, self).train(mode)
|
| 120 |
+
self.module.train(mode)
|
| 121 |
+
'''
|
| 122 |
+
'''
|
| 123 |
+
Modifies existing model to do gradient allreduce, but doesn't change class
|
| 124 |
+
so you don't need "module"
|
| 125 |
+
'''
|
| 126 |
+
def apply_gradient_allreduce(module):
|
| 127 |
+
if not hasattr(dist, '_backend'):
|
| 128 |
+
module.warn_on_half = True
|
| 129 |
+
else:
|
| 130 |
+
module.warn_on_half = True if dist._backend == dist.dist_backend.GLOO else False
|
| 131 |
+
|
| 132 |
+
for p in module.state_dict().values():
|
| 133 |
+
if not torch.is_tensor(p):
|
| 134 |
+
continue
|
| 135 |
+
dist.broadcast(p, 0)
|
| 136 |
+
|
| 137 |
+
def allreduce_params():
|
| 138 |
+
if(module.needs_reduction):
|
| 139 |
+
module.needs_reduction = False
|
| 140 |
+
buckets = {}
|
| 141 |
+
for param in module.parameters():
|
| 142 |
+
if param.requires_grad and param.grad is not None:
|
| 143 |
+
tp = param.data.dtype
|
| 144 |
+
if tp not in buckets:
|
| 145 |
+
buckets[tp] = []
|
| 146 |
+
buckets[tp].append(param)
|
| 147 |
+
if module.warn_on_half:
|
| 148 |
+
if torch.cuda.HalfTensor in buckets:
|
| 149 |
+
print("WARNING: gloo dist backend for half parameters may be extremely slow." +
|
| 150 |
+
" It is recommended to use the NCCL backend in this case. This currently requires" +
|
| 151 |
+
"PyTorch built from top of tree master.")
|
| 152 |
+
module.warn_on_half = False
|
| 153 |
+
|
| 154 |
+
for tp in buckets:
|
| 155 |
+
bucket = buckets[tp]
|
| 156 |
+
grads = [param.grad.data for param in bucket]
|
| 157 |
+
coalesced = _flatten_dense_tensors(grads)
|
| 158 |
+
dist.all_reduce(coalesced)
|
| 159 |
+
coalesced /= dist.get_world_size()
|
| 160 |
+
for buf, synced in zip(grads, _unflatten_dense_tensors(coalesced, grads)):
|
| 161 |
+
buf.copy_(synced)
|
| 162 |
+
|
| 163 |
+
for param in list(module.parameters()):
|
| 164 |
+
def allreduce_hook(*unused):
|
| 165 |
+
Variable._execution_engine.queue_callback(allreduce_params)
|
| 166 |
+
if param.requires_grad:
|
| 167 |
+
param.register_hook(allreduce_hook)
|
| 168 |
+
|
| 169 |
+
def set_needs_reduction(self, input, output):
|
| 170 |
+
self.needs_reduction = True
|
| 171 |
+
|
| 172 |
+
module.register_forward_hook(set_needs_reduction)
|
| 173 |
+
return module
|
hparams.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from text import symbols
|
| 2 |
+
|
| 3 |
+
class AttrDict(dict):
|
| 4 |
+
def __init__(self, *args, **kwargs):
|
| 5 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 6 |
+
self.__dict__ = self
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_hparams(hparams_string=None, verbose=False):
|
| 10 |
+
"""Create model hyperparameters. Parse nondefault from given string."""
|
| 11 |
+
|
| 12 |
+
hparams = AttrDict({
|
| 13 |
+
################################
|
| 14 |
+
# Experiment Parameters #
|
| 15 |
+
################################
|
| 16 |
+
"epochs":1500,
|
| 17 |
+
"iters_per_checkpoint":500,
|
| 18 |
+
"seed":1234,
|
| 19 |
+
"dynamic_loss_scaling":True,
|
| 20 |
+
"fp16_run":False,
|
| 21 |
+
"distributed_run":False,
|
| 22 |
+
"dist_backend":"nccl",
|
| 23 |
+
"dist_url":"tcp://localhost:14897",
|
| 24 |
+
"cudnn_enabled":True,
|
| 25 |
+
"cudnn_benchmark":False,
|
| 26 |
+
"ignore_layers":['embedding.weight'],
|
| 27 |
+
# freeze_layers":['encoder'], # Freeze tacotron2 layer for finetuning
|
| 28 |
+
|
| 29 |
+
################################
|
| 30 |
+
# Data Parameters #
|
| 31 |
+
################################
|
| 32 |
+
"load_mel_from_disk":False,
|
| 33 |
+
"load_phone_from_disk":True,
|
| 34 |
+
|
| 35 |
+
"training_files":'filelists/train_files.txt',
|
| 36 |
+
"validation_files":'filelists/val_files.txt',
|
| 37 |
+
|
| 38 |
+
"text_cleaners":['transliteration_cleaners'],
|
| 39 |
+
|
| 40 |
+
################################
|
| 41 |
+
# Audio Parameters #
|
| 42 |
+
################################
|
| 43 |
+
"max_wav_value":32768.0,
|
| 44 |
+
"sampling_rate":22050,
|
| 45 |
+
"filter_length":1024,
|
| 46 |
+
"hop_length":256,
|
| 47 |
+
"win_length":1024,
|
| 48 |
+
"n_mel_channels":80,
|
| 49 |
+
"mel_fmin":0.0,
|
| 50 |
+
"mel_fmax":8000.0,
|
| 51 |
+
|
| 52 |
+
################################
|
| 53 |
+
# Model Parameters #
|
| 54 |
+
################################
|
| 55 |
+
"n_symbols": len(symbols),
|
| 56 |
+
"symbols_embedding_dim":512,
|
| 57 |
+
"alignloss": "L2",
|
| 58 |
+
"attention": "StepwiseMonotonicAttention",
|
| 59 |
+
|
| 60 |
+
# Encoder parameters
|
| 61 |
+
"encoder_kernel_size":5,
|
| 62 |
+
"encoder_n_convolutions":3,
|
| 63 |
+
"encoder_embedding_dim":512,
|
| 64 |
+
|
| 65 |
+
# Decoder parameters
|
| 66 |
+
"n_frames_per_step":1, # currently only 1 is supported
|
| 67 |
+
"decoder_rnn_dim":1024,
|
| 68 |
+
"prenet_dim":256,
|
| 69 |
+
"max_decoder_steps":1000,
|
| 70 |
+
"gate_threshold":0.5,
|
| 71 |
+
"p_attention_dropout":0.1,
|
| 72 |
+
"p_decoder_dropout":0.1,
|
| 73 |
+
|
| 74 |
+
# Attention parameters
|
| 75 |
+
"attention_rnn_dim":1024,
|
| 76 |
+
"attention_dim":128,
|
| 77 |
+
|
| 78 |
+
# Location Layer parameters
|
| 79 |
+
"attention_location_n_filters":32,
|
| 80 |
+
"attention_location_kernel_size":31,
|
| 81 |
+
|
| 82 |
+
# Mel-post processing network parameters
|
| 83 |
+
"postnet_embedding_dim":512,
|
| 84 |
+
"postnet_kernel_size":5,
|
| 85 |
+
"postnet_n_convolutions":5,
|
| 86 |
+
|
| 87 |
+
################################
|
| 88 |
+
# Optimization Hyperparameters #
|
| 89 |
+
################################
|
| 90 |
+
"use_saved_learning_rate":True,
|
| 91 |
+
"learning_rate":1e-3,
|
| 92 |
+
"weight_decay":1e-6,
|
| 93 |
+
"grad_clip_thresh":1.0,
|
| 94 |
+
"batch_size":8, # each gpus
|
| 95 |
+
"mask_padding":True # set model's padded outputs to padded values
|
| 96 |
+
})
|
| 97 |
+
|
| 98 |
+
if hparams_string:
|
| 99 |
+
hps = hparams_string[1:-2].split("-")
|
| 100 |
+
for hp in hps:
|
| 101 |
+
k,v = hp.split(":")
|
| 102 |
+
if k in hparams:
|
| 103 |
+
hparams[k] = v
|
| 104 |
+
print("Set hparam: " + k + " to " + v)
|
| 105 |
+
|
| 106 |
+
return hparams
|
inference.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
layers.py
ADDED
|
@@ -0,0 +1,80 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 3 |
+
from audio_processing import dynamic_range_compression
|
| 4 |
+
from audio_processing import dynamic_range_decompression
|
| 5 |
+
from stft import STFT
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class LinearNorm(torch.nn.Module):
|
| 9 |
+
def __init__(self, in_dim, out_dim, bias=True, w_init_gain='linear'):
|
| 10 |
+
super(LinearNorm, self).__init__()
|
| 11 |
+
self.linear_layer = torch.nn.Linear(in_dim, out_dim, bias=bias)
|
| 12 |
+
|
| 13 |
+
torch.nn.init.xavier_uniform_(
|
| 14 |
+
self.linear_layer.weight,
|
| 15 |
+
gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 16 |
+
|
| 17 |
+
def forward(self, x):
|
| 18 |
+
return self.linear_layer(x)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
class ConvNorm(torch.nn.Module):
|
| 22 |
+
def __init__(self, in_channels, out_channels, kernel_size=1, stride=1,
|
| 23 |
+
padding=None, dilation=1, bias=True, w_init_gain='linear'):
|
| 24 |
+
super(ConvNorm, self).__init__()
|
| 25 |
+
if padding is None:
|
| 26 |
+
assert(kernel_size % 2 == 1)
|
| 27 |
+
padding = int(dilation * (kernel_size - 1) / 2)
|
| 28 |
+
|
| 29 |
+
self.conv = torch.nn.Conv1d(in_channels, out_channels,
|
| 30 |
+
kernel_size=kernel_size, stride=stride,
|
| 31 |
+
padding=padding, dilation=dilation,
|
| 32 |
+
bias=bias)
|
| 33 |
+
|
| 34 |
+
torch.nn.init.xavier_uniform_(
|
| 35 |
+
self.conv.weight, gain=torch.nn.init.calculate_gain(w_init_gain))
|
| 36 |
+
|
| 37 |
+
def forward(self, signal):
|
| 38 |
+
conv_signal = self.conv(signal)
|
| 39 |
+
return conv_signal
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class TacotronSTFT(torch.nn.Module):
|
| 43 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
| 44 |
+
n_mel_channels=80, sampling_rate=22050, mel_fmin=0.0,
|
| 45 |
+
mel_fmax=8000.0):
|
| 46 |
+
super(TacotronSTFT, self).__init__()
|
| 47 |
+
self.n_mel_channels = n_mel_channels
|
| 48 |
+
self.sampling_rate = sampling_rate
|
| 49 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
| 50 |
+
mel_basis = librosa_mel_fn(
|
| 51 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax)
|
| 52 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 53 |
+
self.register_buffer('mel_basis', mel_basis)
|
| 54 |
+
|
| 55 |
+
def spectral_normalize(self, magnitudes):
|
| 56 |
+
output = dynamic_range_compression(magnitudes)
|
| 57 |
+
return output
|
| 58 |
+
|
| 59 |
+
def spectral_de_normalize(self, magnitudes):
|
| 60 |
+
output = dynamic_range_decompression(magnitudes)
|
| 61 |
+
return output
|
| 62 |
+
|
| 63 |
+
def mel_spectrogram(self, y):
|
| 64 |
+
"""Computes mel-spectrograms from a batch of waves
|
| 65 |
+
PARAMS
|
| 66 |
+
------
|
| 67 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
| 68 |
+
|
| 69 |
+
RETURNS
|
| 70 |
+
-------
|
| 71 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
| 72 |
+
"""
|
| 73 |
+
assert(torch.min(y.data) >= -1)
|
| 74 |
+
assert(torch.max(y.data) <= 1)
|
| 75 |
+
|
| 76 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
| 77 |
+
magnitudes = magnitudes.data
|
| 78 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
| 79 |
+
mel_output = self.spectral_normalize(mel_output)
|
| 80 |
+
return mel_output
|
logger.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import torch
|
| 3 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 4 |
+
from plotting_utils import plot_alignment_to_numpy, plot_spectrogram_to_numpy
|
| 5 |
+
from plotting_utils import plot_gate_outputs_to_numpy
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
class Tacotron2Logger(SummaryWriter):
|
| 9 |
+
def __init__(self, logdir):
|
| 10 |
+
super(Tacotron2Logger, self).__init__(logdir)
|
| 11 |
+
|
| 12 |
+
def log_training(self, reduced_loss, grad_norm, learning_rate, duration,
|
| 13 |
+
iteration):
|
| 14 |
+
self.add_scalar("training.loss", reduced_loss, iteration)
|
| 15 |
+
self.add_scalar("grad.norm", grad_norm, iteration)
|
| 16 |
+
self.add_scalar("learning.rate", learning_rate, iteration)
|
| 17 |
+
self.add_scalar("duration", duration, iteration)
|
| 18 |
+
|
| 19 |
+
def log_validation(self, reduced_loss, model, y, y_pred, iteration):
|
| 20 |
+
self.add_scalar("validation.loss", reduced_loss, iteration)
|
| 21 |
+
_, mel_outputs, gate_outputs, alignments = y_pred
|
| 22 |
+
mel_targets, gate_targets = y
|
| 23 |
+
|
| 24 |
+
# plot distribution of parameters
|
| 25 |
+
for tag, value in model.named_parameters():
|
| 26 |
+
tag = tag.replace('.', '/')
|
| 27 |
+
self.add_histogram(tag, value.data.cpu().numpy(), iteration)
|
| 28 |
+
|
| 29 |
+
# plot alignment, mel target and predicted, gate target and predicted
|
| 30 |
+
idx = random.randint(0, alignments.size(0) - 1)
|
| 31 |
+
self.add_image(
|
| 32 |
+
"alignment",
|
| 33 |
+
plot_alignment_to_numpy(alignments[idx].data.cpu().numpy().T),
|
| 34 |
+
iteration, dataformats='HWC')
|
| 35 |
+
self.add_image(
|
| 36 |
+
"mel_target",
|
| 37 |
+
plot_spectrogram_to_numpy(mel_targets[idx].data.cpu().numpy()),
|
| 38 |
+
iteration, dataformats='HWC')
|
| 39 |
+
self.add_image(
|
| 40 |
+
"mel_predicted",
|
| 41 |
+
plot_spectrogram_to_numpy(mel_outputs[idx].data.cpu().numpy()),
|
| 42 |
+
iteration, dataformats='HWC')
|
| 43 |
+
self.add_image(
|
| 44 |
+
"gate",
|
| 45 |
+
plot_gate_outputs_to_numpy(
|
| 46 |
+
gate_targets[idx].data.cpu().numpy(),
|
| 47 |
+
torch.sigmoid(gate_outputs[idx]).data.cpu().numpy()),
|
| 48 |
+
iteration, dataformats='HWC')
|
logic.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib.pyplot as plt
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import base64
|
| 5 |
+
import io
|
| 6 |
+
import matplotlib.pyplot as plt
|
| 7 |
+
from hparams import create_hparams
|
| 8 |
+
from model import Tacotron2
|
| 9 |
+
from train import load_model
|
| 10 |
+
from text import text_to_sequence
|
| 11 |
+
import os
|
| 12 |
+
import subprocess
|
| 13 |
+
import librosa.display
|
| 14 |
+
|
| 15 |
+
# Function to plot data
|
| 16 |
+
def plot_data(data, figsize=(16, 4), titles=['Mel Spectrogram (Original)', 'Mel Spectrogram (Postnet)', 'Alignment'],
|
| 17 |
+
xlabel=['Time Steps', 'Time Steps', 'Decoder Time Steps'],
|
| 18 |
+
ylabel=['Mel Channels', 'Mel Channels', 'Encoder Time Steps'], colorbar_labels=None):
|
| 19 |
+
fig, axes = plt.subplots(1, len(data), figsize=figsize)
|
| 20 |
+
for i in range(len(data)):
|
| 21 |
+
im = axes[i].imshow(data[i], aspect='auto', origin='lower', interpolation='none', cmap='viridis')
|
| 22 |
+
|
| 23 |
+
if titles:
|
| 24 |
+
axes[i].set_title(titles[i])
|
| 25 |
+
if xlabel:
|
| 26 |
+
axes[i].set_xlabel(xlabel[i])
|
| 27 |
+
if ylabel:
|
| 28 |
+
axes[i].set_ylabel(ylabel[i])
|
| 29 |
+
|
| 30 |
+
# Add color bar
|
| 31 |
+
cbar = fig.colorbar(im, ax=axes[i])
|
| 32 |
+
if colorbar_labels:
|
| 33 |
+
cbar.set_label(colorbar_labels[i])
|
| 34 |
+
|
| 35 |
+
plt.tight_layout()
|
| 36 |
+
img_buffer = io.BytesIO()
|
| 37 |
+
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
|
| 38 |
+
plt.close()
|
| 39 |
+
|
| 40 |
+
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
|
| 41 |
+
|
| 42 |
+
return img_base64
|
| 43 |
+
|
| 44 |
+
#Function to plot timedomain waveform
|
| 45 |
+
def plot_waveforms(audio_file, sr=22050):
|
| 46 |
+
# Load audio waveform
|
| 47 |
+
y, sr = librosa.load(audio_file, sr=sr)
|
| 48 |
+
|
| 49 |
+
# Create time vector
|
| 50 |
+
time = librosa.times_like(y, sr=sr)
|
| 51 |
+
|
| 52 |
+
# Plot the waveform
|
| 53 |
+
plt.figure(figsize=(16, 4))
|
| 54 |
+
librosa.display.waveshow(y, sr=sr)
|
| 55 |
+
plt.title('Time vs Amplitude')
|
| 56 |
+
plt.xlabel('Time (s)')
|
| 57 |
+
plt.ylabel('Amplitude')
|
| 58 |
+
|
| 59 |
+
plt.tight_layout()
|
| 60 |
+
# plt.savefig('static/waveform.png')
|
| 61 |
+
img_buffer = io.BytesIO()
|
| 62 |
+
plt.savefig(img_buffer, format='png', bbox_inches='tight', pad_inches=0)
|
| 63 |
+
plt.close()
|
| 64 |
+
|
| 65 |
+
img_base64 = base64.b64encode(img_buffer.getvalue()).decode('utf-8')
|
| 66 |
+
|
| 67 |
+
return img_base64
|
| 68 |
+
|
| 69 |
+
def synthesize_voice(text_input, checkpoint_path):
|
| 70 |
+
# Load Tacotron2 model
|
| 71 |
+
hparams = create_hparams()
|
| 72 |
+
hparams.sampling_rate = 22050
|
| 73 |
+
|
| 74 |
+
# Load model from checkpoint
|
| 75 |
+
model = load_model(hparams)
|
| 76 |
+
model.load_state_dict(torch.load(checkpoint_path)['state_dict'])
|
| 77 |
+
model = model.cuda().eval().half()
|
| 78 |
+
|
| 79 |
+
# Nepali text
|
| 80 |
+
sequence = np.array(text_to_sequence(text_input, ['transliteration_cleaners']))[None, :]
|
| 81 |
+
sequence = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
|
| 82 |
+
|
| 83 |
+
# Melspectrogram and Alignment graph
|
| 84 |
+
mel_outputs, mel_outputs_postnet, _, alignments = model.inference(sequence)
|
| 85 |
+
mel_output_data = mel_outputs.data.cpu().numpy()[0]
|
| 86 |
+
mel_output_postnet_data = mel_outputs_postnet.data.cpu().numpy()[0]
|
| 87 |
+
alignments_data = alignments.data.cpu().numpy()[0].T
|
| 88 |
+
|
| 89 |
+
np.save('mel_files/mel1'+'.npy', mel_output_data)
|
| 90 |
+
|
| 91 |
+
input_mels_dir = 'mel_files/'
|
| 92 |
+
output_dir = 'audio_output/'
|
| 93 |
+
run_hifigan_inference(input_mels_dir, output_dir)
|
| 94 |
+
|
| 95 |
+
return mel_output_data, mel_output_postnet_data, alignments_data
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def run_hifigan_inference(input_mels_dir, output_dir):
|
| 99 |
+
script_path = os.path.join(os.path.dirname("hifigan/"), "inference_e2e.py") # Assuming both scripts are in the same directory
|
| 100 |
+
subprocess.run(["python", script_path, "--checkpoint_file", "generator_v1", "--input_mels_dir", input_mels_dir, "--output_dir", output_dir])
|
loss_function.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from torch import nn
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Tacotron2Loss(nn.Module):
|
| 5 |
+
def __init__(self):
|
| 6 |
+
super(Tacotron2Loss, self).__init__()
|
| 7 |
+
|
| 8 |
+
def forward(self, model_output, targets):
|
| 9 |
+
mel_target, gate_target = targets[0], targets[1]
|
| 10 |
+
mel_target.requires_grad = False
|
| 11 |
+
gate_target.requires_grad = False
|
| 12 |
+
gate_target = gate_target.view(-1, 1)
|
| 13 |
+
|
| 14 |
+
mel_out, mel_out_postnet, gate_out, _ = model_output
|
| 15 |
+
gate_out = gate_out.view(-1, 1)
|
| 16 |
+
mel_loss = nn.MSELoss()(mel_out, mel_target) + \
|
| 17 |
+
nn.MSELoss()(mel_out_postnet, mel_target)
|
| 18 |
+
gate_loss = nn.BCEWithLogitsLoss()(gate_out, gate_target)
|
| 19 |
+
return mel_loss + gate_loss
|
loss_scaler.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
|
| 3 |
+
class LossScaler:
|
| 4 |
+
|
| 5 |
+
def __init__(self, scale=1):
|
| 6 |
+
self.cur_scale = scale
|
| 7 |
+
|
| 8 |
+
# `params` is a list / generator of torch.Variable
|
| 9 |
+
def has_overflow(self, params):
|
| 10 |
+
return False
|
| 11 |
+
|
| 12 |
+
# `x` is a torch.Tensor
|
| 13 |
+
def _has_inf_or_nan(x):
|
| 14 |
+
return False
|
| 15 |
+
|
| 16 |
+
# `overflow` is boolean indicating whether we overflowed in gradient
|
| 17 |
+
def update_scale(self, overflow):
|
| 18 |
+
pass
|
| 19 |
+
|
| 20 |
+
@property
|
| 21 |
+
def loss_scale(self):
|
| 22 |
+
return self.cur_scale
|
| 23 |
+
|
| 24 |
+
def scale_gradient(self, module, grad_in, grad_out):
|
| 25 |
+
return tuple(self.loss_scale * g for g in grad_in)
|
| 26 |
+
|
| 27 |
+
def backward(self, loss):
|
| 28 |
+
scaled_loss = loss*self.loss_scale
|
| 29 |
+
scaled_loss.backward()
|
| 30 |
+
|
| 31 |
+
class DynamicLossScaler:
|
| 32 |
+
|
| 33 |
+
def __init__(self,
|
| 34 |
+
init_scale=2**32,
|
| 35 |
+
scale_factor=2.,
|
| 36 |
+
scale_window=1000):
|
| 37 |
+
self.cur_scale = init_scale
|
| 38 |
+
self.cur_iter = 0
|
| 39 |
+
self.last_overflow_iter = -1
|
| 40 |
+
self.scale_factor = scale_factor
|
| 41 |
+
self.scale_window = scale_window
|
| 42 |
+
|
| 43 |
+
# `params` is a list / generator of torch.Variable
|
| 44 |
+
def has_overflow(self, params):
|
| 45 |
+
# return False
|
| 46 |
+
for p in params:
|
| 47 |
+
if p.grad is not None and DynamicLossScaler._has_inf_or_nan(p.grad.data):
|
| 48 |
+
return True
|
| 49 |
+
|
| 50 |
+
return False
|
| 51 |
+
|
| 52 |
+
# `x` is a torch.Tensor
|
| 53 |
+
def _has_inf_or_nan(x):
|
| 54 |
+
cpu_sum = float(x.float().sum())
|
| 55 |
+
if cpu_sum == float('inf') or cpu_sum == -float('inf') or cpu_sum != cpu_sum:
|
| 56 |
+
return True
|
| 57 |
+
return False
|
| 58 |
+
|
| 59 |
+
# `overflow` is boolean indicating whether we overflowed in gradient
|
| 60 |
+
def update_scale(self, overflow):
|
| 61 |
+
if overflow:
|
| 62 |
+
#self.cur_scale /= self.scale_factor
|
| 63 |
+
self.cur_scale = max(self.cur_scale/self.scale_factor, 1)
|
| 64 |
+
self.last_overflow_iter = self.cur_iter
|
| 65 |
+
else:
|
| 66 |
+
if (self.cur_iter - self.last_overflow_iter) % self.scale_window == 0:
|
| 67 |
+
self.cur_scale *= self.scale_factor
|
| 68 |
+
# self.cur_scale = 1
|
| 69 |
+
self.cur_iter += 1
|
| 70 |
+
|
| 71 |
+
@property
|
| 72 |
+
def loss_scale(self):
|
| 73 |
+
return self.cur_scale
|
| 74 |
+
|
| 75 |
+
def scale_gradient(self, module, grad_in, grad_out):
|
| 76 |
+
return tuple(self.loss_scale * g for g in grad_in)
|
| 77 |
+
|
| 78 |
+
def backward(self, loss):
|
| 79 |
+
scaled_loss = loss*self.loss_scale
|
| 80 |
+
scaled_loss.backward()
|
| 81 |
+
|
| 82 |
+
##############################################################
|
| 83 |
+
# Example usage below here -- assuming it's in a separate file
|
| 84 |
+
##############################################################
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
import torch
|
| 87 |
+
from torch.autograd import Variable
|
| 88 |
+
from dynamic_loss_scaler import DynamicLossScaler
|
| 89 |
+
|
| 90 |
+
# N is batch size; D_in is input dimension;
|
| 91 |
+
# H is hidden dimension; D_out is output dimension.
|
| 92 |
+
N, D_in, H, D_out = 64, 1000, 100, 10
|
| 93 |
+
|
| 94 |
+
# Create random Tensors to hold inputs and outputs, and wrap them in Variables.
|
| 95 |
+
x = Variable(torch.randn(N, D_in), requires_grad=False)
|
| 96 |
+
y = Variable(torch.randn(N, D_out), requires_grad=False)
|
| 97 |
+
|
| 98 |
+
w1 = Variable(torch.randn(D_in, H), requires_grad=True)
|
| 99 |
+
w2 = Variable(torch.randn(H, D_out), requires_grad=True)
|
| 100 |
+
parameters = [w1, w2]
|
| 101 |
+
|
| 102 |
+
learning_rate = 1e-6
|
| 103 |
+
optimizer = torch.optim.SGD(parameters, lr=learning_rate)
|
| 104 |
+
loss_scaler = DynamicLossScaler()
|
| 105 |
+
|
| 106 |
+
for t in range(500):
|
| 107 |
+
y_pred = x.mm(w1).clamp(min=0).mm(w2)
|
| 108 |
+
loss = (y_pred - y).pow(2).sum() * loss_scaler.loss_scale
|
| 109 |
+
print('Iter {} loss scale: {}'.format(t, loss_scaler.loss_scale))
|
| 110 |
+
print('Iter {} scaled loss: {}'.format(t, loss.data[0]))
|
| 111 |
+
print('Iter {} unscaled loss: {}'.format(t, loss.data[0] / loss_scaler.loss_scale))
|
| 112 |
+
|
| 113 |
+
# Run backprop
|
| 114 |
+
optimizer.zero_grad()
|
| 115 |
+
loss.backward()
|
| 116 |
+
|
| 117 |
+
# Check for overflow
|
| 118 |
+
has_overflow = DynamicLossScaler.has_overflow(parameters)
|
| 119 |
+
|
| 120 |
+
# If no overflow, unscale grad and update as usual
|
| 121 |
+
if not has_overflow:
|
| 122 |
+
for param in parameters:
|
| 123 |
+
param.grad.data.mul_(1. / loss_scaler.loss_scale)
|
| 124 |
+
optimizer.step()
|
| 125 |
+
# Otherwise, don't do anything -- ie, skip iteration
|
| 126 |
+
else:
|
| 127 |
+
print('OVERFLOW!')
|
| 128 |
+
|
| 129 |
+
# Update loss scale for next iteration
|
| 130 |
+
loss_scaler.update_scale(has_overflow)
|
| 131 |
+
|
model.py
ADDED
|
@@ -0,0 +1,529 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
+
import torch
|
| 3 |
+
from torch.autograd import Variable
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
from layers import ConvNorm, LinearNorm
|
| 7 |
+
from utils import to_gpu, get_mask_from_lengths
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class LocationLayer(nn.Module):
|
| 11 |
+
def __init__(self, attention_n_filters, attention_kernel_size,
|
| 12 |
+
attention_dim):
|
| 13 |
+
super(LocationLayer, self).__init__()
|
| 14 |
+
padding = int((attention_kernel_size - 1) / 2)
|
| 15 |
+
self.location_conv = ConvNorm(2, attention_n_filters,
|
| 16 |
+
kernel_size=attention_kernel_size,
|
| 17 |
+
padding=padding, bias=False, stride=1,
|
| 18 |
+
dilation=1)
|
| 19 |
+
self.location_dense = LinearNorm(attention_n_filters, attention_dim,
|
| 20 |
+
bias=False, w_init_gain='tanh')
|
| 21 |
+
|
| 22 |
+
def forward(self, attention_weights_cat):
|
| 23 |
+
processed_attention = self.location_conv(attention_weights_cat)
|
| 24 |
+
processed_attention = processed_attention.transpose(1, 2)
|
| 25 |
+
processed_attention = self.location_dense(processed_attention)
|
| 26 |
+
return processed_attention
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class Attention(nn.Module):
|
| 30 |
+
def __init__(self, attention_rnn_dim, embedding_dim, attention_dim,
|
| 31 |
+
attention_location_n_filters, attention_location_kernel_size):
|
| 32 |
+
super(Attention, self).__init__()
|
| 33 |
+
self.query_layer = LinearNorm(attention_rnn_dim, attention_dim,
|
| 34 |
+
bias=False, w_init_gain='tanh')
|
| 35 |
+
self.memory_layer = LinearNorm(embedding_dim, attention_dim, bias=False,
|
| 36 |
+
w_init_gain='tanh')
|
| 37 |
+
self.v = LinearNorm(attention_dim, 1, bias=False)
|
| 38 |
+
self.location_layer = LocationLayer(attention_location_n_filters,
|
| 39 |
+
attention_location_kernel_size,
|
| 40 |
+
attention_dim)
|
| 41 |
+
self.score_mask_value = -float("inf")
|
| 42 |
+
|
| 43 |
+
def get_alignment_energies(self, query, processed_memory,
|
| 44 |
+
attention_weights_cat):
|
| 45 |
+
"""
|
| 46 |
+
PARAMS
|
| 47 |
+
------
|
| 48 |
+
query: decoder output (batch, n_mel_channels * n_frames_per_step)
|
| 49 |
+
processed_memory: processed encoder outputs (B, T_in, attention_dim)
|
| 50 |
+
attention_weights_cat: cumulative and prev. att weights (B, 2, max_time)
|
| 51 |
+
|
| 52 |
+
RETURNS
|
| 53 |
+
-------
|
| 54 |
+
alignment (batch, max_time)
|
| 55 |
+
"""
|
| 56 |
+
|
| 57 |
+
processed_query = self.query_layer(query.unsqueeze(1))
|
| 58 |
+
processed_attention_weights = self.location_layer(attention_weights_cat)
|
| 59 |
+
energies = self.v(torch.tanh(
|
| 60 |
+
processed_query + processed_attention_weights + processed_memory))
|
| 61 |
+
|
| 62 |
+
energies = energies.squeeze(-1)
|
| 63 |
+
return energies
|
| 64 |
+
|
| 65 |
+
def forward(self, attention_hidden_state, memory, processed_memory,
|
| 66 |
+
attention_weights_cat, mask):
|
| 67 |
+
"""
|
| 68 |
+
PARAMS
|
| 69 |
+
------
|
| 70 |
+
attention_hidden_state: attention rnn last output
|
| 71 |
+
memory: encoder outputs
|
| 72 |
+
processed_memory: processed encoder outputs
|
| 73 |
+
attention_weights_cat: previous and cummulative attention weights
|
| 74 |
+
mask: binary mask for padded data
|
| 75 |
+
"""
|
| 76 |
+
alignment = self.get_alignment_energies(
|
| 77 |
+
attention_hidden_state, processed_memory, attention_weights_cat)
|
| 78 |
+
|
| 79 |
+
if mask is not None:
|
| 80 |
+
alignment.data.masked_fill_(mask, self.score_mask_value)
|
| 81 |
+
|
| 82 |
+
attention_weights = F.softmax(alignment, dim=1)
|
| 83 |
+
attention_context = torch.bmm(attention_weights.unsqueeze(1), memory)
|
| 84 |
+
attention_context = attention_context.squeeze(1)
|
| 85 |
+
|
| 86 |
+
return attention_context, attention_weights
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class Prenet(nn.Module):
|
| 90 |
+
def __init__(self, in_dim, sizes):
|
| 91 |
+
super(Prenet, self).__init__()
|
| 92 |
+
in_sizes = [in_dim] + sizes[:-1]
|
| 93 |
+
self.layers = nn.ModuleList(
|
| 94 |
+
[LinearNorm(in_size, out_size, bias=False)
|
| 95 |
+
for (in_size, out_size) in zip(in_sizes, sizes)])
|
| 96 |
+
|
| 97 |
+
def forward(self, x):
|
| 98 |
+
for linear in self.layers:
|
| 99 |
+
x = F.dropout(F.relu(linear(x)), p=0.5, training=True)
|
| 100 |
+
return x
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class Postnet(nn.Module):
|
| 104 |
+
"""Postnet
|
| 105 |
+
- Five 1-d convolution with 512 channels and kernel size 5
|
| 106 |
+
"""
|
| 107 |
+
|
| 108 |
+
def __init__(self, hparams):
|
| 109 |
+
super(Postnet, self).__init__()
|
| 110 |
+
self.convolutions = nn.ModuleList()
|
| 111 |
+
|
| 112 |
+
self.convolutions.append(
|
| 113 |
+
nn.Sequential(
|
| 114 |
+
ConvNorm(hparams.n_mel_channels, hparams.postnet_embedding_dim,
|
| 115 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
| 116 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
| 117 |
+
dilation=1, w_init_gain='tanh'),
|
| 118 |
+
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
for i in range(1, hparams.postnet_n_convolutions - 1):
|
| 122 |
+
self.convolutions.append(
|
| 123 |
+
nn.Sequential(
|
| 124 |
+
ConvNorm(hparams.postnet_embedding_dim,
|
| 125 |
+
hparams.postnet_embedding_dim,
|
| 126 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
| 127 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
| 128 |
+
dilation=1, w_init_gain='tanh'),
|
| 129 |
+
nn.BatchNorm1d(hparams.postnet_embedding_dim))
|
| 130 |
+
)
|
| 131 |
+
|
| 132 |
+
self.convolutions.append(
|
| 133 |
+
nn.Sequential(
|
| 134 |
+
ConvNorm(hparams.postnet_embedding_dim, hparams.n_mel_channels,
|
| 135 |
+
kernel_size=hparams.postnet_kernel_size, stride=1,
|
| 136 |
+
padding=int((hparams.postnet_kernel_size - 1) / 2),
|
| 137 |
+
dilation=1, w_init_gain='linear'),
|
| 138 |
+
nn.BatchNorm1d(hparams.n_mel_channels))
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
def forward(self, x):
|
| 142 |
+
for i in range(len(self.convolutions) - 1):
|
| 143 |
+
x = F.dropout(torch.tanh(self.convolutions[i](x)), 0.5, self.training)
|
| 144 |
+
x = F.dropout(self.convolutions[-1](x), 0.5, self.training)
|
| 145 |
+
|
| 146 |
+
return x
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class Encoder(nn.Module):
|
| 150 |
+
"""Encoder module:
|
| 151 |
+
- Three 1-d convolution banks
|
| 152 |
+
- Bidirectional LSTM
|
| 153 |
+
"""
|
| 154 |
+
def __init__(self, hparams):
|
| 155 |
+
super(Encoder, self).__init__()
|
| 156 |
+
|
| 157 |
+
convolutions = []
|
| 158 |
+
for _ in range(hparams.encoder_n_convolutions):
|
| 159 |
+
conv_layer = nn.Sequential(
|
| 160 |
+
ConvNorm(hparams.encoder_embedding_dim,
|
| 161 |
+
hparams.encoder_embedding_dim,
|
| 162 |
+
kernel_size=hparams.encoder_kernel_size, stride=1,
|
| 163 |
+
padding=int((hparams.encoder_kernel_size - 1) / 2),
|
| 164 |
+
dilation=1, w_init_gain='relu'),
|
| 165 |
+
nn.BatchNorm1d(hparams.encoder_embedding_dim))
|
| 166 |
+
convolutions.append(conv_layer)
|
| 167 |
+
self.convolutions = nn.ModuleList(convolutions)
|
| 168 |
+
|
| 169 |
+
self.lstm = nn.LSTM(hparams.encoder_embedding_dim,
|
| 170 |
+
int(hparams.encoder_embedding_dim / 2), 1,
|
| 171 |
+
batch_first=True, bidirectional=True)
|
| 172 |
+
|
| 173 |
+
def forward(self, x, input_lengths):
|
| 174 |
+
for conv in self.convolutions:
|
| 175 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 176 |
+
|
| 177 |
+
x = x.transpose(1, 2)
|
| 178 |
+
|
| 179 |
+
# pytorch tensor are not reversible, hence the conversion
|
| 180 |
+
input_lengths = input_lengths.cpu().numpy()
|
| 181 |
+
x = nn.utils.rnn.pack_padded_sequence(
|
| 182 |
+
x, input_lengths, batch_first=True)
|
| 183 |
+
|
| 184 |
+
self.lstm.flatten_parameters()
|
| 185 |
+
outputs, _ = self.lstm(x)
|
| 186 |
+
|
| 187 |
+
outputs, _ = nn.utils.rnn.pad_packed_sequence(
|
| 188 |
+
outputs, batch_first=True)
|
| 189 |
+
|
| 190 |
+
return outputs
|
| 191 |
+
|
| 192 |
+
def inference(self, x):
|
| 193 |
+
for conv in self.convolutions:
|
| 194 |
+
x = F.dropout(F.relu(conv(x)), 0.5, self.training)
|
| 195 |
+
|
| 196 |
+
x = x.transpose(1, 2)
|
| 197 |
+
|
| 198 |
+
self.lstm.flatten_parameters()
|
| 199 |
+
outputs, _ = self.lstm(x)
|
| 200 |
+
|
| 201 |
+
return outputs
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
class Decoder(nn.Module):
|
| 205 |
+
def __init__(self, hparams):
|
| 206 |
+
super(Decoder, self).__init__()
|
| 207 |
+
self.n_mel_channels = hparams.n_mel_channels
|
| 208 |
+
self.n_frames_per_step = hparams.n_frames_per_step
|
| 209 |
+
self.encoder_embedding_dim = hparams.encoder_embedding_dim
|
| 210 |
+
self.attention_rnn_dim = hparams.attention_rnn_dim
|
| 211 |
+
self.decoder_rnn_dim = hparams.decoder_rnn_dim
|
| 212 |
+
self.prenet_dim = hparams.prenet_dim
|
| 213 |
+
self.max_decoder_steps = hparams.max_decoder_steps
|
| 214 |
+
self.gate_threshold = hparams.gate_threshold
|
| 215 |
+
self.p_attention_dropout = hparams.p_attention_dropout
|
| 216 |
+
self.p_decoder_dropout = hparams.p_decoder_dropout
|
| 217 |
+
|
| 218 |
+
self.prenet = Prenet(
|
| 219 |
+
hparams.n_mel_channels * hparams.n_frames_per_step,
|
| 220 |
+
[hparams.prenet_dim, hparams.prenet_dim])
|
| 221 |
+
|
| 222 |
+
self.attention_rnn = nn.LSTMCell(
|
| 223 |
+
hparams.prenet_dim + hparams.encoder_embedding_dim,
|
| 224 |
+
hparams.attention_rnn_dim)
|
| 225 |
+
|
| 226 |
+
self.attention_layer = Attention(
|
| 227 |
+
hparams.attention_rnn_dim, hparams.encoder_embedding_dim,
|
| 228 |
+
hparams.attention_dim, hparams.attention_location_n_filters,
|
| 229 |
+
hparams.attention_location_kernel_size)
|
| 230 |
+
|
| 231 |
+
self.decoder_rnn = nn.LSTMCell(
|
| 232 |
+
hparams.attention_rnn_dim + hparams.encoder_embedding_dim,
|
| 233 |
+
hparams.decoder_rnn_dim, 1)
|
| 234 |
+
|
| 235 |
+
self.linear_projection = LinearNorm(
|
| 236 |
+
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim,
|
| 237 |
+
hparams.n_mel_channels * hparams.n_frames_per_step)
|
| 238 |
+
|
| 239 |
+
self.gate_layer = LinearNorm(
|
| 240 |
+
hparams.decoder_rnn_dim + hparams.encoder_embedding_dim, 1,
|
| 241 |
+
bias=True, w_init_gain='sigmoid')
|
| 242 |
+
|
| 243 |
+
def get_go_frame(self, memory):
|
| 244 |
+
""" Gets all zeros frames to use as first decoder input
|
| 245 |
+
PARAMS
|
| 246 |
+
------
|
| 247 |
+
memory: decoder outputs
|
| 248 |
+
|
| 249 |
+
RETURNS
|
| 250 |
+
-------
|
| 251 |
+
decoder_input: all zeros frames
|
| 252 |
+
"""
|
| 253 |
+
B = memory.size(0)
|
| 254 |
+
decoder_input = Variable(memory.data.new(
|
| 255 |
+
B, self.n_mel_channels * self.n_frames_per_step).zero_())
|
| 256 |
+
return decoder_input
|
| 257 |
+
|
| 258 |
+
def initialize_decoder_states(self, memory, mask):
|
| 259 |
+
""" Initializes attention rnn states, decoder rnn states, attention
|
| 260 |
+
weights, attention cumulative weights, attention context, stores memory
|
| 261 |
+
and stores processed memory
|
| 262 |
+
PARAMS
|
| 263 |
+
------
|
| 264 |
+
memory: Encoder outputs
|
| 265 |
+
mask: Mask for padded data if training, expects None for inference
|
| 266 |
+
"""
|
| 267 |
+
B = memory.size(0)
|
| 268 |
+
MAX_TIME = memory.size(1)
|
| 269 |
+
|
| 270 |
+
self.attention_hidden = Variable(memory.data.new(
|
| 271 |
+
B, self.attention_rnn_dim).zero_())
|
| 272 |
+
self.attention_cell = Variable(memory.data.new(
|
| 273 |
+
B, self.attention_rnn_dim).zero_())
|
| 274 |
+
|
| 275 |
+
self.decoder_hidden = Variable(memory.data.new(
|
| 276 |
+
B, self.decoder_rnn_dim).zero_())
|
| 277 |
+
self.decoder_cell = Variable(memory.data.new(
|
| 278 |
+
B, self.decoder_rnn_dim).zero_())
|
| 279 |
+
|
| 280 |
+
self.attention_weights = Variable(memory.data.new(
|
| 281 |
+
B, MAX_TIME).zero_())
|
| 282 |
+
self.attention_weights_cum = Variable(memory.data.new(
|
| 283 |
+
B, MAX_TIME).zero_())
|
| 284 |
+
self.attention_context = Variable(memory.data.new(
|
| 285 |
+
B, self.encoder_embedding_dim).zero_())
|
| 286 |
+
|
| 287 |
+
self.memory = memory
|
| 288 |
+
self.processed_memory = self.attention_layer.memory_layer(memory)
|
| 289 |
+
self.mask = mask
|
| 290 |
+
|
| 291 |
+
def parse_decoder_inputs(self, decoder_inputs):
|
| 292 |
+
""" Prepares decoder inputs, i.e. mel outputs
|
| 293 |
+
PARAMS
|
| 294 |
+
------
|
| 295 |
+
decoder_inputs: inputs used for teacher-forced training, i.e. mel-specs
|
| 296 |
+
|
| 297 |
+
RETURNS
|
| 298 |
+
-------
|
| 299 |
+
inputs: processed decoder inputs
|
| 300 |
+
|
| 301 |
+
"""
|
| 302 |
+
# (B, n_mel_channels, T_out) -> (B, T_out, n_mel_channels)
|
| 303 |
+
decoder_inputs = decoder_inputs.transpose(1, 2)
|
| 304 |
+
decoder_inputs = decoder_inputs.view(
|
| 305 |
+
decoder_inputs.size(0),
|
| 306 |
+
int(decoder_inputs.size(1)/self.n_frames_per_step), -1)
|
| 307 |
+
# (B, T_out, n_mel_channels) -> (T_out, B, n_mel_channels)
|
| 308 |
+
decoder_inputs = decoder_inputs.transpose(0, 1)
|
| 309 |
+
return decoder_inputs
|
| 310 |
+
|
| 311 |
+
def parse_decoder_outputs(self, mel_outputs, gate_outputs, alignments):
|
| 312 |
+
""" Prepares decoder outputs for output
|
| 313 |
+
PARAMS
|
| 314 |
+
------
|
| 315 |
+
mel_outputs:
|
| 316 |
+
gate_outputs: gate output energies
|
| 317 |
+
alignments:
|
| 318 |
+
|
| 319 |
+
RETURNS
|
| 320 |
+
-------
|
| 321 |
+
mel_outputs:
|
| 322 |
+
gate_outpust: gate output energies
|
| 323 |
+
alignments:
|
| 324 |
+
"""
|
| 325 |
+
# (T_out, B) -> (B, T_out)
|
| 326 |
+
alignments = torch.stack(alignments).transpose(0, 1)
|
| 327 |
+
# (T_out, B) -> (B, T_out)
|
| 328 |
+
gate_outputs = torch.stack(gate_outputs).transpose(0, 1)
|
| 329 |
+
gate_outputs = gate_outputs.contiguous()
|
| 330 |
+
# (T_out, B, n_mel_channels) -> (B, T_out, n_mel_channels)
|
| 331 |
+
mel_outputs = torch.stack(mel_outputs).transpose(0, 1).contiguous()
|
| 332 |
+
# decouple frames per step
|
| 333 |
+
mel_outputs = mel_outputs.view(
|
| 334 |
+
mel_outputs.size(0), -1, self.n_mel_channels)
|
| 335 |
+
# (B, T_out, n_mel_channels) -> (B, n_mel_channels, T_out)
|
| 336 |
+
mel_outputs = mel_outputs.transpose(1, 2)
|
| 337 |
+
|
| 338 |
+
return mel_outputs, gate_outputs, alignments
|
| 339 |
+
|
| 340 |
+
def decode(self, decoder_input):
|
| 341 |
+
""" Decoder step using stored states, attention and memory
|
| 342 |
+
PARAMS
|
| 343 |
+
------
|
| 344 |
+
decoder_input: previous mel output
|
| 345 |
+
|
| 346 |
+
RETURNS
|
| 347 |
+
-------
|
| 348 |
+
mel_output:
|
| 349 |
+
gate_output: gate output energies
|
| 350 |
+
attention_weights:
|
| 351 |
+
"""
|
| 352 |
+
cell_input = torch.cat((decoder_input, self.attention_context), -1)
|
| 353 |
+
self.attention_hidden, self.attention_cell = self.attention_rnn(
|
| 354 |
+
cell_input, (self.attention_hidden, self.attention_cell))
|
| 355 |
+
self.attention_hidden = F.dropout(
|
| 356 |
+
self.attention_hidden, self.p_attention_dropout, self.training)
|
| 357 |
+
|
| 358 |
+
attention_weights_cat = torch.cat(
|
| 359 |
+
(self.attention_weights.unsqueeze(1),
|
| 360 |
+
self.attention_weights_cum.unsqueeze(1)), dim=1)
|
| 361 |
+
self.attention_context, self.attention_weights = self.attention_layer(
|
| 362 |
+
self.attention_hidden, self.memory, self.processed_memory,
|
| 363 |
+
attention_weights_cat, self.mask)
|
| 364 |
+
|
| 365 |
+
self.attention_weights_cum += self.attention_weights
|
| 366 |
+
decoder_input = torch.cat(
|
| 367 |
+
(self.attention_hidden, self.attention_context), -1)
|
| 368 |
+
self.decoder_hidden, self.decoder_cell = self.decoder_rnn(
|
| 369 |
+
decoder_input, (self.decoder_hidden, self.decoder_cell))
|
| 370 |
+
self.decoder_hidden = F.dropout(
|
| 371 |
+
self.decoder_hidden, self.p_decoder_dropout, self.training)
|
| 372 |
+
|
| 373 |
+
decoder_hidden_attention_context = torch.cat(
|
| 374 |
+
(self.decoder_hidden, self.attention_context), dim=1)
|
| 375 |
+
decoder_output = self.linear_projection(
|
| 376 |
+
decoder_hidden_attention_context)
|
| 377 |
+
|
| 378 |
+
gate_prediction = self.gate_layer(decoder_hidden_attention_context)
|
| 379 |
+
return decoder_output, gate_prediction, self.attention_weights
|
| 380 |
+
|
| 381 |
+
def forward(self, memory, decoder_inputs, memory_lengths):
|
| 382 |
+
""" Decoder forward pass for training
|
| 383 |
+
PARAMS
|
| 384 |
+
------
|
| 385 |
+
memory: Encoder outputs
|
| 386 |
+
decoder_inputs: Decoder inputs for teacher forcing. i.e. mel-specs
|
| 387 |
+
memory_lengths: Encoder output lengths for attention masking.
|
| 388 |
+
|
| 389 |
+
RETURNS
|
| 390 |
+
-------
|
| 391 |
+
mel_outputs: mel outputs from the decoder
|
| 392 |
+
gate_outputs: gate outputs from the decoder
|
| 393 |
+
alignments: sequence of attention weights from the decoder
|
| 394 |
+
"""
|
| 395 |
+
|
| 396 |
+
decoder_input = self.get_go_frame(memory).unsqueeze(0)
|
| 397 |
+
decoder_inputs = self.parse_decoder_inputs(decoder_inputs)
|
| 398 |
+
decoder_inputs = torch.cat((decoder_input, decoder_inputs), dim=0)
|
| 399 |
+
decoder_inputs = self.prenet(decoder_inputs)
|
| 400 |
+
|
| 401 |
+
self.initialize_decoder_states(
|
| 402 |
+
memory, mask=~get_mask_from_lengths(memory_lengths))
|
| 403 |
+
|
| 404 |
+
mel_outputs, gate_outputs, alignments = [], [], []
|
| 405 |
+
while len(mel_outputs) < decoder_inputs.size(0) - 1:
|
| 406 |
+
decoder_input = decoder_inputs[len(mel_outputs)]
|
| 407 |
+
mel_output, gate_output, attention_weights = self.decode(
|
| 408 |
+
decoder_input)
|
| 409 |
+
mel_outputs += [mel_output.squeeze(1)]
|
| 410 |
+
gate_outputs += [gate_output.squeeze(1)]
|
| 411 |
+
alignments += [attention_weights]
|
| 412 |
+
|
| 413 |
+
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
| 414 |
+
mel_outputs, gate_outputs, alignments)
|
| 415 |
+
|
| 416 |
+
return mel_outputs, gate_outputs, alignments
|
| 417 |
+
|
| 418 |
+
def inference(self, memory):
|
| 419 |
+
""" Decoder inference
|
| 420 |
+
PARAMS
|
| 421 |
+
------
|
| 422 |
+
memory: Encoder outputs
|
| 423 |
+
|
| 424 |
+
RETURNS
|
| 425 |
+
-------
|
| 426 |
+
mel_outputs: mel outputs from the decoder
|
| 427 |
+
gate_outputs: gate outputs from the decoder
|
| 428 |
+
alignments: sequence of attention weights from the decoder
|
| 429 |
+
"""
|
| 430 |
+
decoder_input = self.get_go_frame(memory)
|
| 431 |
+
|
| 432 |
+
self.initialize_decoder_states(memory, mask=None)
|
| 433 |
+
|
| 434 |
+
mel_outputs, gate_outputs, alignments = [], [], []
|
| 435 |
+
while True:
|
| 436 |
+
decoder_input = self.prenet(decoder_input)
|
| 437 |
+
mel_output, gate_output, alignment = self.decode(decoder_input)
|
| 438 |
+
|
| 439 |
+
mel_outputs += [mel_output.squeeze(1)]
|
| 440 |
+
gate_outputs += [gate_output]
|
| 441 |
+
alignments += [alignment]
|
| 442 |
+
|
| 443 |
+
if torch.sigmoid(gate_output.data) > self.gate_threshold:
|
| 444 |
+
break
|
| 445 |
+
elif len(mel_outputs) == self.max_decoder_steps:
|
| 446 |
+
print("Warning! Reached max decoder steps")
|
| 447 |
+
break
|
| 448 |
+
|
| 449 |
+
decoder_input = mel_output
|
| 450 |
+
|
| 451 |
+
mel_outputs, gate_outputs, alignments = self.parse_decoder_outputs(
|
| 452 |
+
mel_outputs, gate_outputs, alignments)
|
| 453 |
+
|
| 454 |
+
return mel_outputs, gate_outputs, alignments
|
| 455 |
+
|
| 456 |
+
|
| 457 |
+
class Tacotron2(nn.Module):
|
| 458 |
+
def __init__(self, hparams):
|
| 459 |
+
super(Tacotron2, self).__init__()
|
| 460 |
+
self.mask_padding = hparams.mask_padding
|
| 461 |
+
self.fp16_run = hparams.fp16_run
|
| 462 |
+
self.n_mel_channels = hparams.n_mel_channels
|
| 463 |
+
self.n_frames_per_step = hparams.n_frames_per_step
|
| 464 |
+
self.embedding = nn.Embedding(
|
| 465 |
+
hparams.n_symbols, hparams.symbols_embedding_dim)
|
| 466 |
+
std = sqrt(2.0 / (hparams.n_symbols + hparams.symbols_embedding_dim))
|
| 467 |
+
val = sqrt(3.0) * std # uniform bounds for std
|
| 468 |
+
self.embedding.weight.data.uniform_(-val, val)
|
| 469 |
+
self.encoder = Encoder(hparams)
|
| 470 |
+
self.decoder = Decoder(hparams)
|
| 471 |
+
self.postnet = Postnet(hparams)
|
| 472 |
+
|
| 473 |
+
def parse_batch(self, batch):
|
| 474 |
+
text_padded, input_lengths, mel_padded, gate_padded, \
|
| 475 |
+
output_lengths = batch
|
| 476 |
+
text_padded = to_gpu(text_padded).long()
|
| 477 |
+
input_lengths = to_gpu(input_lengths).long()
|
| 478 |
+
max_len = torch.max(input_lengths.data).item()
|
| 479 |
+
mel_padded = to_gpu(mel_padded).float()
|
| 480 |
+
gate_padded = to_gpu(gate_padded).float()
|
| 481 |
+
output_lengths = to_gpu(output_lengths).long()
|
| 482 |
+
|
| 483 |
+
return (
|
| 484 |
+
(text_padded, input_lengths, mel_padded, max_len, output_lengths),
|
| 485 |
+
(mel_padded, gate_padded))
|
| 486 |
+
|
| 487 |
+
def parse_output(self, outputs, output_lengths=None):
|
| 488 |
+
if self.mask_padding and output_lengths is not None:
|
| 489 |
+
mask = ~get_mask_from_lengths(output_lengths)
|
| 490 |
+
mask = mask.expand(self.n_mel_channels, mask.size(0), mask.size(1))
|
| 491 |
+
mask = mask.permute(1, 0, 2)
|
| 492 |
+
|
| 493 |
+
outputs[0].data.masked_fill_(mask, 0.0)
|
| 494 |
+
outputs[1].data.masked_fill_(mask, 0.0)
|
| 495 |
+
outputs[2].data.masked_fill_(mask[:, 0, :], 1e3) # gate energies
|
| 496 |
+
|
| 497 |
+
return outputs
|
| 498 |
+
|
| 499 |
+
def forward(self, inputs):
|
| 500 |
+
text_inputs, text_lengths, mels, max_len, output_lengths = inputs
|
| 501 |
+
text_lengths, output_lengths = text_lengths.data, output_lengths.data
|
| 502 |
+
|
| 503 |
+
embedded_inputs = self.embedding(text_inputs).transpose(1, 2)
|
| 504 |
+
|
| 505 |
+
encoder_outputs = self.encoder(embedded_inputs, text_lengths)
|
| 506 |
+
|
| 507 |
+
mel_outputs, gate_outputs, alignments = self.decoder(
|
| 508 |
+
encoder_outputs, mels, memory_lengths=text_lengths)
|
| 509 |
+
|
| 510 |
+
mel_outputs_postnet = self.postnet(mel_outputs)
|
| 511 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 512 |
+
|
| 513 |
+
return self.parse_output(
|
| 514 |
+
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments],
|
| 515 |
+
output_lengths)
|
| 516 |
+
|
| 517 |
+
def inference(self, inputs):
|
| 518 |
+
embedded_inputs = self.embedding(inputs).transpose(1, 2)
|
| 519 |
+
encoder_outputs = self.encoder.inference(embedded_inputs)
|
| 520 |
+
mel_outputs, gate_outputs, alignments = self.decoder.inference(
|
| 521 |
+
encoder_outputs)
|
| 522 |
+
|
| 523 |
+
mel_outputs_postnet = self.postnet(mel_outputs)
|
| 524 |
+
mel_outputs_postnet = mel_outputs + mel_outputs_postnet
|
| 525 |
+
|
| 526 |
+
outputs = self.parse_output(
|
| 527 |
+
[mel_outputs, mel_outputs_postnet, gate_outputs, alignments])
|
| 528 |
+
|
| 529 |
+
return outputs
|
multiproc.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import time
|
| 2 |
+
import torch
|
| 3 |
+
import sys
|
| 4 |
+
import subprocess
|
| 5 |
+
|
| 6 |
+
argslist = list(sys.argv)[1:]
|
| 7 |
+
num_gpus = torch.cuda.device_count()
|
| 8 |
+
argslist.append('--n_gpus={}'.format(num_gpus))
|
| 9 |
+
workers = []
|
| 10 |
+
job_id = time.strftime("%Y_%m_%d-%H%M%S")
|
| 11 |
+
argslist.append("--group_name=group_{}".format(job_id))
|
| 12 |
+
|
| 13 |
+
for i in range(num_gpus):
|
| 14 |
+
argslist.append('--rank={}'.format(i))
|
| 15 |
+
stdout = None if i == 0 else open("logs/{}_GPU_{}.log".format(job_id, i),
|
| 16 |
+
"w")
|
| 17 |
+
print(argslist)
|
| 18 |
+
p = subprocess.Popen([str(sys.executable)]+argslist, stdout=stdout)
|
| 19 |
+
workers.append(p)
|
| 20 |
+
argslist = argslist[:-1]
|
| 21 |
+
|
| 22 |
+
for p in workers:
|
| 23 |
+
p.wait()
|
plotting_utils.py
ADDED
|
@@ -0,0 +1,61 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import matplotlib
|
| 2 |
+
matplotlib.use("Agg")
|
| 3 |
+
import matplotlib.pylab as plt
|
| 4 |
+
import numpy as np
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def save_figure_to_numpy(fig):
|
| 8 |
+
# save it to a numpy array.
|
| 9 |
+
data = np.fromstring(fig.canvas.tostring_rgb(), dtype=np.uint8, sep='')
|
| 10 |
+
data = data.reshape(fig.canvas.get_width_height()[::-1] + (3,))
|
| 11 |
+
return data
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def plot_alignment_to_numpy(alignment, info=None):
|
| 15 |
+
fig, ax = plt.subplots(figsize=(6, 4))
|
| 16 |
+
im = ax.imshow(alignment, aspect='auto', origin='lower',
|
| 17 |
+
interpolation='none')
|
| 18 |
+
fig.colorbar(im, ax=ax)
|
| 19 |
+
xlabel = 'Decoder timestep'
|
| 20 |
+
if info is not None:
|
| 21 |
+
xlabel += '\n\n' + info
|
| 22 |
+
plt.xlabel(xlabel)
|
| 23 |
+
plt.ylabel('Encoder timestep')
|
| 24 |
+
plt.tight_layout()
|
| 25 |
+
|
| 26 |
+
fig.canvas.draw()
|
| 27 |
+
data = save_figure_to_numpy(fig)
|
| 28 |
+
plt.close()
|
| 29 |
+
return data
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def plot_spectrogram_to_numpy(spectrogram):
|
| 33 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
| 34 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower",
|
| 35 |
+
interpolation='none')
|
| 36 |
+
plt.colorbar(im, ax=ax)
|
| 37 |
+
plt.xlabel("Frames")
|
| 38 |
+
plt.ylabel("Channels")
|
| 39 |
+
plt.tight_layout()
|
| 40 |
+
|
| 41 |
+
fig.canvas.draw()
|
| 42 |
+
data = save_figure_to_numpy(fig)
|
| 43 |
+
plt.close()
|
| 44 |
+
return data
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def plot_gate_outputs_to_numpy(gate_targets, gate_outputs):
|
| 48 |
+
fig, ax = plt.subplots(figsize=(12, 3))
|
| 49 |
+
ax.scatter(range(len(gate_targets)), gate_targets, alpha=0.5,
|
| 50 |
+
color='green', marker='+', s=1, label='target')
|
| 51 |
+
ax.scatter(range(len(gate_outputs)), gate_outputs, alpha=0.5,
|
| 52 |
+
color='red', marker='.', s=1, label='predicted')
|
| 53 |
+
|
| 54 |
+
plt.xlabel("Frames (Green target, Red predicted)")
|
| 55 |
+
plt.ylabel("Gate State")
|
| 56 |
+
plt.tight_layout()
|
| 57 |
+
|
| 58 |
+
fig.canvas.draw()
|
| 59 |
+
data = save_figure_to_numpy(fig)
|
| 60 |
+
plt.close()
|
| 61 |
+
return data
|
stft.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BSD 3-Clause License
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017, Prem Seetharaman
|
| 5 |
+
All rights reserved.
|
| 6 |
+
|
| 7 |
+
* Redistribution and use in source and binary forms, with or without
|
| 8 |
+
modification, are permitted provided that the following conditions are met:
|
| 9 |
+
|
| 10 |
+
* Redistributions of source code must retain the above copyright notice,
|
| 11 |
+
this list of conditions and the following disclaimer.
|
| 12 |
+
|
| 13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
| 14 |
+
list of conditions and the following disclaimer in the
|
| 15 |
+
documentation and/or other materials provided with the distribution.
|
| 16 |
+
|
| 17 |
+
* Neither the name of the copyright holder nor the names of its
|
| 18 |
+
contributors may be used to endorse or promote products derived from this
|
| 19 |
+
software without specific prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from torch.autograd import Variable
|
| 37 |
+
from scipy.signal import get_window
|
| 38 |
+
from librosa.util import pad_center, tiny
|
| 39 |
+
from audio_processing import window_sumsquare
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
class STFT(torch.nn.Module):
|
| 43 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 44 |
+
def __init__(self, filter_length=1024, hop_length=256, win_length=1024,
|
| 45 |
+
window='hann'):
|
| 46 |
+
super(STFT, self).__init__()
|
| 47 |
+
self.filter_length = filter_length
|
| 48 |
+
self.hop_length = hop_length
|
| 49 |
+
self.win_length = win_length
|
| 50 |
+
self.window = window
|
| 51 |
+
self.forward_transform = None
|
| 52 |
+
scale = self.filter_length / self.hop_length
|
| 53 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 54 |
+
|
| 55 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 56 |
+
fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
|
| 57 |
+
np.imag(fourier_basis[:cutoff, :])])
|
| 58 |
+
|
| 59 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 60 |
+
inverse_basis = torch.FloatTensor(
|
| 61 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :])
|
| 62 |
+
|
| 63 |
+
if window is not None:
|
| 64 |
+
assert(filter_length >= win_length)
|
| 65 |
+
# get window and zero center pad it to filter_length
|
| 66 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
| 67 |
+
fft_window = pad_center(fft_window, filter_length)
|
| 68 |
+
fft_window = torch.from_numpy(fft_window).float()
|
| 69 |
+
|
| 70 |
+
# window the bases
|
| 71 |
+
forward_basis *= fft_window
|
| 72 |
+
inverse_basis *= fft_window
|
| 73 |
+
|
| 74 |
+
self.register_buffer('forward_basis', forward_basis.float())
|
| 75 |
+
self.register_buffer('inverse_basis', inverse_basis.float())
|
| 76 |
+
|
| 77 |
+
def transform(self, input_data):
|
| 78 |
+
num_batches = input_data.size(0)
|
| 79 |
+
num_samples = input_data.size(1)
|
| 80 |
+
|
| 81 |
+
self.num_samples = num_samples
|
| 82 |
+
|
| 83 |
+
# similar to librosa, reflect-pad the input
|
| 84 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
| 85 |
+
input_data = F.pad(
|
| 86 |
+
input_data.unsqueeze(1),
|
| 87 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 88 |
+
mode='reflect')
|
| 89 |
+
input_data = input_data.squeeze(1)
|
| 90 |
+
|
| 91 |
+
forward_transform = F.conv1d(
|
| 92 |
+
input_data,
|
| 93 |
+
Variable(self.forward_basis, requires_grad=False),
|
| 94 |
+
stride=self.hop_length,
|
| 95 |
+
padding=0)
|
| 96 |
+
|
| 97 |
+
cutoff = int((self.filter_length / 2) + 1)
|
| 98 |
+
real_part = forward_transform[:, :cutoff, :]
|
| 99 |
+
imag_part = forward_transform[:, cutoff:, :]
|
| 100 |
+
|
| 101 |
+
magnitude = torch.sqrt(real_part**2 + imag_part**2)
|
| 102 |
+
phase = torch.autograd.Variable(
|
| 103 |
+
torch.atan2(imag_part.data, real_part.data))
|
| 104 |
+
|
| 105 |
+
return magnitude, phase
|
| 106 |
+
|
| 107 |
+
def inverse(self, magnitude, phase):
|
| 108 |
+
recombine_magnitude_phase = torch.cat(
|
| 109 |
+
[magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
|
| 110 |
+
|
| 111 |
+
inverse_transform = F.conv_transpose1d(
|
| 112 |
+
recombine_magnitude_phase,
|
| 113 |
+
Variable(self.inverse_basis, requires_grad=False),
|
| 114 |
+
stride=self.hop_length,
|
| 115 |
+
padding=0)
|
| 116 |
+
|
| 117 |
+
if self.window is not None:
|
| 118 |
+
window_sum = window_sumsquare(
|
| 119 |
+
self.window, magnitude.size(-1), hop_length=self.hop_length,
|
| 120 |
+
win_length=self.win_length, n_fft=self.filter_length,
|
| 121 |
+
dtype=np.float32)
|
| 122 |
+
# remove modulation effects
|
| 123 |
+
approx_nonzero_indices = torch.from_numpy(
|
| 124 |
+
np.where(window_sum > tiny(window_sum))[0])
|
| 125 |
+
window_sum = torch.autograd.Variable(
|
| 126 |
+
torch.from_numpy(window_sum), requires_grad=False)
|
| 127 |
+
window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
|
| 128 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
|
| 129 |
+
|
| 130 |
+
# scale by hop ratio
|
| 131 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
| 132 |
+
|
| 133 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
|
| 134 |
+
inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
|
| 135 |
+
|
| 136 |
+
return inverse_transform
|
| 137 |
+
|
| 138 |
+
def forward(self, input_data):
|
| 139 |
+
self.magnitude, self.phase = self.transform(input_data)
|
| 140 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
| 141 |
+
return reconstruction
|
train.py
ADDED
|
@@ -0,0 +1,290 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
from numpy import finfo
|
| 6 |
+
|
| 7 |
+
import torch
|
| 8 |
+
from distributed import apply_gradient_allreduce
|
| 9 |
+
import torch.distributed as dist
|
| 10 |
+
from torch.utils.data.distributed import DistributedSampler
|
| 11 |
+
from torch.utils.data import DataLoader
|
| 12 |
+
|
| 13 |
+
from model import Tacotron2
|
| 14 |
+
from data_utils import TextMelLoader, TextMelCollate
|
| 15 |
+
from loss_function import Tacotron2Loss
|
| 16 |
+
from logger import Tacotron2Logger
|
| 17 |
+
from hparams import create_hparams
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def reduce_tensor(tensor, n_gpus):
|
| 21 |
+
rt = tensor.clone()
|
| 22 |
+
dist.all_reduce(rt, op=dist.reduce_op.SUM)
|
| 23 |
+
rt /= n_gpus
|
| 24 |
+
return rt
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
def init_distributed(hparams, n_gpus, rank, group_name):
|
| 28 |
+
assert torch.cuda.is_available(), "Distributed mode requires CUDA."
|
| 29 |
+
print("Initializing Distributed")
|
| 30 |
+
|
| 31 |
+
# Set cuda device so everything is done on the right GPU.
|
| 32 |
+
torch.cuda.set_device(rank % torch.cuda.device_count())
|
| 33 |
+
|
| 34 |
+
# Initialize distributed communication
|
| 35 |
+
dist.init_process_group(
|
| 36 |
+
backend=hparams.dist_backend, init_method=hparams.dist_url,
|
| 37 |
+
world_size=n_gpus, rank=rank, group_name=group_name)
|
| 38 |
+
|
| 39 |
+
print("Done initializing distributed")
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def prepare_dataloaders(hparams):
|
| 43 |
+
# Get data, data loaders and collate function ready
|
| 44 |
+
trainset = TextMelLoader(hparams.training_files, hparams)
|
| 45 |
+
valset = TextMelLoader(hparams.validation_files, hparams)
|
| 46 |
+
collate_fn = TextMelCollate(hparams.n_frames_per_step)
|
| 47 |
+
|
| 48 |
+
if hparams.distributed_run:
|
| 49 |
+
train_sampler = DistributedSampler(trainset)
|
| 50 |
+
shuffle = False
|
| 51 |
+
else:
|
| 52 |
+
train_sampler = None
|
| 53 |
+
shuffle = True
|
| 54 |
+
|
| 55 |
+
train_loader = DataLoader(trainset, num_workers=1, shuffle=shuffle,
|
| 56 |
+
sampler=train_sampler,
|
| 57 |
+
batch_size=hparams.batch_size, pin_memory=False,
|
| 58 |
+
drop_last=True, collate_fn=collate_fn)
|
| 59 |
+
return train_loader, valset, collate_fn
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def prepare_directories_and_logger(output_directory, log_directory, rank):
|
| 63 |
+
if rank == 0:
|
| 64 |
+
if not os.path.isdir(output_directory):
|
| 65 |
+
os.makedirs(output_directory)
|
| 66 |
+
os.chmod(output_directory, 0o775)
|
| 67 |
+
logger = Tacotron2Logger(os.path.join(output_directory, log_directory))
|
| 68 |
+
else:
|
| 69 |
+
logger = None
|
| 70 |
+
return logger
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
def load_model(hparams):
|
| 74 |
+
model = Tacotron2(hparams).cuda()
|
| 75 |
+
if hparams.fp16_run:
|
| 76 |
+
model.decoder.attention_layer.score_mask_value = finfo('float16').min
|
| 77 |
+
|
| 78 |
+
if hparams.distributed_run:
|
| 79 |
+
model = apply_gradient_allreduce(model)
|
| 80 |
+
|
| 81 |
+
return model
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
def warm_start_model(checkpoint_path, model, ignore_layers):
|
| 85 |
+
assert os.path.isfile(checkpoint_path)
|
| 86 |
+
print("Warm starting model from checkpoint '{}'".format(checkpoint_path))
|
| 87 |
+
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
| 88 |
+
model_dict = checkpoint_dict['state_dict']
|
| 89 |
+
if len(ignore_layers) > 0:
|
| 90 |
+
model_dict = {k: v for k, v in model_dict.items()
|
| 91 |
+
if k not in ignore_layers}
|
| 92 |
+
dummy_dict = model.state_dict()
|
| 93 |
+
dummy_dict.update(model_dict)
|
| 94 |
+
model_dict = dummy_dict
|
| 95 |
+
model.load_state_dict(model_dict)
|
| 96 |
+
return model
|
| 97 |
+
|
| 98 |
+
|
| 99 |
+
def load_checkpoint(checkpoint_path, model, optimizer):
|
| 100 |
+
assert os.path.isfile(checkpoint_path)
|
| 101 |
+
print("Loading checkpoint '{}'".format(checkpoint_path))
|
| 102 |
+
checkpoint_dict = torch.load(checkpoint_path, map_location='cpu')
|
| 103 |
+
model.load_state_dict(checkpoint_dict['state_dict'])
|
| 104 |
+
optimizer.load_state_dict(checkpoint_dict['optimizer'])
|
| 105 |
+
learning_rate = checkpoint_dict['learning_rate']
|
| 106 |
+
iteration = checkpoint_dict['iteration']
|
| 107 |
+
print("Loaded checkpoint '{}' from iteration {}" .format(
|
| 108 |
+
checkpoint_path, iteration))
|
| 109 |
+
return model, optimizer, learning_rate, iteration
|
| 110 |
+
|
| 111 |
+
|
| 112 |
+
def save_checkpoint(model, optimizer, learning_rate, iteration, filepath):
|
| 113 |
+
print("Saving model and optimizer state at iteration {} to {}".format(
|
| 114 |
+
iteration, filepath))
|
| 115 |
+
torch.save({'iteration': iteration,
|
| 116 |
+
'state_dict': model.state_dict(),
|
| 117 |
+
'optimizer': optimizer.state_dict(),
|
| 118 |
+
'learning_rate': learning_rate}, filepath)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def validate(model, criterion, valset, iteration, batch_size, n_gpus,
|
| 122 |
+
collate_fn, logger, distributed_run, rank):
|
| 123 |
+
"""Handles all the validation scoring and printing"""
|
| 124 |
+
model.eval()
|
| 125 |
+
with torch.no_grad():
|
| 126 |
+
val_sampler = DistributedSampler(valset) if distributed_run else None
|
| 127 |
+
val_loader = DataLoader(valset, sampler=val_sampler, num_workers=1,
|
| 128 |
+
shuffle=False, batch_size=batch_size,
|
| 129 |
+
pin_memory=False, collate_fn=collate_fn)
|
| 130 |
+
|
| 131 |
+
val_loss = 0.0
|
| 132 |
+
for i, batch in enumerate(val_loader):
|
| 133 |
+
x, y = model.parse_batch(batch)
|
| 134 |
+
y_pred = model(x)
|
| 135 |
+
loss = criterion(y_pred, y)
|
| 136 |
+
if distributed_run:
|
| 137 |
+
reduced_val_loss = reduce_tensor(loss.data, n_gpus).item()
|
| 138 |
+
else:
|
| 139 |
+
reduced_val_loss = loss.item()
|
| 140 |
+
val_loss += reduced_val_loss
|
| 141 |
+
val_loss = val_loss / (i + 1)
|
| 142 |
+
|
| 143 |
+
model.train()
|
| 144 |
+
if rank == 0:
|
| 145 |
+
print("Validation loss {}: {:9f} ".format(iteration, val_loss))
|
| 146 |
+
logger.log_validation(val_loss, model, y, y_pred, iteration)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def train(output_directory, log_directory, checkpoint_path, warm_start, n_gpus,
|
| 150 |
+
rank, group_name, hparams):
|
| 151 |
+
"""Training and validation logging results to tensorboard and stdout
|
| 152 |
+
|
| 153 |
+
Params
|
| 154 |
+
------
|
| 155 |
+
output_directory (string): directory to save checkpoints
|
| 156 |
+
log_directory (string) directory to save tensorboard logs
|
| 157 |
+
checkpoint_path(string): checkpoint path
|
| 158 |
+
n_gpus (int): number of gpus
|
| 159 |
+
rank (int): rank of current gpu
|
| 160 |
+
hparams (object): comma separated list of "name=value" pairs.
|
| 161 |
+
"""
|
| 162 |
+
if hparams.distributed_run:
|
| 163 |
+
init_distributed(hparams, n_gpus, rank, group_name)
|
| 164 |
+
|
| 165 |
+
torch.manual_seed(hparams.seed)
|
| 166 |
+
torch.cuda.manual_seed(hparams.seed)
|
| 167 |
+
|
| 168 |
+
model = load_model(hparams)
|
| 169 |
+
learning_rate = hparams.learning_rate
|
| 170 |
+
optimizer = torch.optim.Adam(model.parameters(), lr=learning_rate,
|
| 171 |
+
weight_decay=hparams.weight_decay)
|
| 172 |
+
|
| 173 |
+
if hparams.fp16_run:
|
| 174 |
+
from apex import amp
|
| 175 |
+
model, optimizer = amp.initialize(
|
| 176 |
+
model, optimizer, opt_level='O2')
|
| 177 |
+
|
| 178 |
+
if hparams.distributed_run:
|
| 179 |
+
model = apply_gradient_allreduce(model)
|
| 180 |
+
|
| 181 |
+
criterion = Tacotron2Loss()
|
| 182 |
+
|
| 183 |
+
logger = prepare_directories_and_logger(
|
| 184 |
+
output_directory, log_directory, rank)
|
| 185 |
+
|
| 186 |
+
train_loader, valset, collate_fn = prepare_dataloaders(hparams)
|
| 187 |
+
|
| 188 |
+
# Load checkpoint if one exists
|
| 189 |
+
iteration = 0
|
| 190 |
+
epoch_offset = 0
|
| 191 |
+
if checkpoint_path is not None:
|
| 192 |
+
if warm_start:
|
| 193 |
+
model = warm_start_model(
|
| 194 |
+
checkpoint_path, model, hparams.ignore_layers)
|
| 195 |
+
else:
|
| 196 |
+
model, optimizer, _learning_rate, iteration = load_checkpoint(
|
| 197 |
+
checkpoint_path, model, optimizer)
|
| 198 |
+
if hparams.use_saved_learning_rate:
|
| 199 |
+
learning_rate = _learning_rate
|
| 200 |
+
iteration += 1 # next iteration is iteration + 1
|
| 201 |
+
epoch_offset = max(0, int(iteration / len(train_loader)))
|
| 202 |
+
|
| 203 |
+
model.train()
|
| 204 |
+
is_overflow = False
|
| 205 |
+
# ================ MAIN TRAINNIG LOOP! ===================
|
| 206 |
+
for epoch in range(epoch_offset, hparams.epochs):
|
| 207 |
+
print("Epoch: {}".format(epoch))
|
| 208 |
+
for i, batch in enumerate(train_loader):
|
| 209 |
+
start = time.perf_counter()
|
| 210 |
+
for param_group in optimizer.param_groups:
|
| 211 |
+
param_group['lr'] = learning_rate
|
| 212 |
+
|
| 213 |
+
model.zero_grad()
|
| 214 |
+
x, y = model.parse_batch(batch)
|
| 215 |
+
y_pred = model(x)
|
| 216 |
+
|
| 217 |
+
loss = criterion(y_pred, y)
|
| 218 |
+
if hparams.distributed_run:
|
| 219 |
+
reduced_loss = reduce_tensor(loss.data, n_gpus).item()
|
| 220 |
+
else:
|
| 221 |
+
reduced_loss = loss.item()
|
| 222 |
+
if hparams.fp16_run:
|
| 223 |
+
with amp.scale_loss(loss, optimizer) as scaled_loss:
|
| 224 |
+
scaled_loss.backward()
|
| 225 |
+
else:
|
| 226 |
+
loss.backward()
|
| 227 |
+
|
| 228 |
+
if hparams.fp16_run:
|
| 229 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 230 |
+
amp.master_params(optimizer), hparams.grad_clip_thresh)
|
| 231 |
+
is_overflow = math.isnan(grad_norm)
|
| 232 |
+
else:
|
| 233 |
+
grad_norm = torch.nn.utils.clip_grad_norm_(
|
| 234 |
+
model.parameters(), hparams.grad_clip_thresh)
|
| 235 |
+
|
| 236 |
+
optimizer.step()
|
| 237 |
+
|
| 238 |
+
if not is_overflow and rank == 0:
|
| 239 |
+
duration = time.perf_counter() - start
|
| 240 |
+
print("Train loss {} {:.6f} Grad Norm {:.6f} {:.2f}s/it".format(
|
| 241 |
+
iteration, reduced_loss, grad_norm, duration))
|
| 242 |
+
logger.log_training(
|
| 243 |
+
reduced_loss, grad_norm, learning_rate, duration, iteration)
|
| 244 |
+
|
| 245 |
+
if not is_overflow and (iteration % hparams.iters_per_checkpoint == 0):
|
| 246 |
+
validate(model, criterion, valset, iteration,
|
| 247 |
+
hparams.batch_size, n_gpus, collate_fn, logger,
|
| 248 |
+
hparams.distributed_run, rank)
|
| 249 |
+
if rank == 0:
|
| 250 |
+
checkpoint_path = os.path.join(
|
| 251 |
+
output_directory, "checkpoint_{}".format(iteration))
|
| 252 |
+
save_checkpoint(model, optimizer, learning_rate, iteration,
|
| 253 |
+
checkpoint_path)
|
| 254 |
+
|
| 255 |
+
iteration += 1
|
| 256 |
+
|
| 257 |
+
|
| 258 |
+
if __name__ == '__main__':
|
| 259 |
+
parser = argparse.ArgumentParser()
|
| 260 |
+
parser.add_argument('-o', '--output_directory', type=str,
|
| 261 |
+
help='directory to save checkpoints')
|
| 262 |
+
parser.add_argument('-l', '--log_directory', type=str,
|
| 263 |
+
help='directory to save tensorboard logs')
|
| 264 |
+
parser.add_argument('-c', '--checkpoint_path', type=str, default=None,
|
| 265 |
+
required=False, help='checkpoint path')
|
| 266 |
+
parser.add_argument('--warm_start', action='store_true',
|
| 267 |
+
help='load model weights only, ignore specified layers')
|
| 268 |
+
parser.add_argument('--n_gpus', type=int, default=1,
|
| 269 |
+
required=False, help='number of gpus')
|
| 270 |
+
parser.add_argument('--rank', type=int, default=0,
|
| 271 |
+
required=False, help='rank of current gpu')
|
| 272 |
+
parser.add_argument('--group_name', type=str, default='group_name',
|
| 273 |
+
required=False, help='Distributed group name')
|
| 274 |
+
parser.add_argument('--hparams', type=str,
|
| 275 |
+
required=False, help='comma separated name=value pairs')
|
| 276 |
+
|
| 277 |
+
args = parser.parse_args()
|
| 278 |
+
hparams = create_hparams(args.hparams)
|
| 279 |
+
|
| 280 |
+
torch.backends.cudnn.enabled = hparams.cudnn_enabled
|
| 281 |
+
torch.backends.cudnn.benchmark = hparams.cudnn_benchmark
|
| 282 |
+
|
| 283 |
+
print("FP16 Run:", hparams.fp16_run)
|
| 284 |
+
print("Dynamic Loss Scaling:", hparams.dynamic_loss_scaling)
|
| 285 |
+
print("Distributed Run:", hparams.distributed_run)
|
| 286 |
+
print("cuDNN Enabled:", hparams.cudnn_enabled)
|
| 287 |
+
print("cuDNN Benchmark:", hparams.cudnn_benchmark)
|
| 288 |
+
|
| 289 |
+
train(args.output_directory, args.log_directory, args.checkpoint_path,
|
| 290 |
+
args.warm_start, args.n_gpus, args.rank, args.group_name, hparams)
|
utils.py
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
from scipy.io.wavfile import read
|
| 3 |
+
import torch
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
def get_mask_from_lengths(lengths):
|
| 7 |
+
max_len = torch.max(lengths).item()
|
| 8 |
+
ids = torch.arange(0, max_len, out=torch.cuda.LongTensor(max_len))
|
| 9 |
+
mask = (ids < lengths.unsqueeze(1)).bool()
|
| 10 |
+
return mask
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_wav_to_torch(full_path):
|
| 14 |
+
sampling_rate, data = read(full_path)
|
| 15 |
+
return torch.FloatTensor(data.astype(np.float32)), sampling_rate
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def load_filepaths_and_text(filename, split="|"):
|
| 19 |
+
with open(filename, encoding='utf-8') as f:
|
| 20 |
+
filepaths_and_text = [line.strip().split(split) for line in f]
|
| 21 |
+
return filepaths_and_text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def to_gpu(x):
|
| 25 |
+
x = x.contiguous()
|
| 26 |
+
|
| 27 |
+
if torch.cuda.is_available():
|
| 28 |
+
x = x.cuda(non_blocking=True)
|
| 29 |
+
return torch.autograd.Variable(x)
|