Commit
·
4cbdd15
1
Parent(s):
21629ac
Upload 89 files
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitattributes +1 -0
- my_gradio_app.py +63 -0
- src/glow_tts/__pycache__/attentions.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/audio_processing.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/commons.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/data_utils.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/models.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/modules.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/stft.cpython-37.pyc +0 -0
- src/glow_tts/__pycache__/utils.cpython-37.pyc +0 -0
- src/glow_tts/attentions.py +378 -0
- src/glow_tts/audio_processing.py +100 -0
- src/glow_tts/commons.py +273 -0
- src/glow_tts/data_utils.py +274 -0
- src/glow_tts/generate_mels.py +70 -0
- src/glow_tts/hifi/__init__.py +5 -0
- src/glow_tts/hifi/__pycache__/__init__.cpython-37.pyc +0 -0
- src/glow_tts/hifi/__pycache__/env.cpython-37.pyc +0 -0
- src/glow_tts/hifi/__pycache__/models.cpython-37.pyc +0 -0
- src/glow_tts/hifi/__pycache__/utils.cpython-37.pyc +0 -0
- src/glow_tts/hifi/env.py +15 -0
- src/glow_tts/hifi/models.py +403 -0
- src/glow_tts/hifi/utils.py +57 -0
- src/glow_tts/init.py +79 -0
- src/glow_tts/models.py +403 -0
- src/glow_tts/modules.py +276 -0
- src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/__init__.py +5 -0
- src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/core.cpython-37m-x86_64-linux-gnu.so +0 -0
- src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/mas.py +57 -0
- src/glow_tts/monotonic_align/build/temp.linux-x86_64-cpython-37/monotonic_align/core.o +3 -0
- src/glow_tts/monotonic_align/monotonic_align.egg-info/PKG-INFO +3 -0
- src/glow_tts/monotonic_align/monotonic_align.egg-info/SOURCES.txt +10 -0
- src/glow_tts/monotonic_align/monotonic_align.egg-info/dependency_links.txt +1 -0
- src/glow_tts/monotonic_align/monotonic_align.egg-info/requires.txt +1 -0
- src/glow_tts/monotonic_align/monotonic_align.egg-info/top_level.txt +1 -0
- src/glow_tts/monotonic_align/monotonic_align/__init__.py +5 -0
- src/glow_tts/monotonic_align/monotonic_align/core.c +0 -0
- src/glow_tts/monotonic_align/monotonic_align/core.pyx +45 -0
- src/glow_tts/monotonic_align/monotonic_align/mas.py +57 -0
- src/glow_tts/monotonic_align/pyproject.toml +7 -0
- src/glow_tts/monotonic_align/setup.py +23 -0
- src/glow_tts/stft.py +185 -0
- src/glow_tts/t2s_fastapi.py +63 -0
- src/glow_tts/t2s_gradio.py +24 -0
- src/glow_tts/text/__init__.py +84 -0
- src/glow_tts/text/__pycache__/__init__.cpython-37.pyc +0 -0
- src/glow_tts/text/__pycache__/cleaners.cpython-37.pyc +0 -0
- src/glow_tts/text/__pycache__/numbers.cpython-37.pyc +0 -0
- src/glow_tts/text/cleaners.py +78 -0
- src/glow_tts/text/numbers.py +69 -0
.gitattributes
CHANGED
|
@@ -37,3 +37,4 @@ checkpoints/hifi/female/do_00040000 filter=lfs diff=lfs merge=lfs -text
|
|
| 37 |
checkpoints/hifi/female/g_00040000 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
checkpoints/hifi/male/do_00060000 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
checkpoints/hifi/male/g_00060000 filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
| 37 |
checkpoints/hifi/female/g_00040000 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
checkpoints/hifi/male/do_00060000 filter=lfs diff=lfs merge=lfs -text
|
| 39 |
checkpoints/hifi/male/g_00060000 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
src/glow_tts/monotonic_align/build/temp.linux-x86_64-cpython-37/monotonic_align/core.o filter=lfs diff=lfs merge=lfs -text
|
my_gradio_app.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from tts_infer.tts import TextToMel, MelToWav
|
| 2 |
+
from tts_infer.num_to_word_on_sent import normalize_nums
|
| 3 |
+
from ai4bharat.transliteration import XlitEngine
|
| 4 |
+
import gradio as gr
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
device = 'cpu'
|
| 8 |
+
|
| 9 |
+
def create_text_to_mel(glow_model_dir):
|
| 10 |
+
return TextToMel(glow_model_dir=glow_model_dir, device=device)
|
| 11 |
+
|
| 12 |
+
text_to_mel_female = create_text_to_mel('checkpoints/glow/female')
|
| 13 |
+
mel_to_wav_female = MelToWav(hifi_model_dir='checkpoints/hifi/female', device=device)
|
| 14 |
+
|
| 15 |
+
text_to_mel_male = create_text_to_mel('checkpoints/glow/male')
|
| 16 |
+
mel_to_wav_male = MelToWav(hifi_model_dir='checkpoints/hifi/male', device=device)
|
| 17 |
+
|
| 18 |
+
def translit(text, lang):
|
| 19 |
+
engine = XlitEngine(lang)
|
| 20 |
+
words = [engine.translit_word(word, topk=1)[lang][0] for word in text.split()]
|
| 21 |
+
updated_sent = ' '.join(words)
|
| 22 |
+
return updated_sent
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def run_tts(text, selected_voice):
|
| 27 |
+
lang = "pa" # Punjabi language code
|
| 28 |
+
text = text.replace('।', '.')
|
| 29 |
+
text_num_to_word = normalize_nums(text, lang) # converting numbers to words in lang
|
| 30 |
+
text_num_to_word_and_transliterated = translit(text_num_to_word, lang) # transliterating English words to lang
|
| 31 |
+
|
| 32 |
+
if selected_voice == "Male Voice":
|
| 33 |
+
text_to_mel = text_to_mel_male
|
| 34 |
+
mel_to_wav = mel_to_wav_male
|
| 35 |
+
else:
|
| 36 |
+
text_to_mel = text_to_mel_female
|
| 37 |
+
mel_to_wav = mel_to_wav_female
|
| 38 |
+
|
| 39 |
+
mel = text_to_mel.generate_mel(text_num_to_word_and_transliterated)
|
| 40 |
+
audio, sr = mel_to_wav.generate_wav(mel)
|
| 41 |
+
return sr, audio
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
iface = gr.Interface(
|
| 48 |
+
fn=run_tts,
|
| 49 |
+
inputs=[
|
| 50 |
+
"textbox",
|
| 51 |
+
gr.inputs.Dropdown(
|
| 52 |
+
choices=["Male Voice", "Female Voice"],
|
| 53 |
+
default="Female Voice",
|
| 54 |
+
label="Select Voice"
|
| 55 |
+
)
|
| 56 |
+
],
|
| 57 |
+
outputs="audio",
|
| 58 |
+
title="Text to Speech Punjabi Language"
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
iface.launch()
|
| 62 |
+
|
| 63 |
+
|
src/glow_tts/__pycache__/attentions.cpython-37.pyc
ADDED
|
Binary file (9.27 kB). View file
|
|
|
src/glow_tts/__pycache__/audio_processing.cpython-37.pyc
ADDED
|
Binary file (2.79 kB). View file
|
|
|
src/glow_tts/__pycache__/commons.cpython-37.pyc
ADDED
|
Binary file (8.61 kB). View file
|
|
|
src/glow_tts/__pycache__/data_utils.cpython-37.pyc
ADDED
|
Binary file (8.42 kB). View file
|
|
|
src/glow_tts/__pycache__/models.cpython-37.pyc
ADDED
|
Binary file (7.59 kB). View file
|
|
|
src/glow_tts/__pycache__/modules.cpython-37.pyc
ADDED
|
Binary file (7.39 kB). View file
|
|
|
src/glow_tts/__pycache__/stft.cpython-37.pyc
ADDED
|
Binary file (5.41 kB). View file
|
|
|
src/glow_tts/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (8.21 kB). View file
|
|
|
src/glow_tts/attentions.py
ADDED
|
@@ -0,0 +1,378 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import torch
|
| 5 |
+
from torch import nn
|
| 6 |
+
from torch.nn import functional as F
|
| 7 |
+
|
| 8 |
+
import commons
|
| 9 |
+
import modules
|
| 10 |
+
from modules import LayerNorm
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
class Encoder(nn.Module):
|
| 14 |
+
def __init__(
|
| 15 |
+
self,
|
| 16 |
+
hidden_channels,
|
| 17 |
+
filter_channels,
|
| 18 |
+
n_heads,
|
| 19 |
+
n_layers,
|
| 20 |
+
kernel_size=1,
|
| 21 |
+
p_dropout=0.0,
|
| 22 |
+
window_size=None,
|
| 23 |
+
block_length=None,
|
| 24 |
+
**kwargs
|
| 25 |
+
):
|
| 26 |
+
super().__init__()
|
| 27 |
+
self.hidden_channels = hidden_channels
|
| 28 |
+
self.filter_channels = filter_channels
|
| 29 |
+
self.n_heads = n_heads
|
| 30 |
+
self.n_layers = n_layers
|
| 31 |
+
self.kernel_size = kernel_size
|
| 32 |
+
self.p_dropout = p_dropout
|
| 33 |
+
self.window_size = window_size
|
| 34 |
+
self.block_length = block_length
|
| 35 |
+
|
| 36 |
+
self.drop = nn.Dropout(p_dropout)
|
| 37 |
+
self.attn_layers = nn.ModuleList()
|
| 38 |
+
self.norm_layers_1 = nn.ModuleList()
|
| 39 |
+
self.ffn_layers = nn.ModuleList()
|
| 40 |
+
self.norm_layers_2 = nn.ModuleList()
|
| 41 |
+
for i in range(self.n_layers):
|
| 42 |
+
self.attn_layers.append(
|
| 43 |
+
MultiHeadAttention(
|
| 44 |
+
hidden_channels,
|
| 45 |
+
hidden_channels,
|
| 46 |
+
n_heads,
|
| 47 |
+
window_size=window_size,
|
| 48 |
+
p_dropout=p_dropout,
|
| 49 |
+
block_length=block_length,
|
| 50 |
+
)
|
| 51 |
+
)
|
| 52 |
+
self.norm_layers_1.append(LayerNorm(hidden_channels))
|
| 53 |
+
self.ffn_layers.append(
|
| 54 |
+
FFN(
|
| 55 |
+
hidden_channels,
|
| 56 |
+
hidden_channels,
|
| 57 |
+
filter_channels,
|
| 58 |
+
kernel_size,
|
| 59 |
+
p_dropout=p_dropout,
|
| 60 |
+
)
|
| 61 |
+
)
|
| 62 |
+
self.norm_layers_2.append(LayerNorm(hidden_channels))
|
| 63 |
+
|
| 64 |
+
def forward(self, x, x_mask):
|
| 65 |
+
attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
|
| 66 |
+
for i in range(self.n_layers):
|
| 67 |
+
x = x * x_mask
|
| 68 |
+
y = self.attn_layers[i](x, x, attn_mask)
|
| 69 |
+
y = self.drop(y)
|
| 70 |
+
x = self.norm_layers_1[i](x + y)
|
| 71 |
+
|
| 72 |
+
y = self.ffn_layers[i](x, x_mask)
|
| 73 |
+
y = self.drop(y)
|
| 74 |
+
x = self.norm_layers_2[i](x + y)
|
| 75 |
+
x = x * x_mask
|
| 76 |
+
return x
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
class CouplingBlock(nn.Module):
|
| 80 |
+
def __init__(
|
| 81 |
+
self,
|
| 82 |
+
in_channels,
|
| 83 |
+
hidden_channels,
|
| 84 |
+
kernel_size,
|
| 85 |
+
dilation_rate,
|
| 86 |
+
n_layers,
|
| 87 |
+
gin_channels=0,
|
| 88 |
+
p_dropout=0,
|
| 89 |
+
sigmoid_scale=False,
|
| 90 |
+
):
|
| 91 |
+
super().__init__()
|
| 92 |
+
self.in_channels = in_channels
|
| 93 |
+
self.hidden_channels = hidden_channels
|
| 94 |
+
self.kernel_size = kernel_size
|
| 95 |
+
self.dilation_rate = dilation_rate
|
| 96 |
+
self.n_layers = n_layers
|
| 97 |
+
self.gin_channels = gin_channels
|
| 98 |
+
self.p_dropout = p_dropout
|
| 99 |
+
self.sigmoid_scale = sigmoid_scale
|
| 100 |
+
|
| 101 |
+
start = torch.nn.Conv1d(in_channels // 2, hidden_channels, 1)
|
| 102 |
+
start = torch.nn.utils.weight_norm(start)
|
| 103 |
+
self.start = start
|
| 104 |
+
# Initializing last layer to 0 makes the affine coupling layers
|
| 105 |
+
# do nothing at first. It helps to stabilze training.
|
| 106 |
+
end = torch.nn.Conv1d(hidden_channels, in_channels, 1)
|
| 107 |
+
end.weight.data.zero_()
|
| 108 |
+
end.bias.data.zero_()
|
| 109 |
+
self.end = end
|
| 110 |
+
|
| 111 |
+
self.wn = modules.WN(
|
| 112 |
+
in_channels,
|
| 113 |
+
hidden_channels,
|
| 114 |
+
kernel_size,
|
| 115 |
+
dilation_rate,
|
| 116 |
+
n_layers,
|
| 117 |
+
gin_channels,
|
| 118 |
+
p_dropout,
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
def forward(self, x, x_mask=None, reverse=False, g=None, **kwargs):
|
| 122 |
+
b, c, t = x.size()
|
| 123 |
+
if x_mask is None:
|
| 124 |
+
x_mask = 1
|
| 125 |
+
x_0, x_1 = x[:, : self.in_channels // 2], x[:, self.in_channels // 2 :]
|
| 126 |
+
|
| 127 |
+
x = self.start(x_0) * x_mask
|
| 128 |
+
x = self.wn(x, x_mask, g)
|
| 129 |
+
out = self.end(x)
|
| 130 |
+
|
| 131 |
+
z_0 = x_0
|
| 132 |
+
m = out[:, : self.in_channels // 2, :]
|
| 133 |
+
logs = out[:, self.in_channels // 2 :, :]
|
| 134 |
+
if self.sigmoid_scale:
|
| 135 |
+
logs = torch.log(1e-6 + torch.sigmoid(logs + 2))
|
| 136 |
+
|
| 137 |
+
if reverse:
|
| 138 |
+
z_1 = (x_1 - m) * torch.exp(-logs) * x_mask
|
| 139 |
+
logdet = None
|
| 140 |
+
else:
|
| 141 |
+
z_1 = (m + torch.exp(logs) * x_1) * x_mask
|
| 142 |
+
logdet = torch.sum(logs * x_mask, [1, 2])
|
| 143 |
+
|
| 144 |
+
z = torch.cat([z_0, z_1], 1)
|
| 145 |
+
return z, logdet
|
| 146 |
+
|
| 147 |
+
def store_inverse(self):
|
| 148 |
+
self.wn.remove_weight_norm()
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
class MultiHeadAttention(nn.Module):
|
| 152 |
+
def __init__(
|
| 153 |
+
self,
|
| 154 |
+
channels,
|
| 155 |
+
out_channels,
|
| 156 |
+
n_heads,
|
| 157 |
+
window_size=None,
|
| 158 |
+
heads_share=True,
|
| 159 |
+
p_dropout=0.0,
|
| 160 |
+
block_length=None,
|
| 161 |
+
proximal_bias=False,
|
| 162 |
+
proximal_init=False,
|
| 163 |
+
):
|
| 164 |
+
super().__init__()
|
| 165 |
+
assert channels % n_heads == 0
|
| 166 |
+
|
| 167 |
+
self.channels = channels
|
| 168 |
+
self.out_channels = out_channels
|
| 169 |
+
self.n_heads = n_heads
|
| 170 |
+
self.window_size = window_size
|
| 171 |
+
self.heads_share = heads_share
|
| 172 |
+
self.block_length = block_length
|
| 173 |
+
self.proximal_bias = proximal_bias
|
| 174 |
+
self.p_dropout = p_dropout
|
| 175 |
+
self.attn = None
|
| 176 |
+
|
| 177 |
+
self.k_channels = channels // n_heads
|
| 178 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 179 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 180 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 181 |
+
if window_size is not None:
|
| 182 |
+
n_heads_rel = 1 if heads_share else n_heads
|
| 183 |
+
rel_stddev = self.k_channels ** -0.5
|
| 184 |
+
self.emb_rel_k = nn.Parameter(
|
| 185 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 186 |
+
* rel_stddev
|
| 187 |
+
)
|
| 188 |
+
self.emb_rel_v = nn.Parameter(
|
| 189 |
+
torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
|
| 190 |
+
* rel_stddev
|
| 191 |
+
)
|
| 192 |
+
self.conv_o = nn.Conv1d(channels, out_channels, 1)
|
| 193 |
+
self.drop = nn.Dropout(p_dropout)
|
| 194 |
+
|
| 195 |
+
nn.init.xavier_uniform_(self.conv_q.weight)
|
| 196 |
+
nn.init.xavier_uniform_(self.conv_k.weight)
|
| 197 |
+
if proximal_init:
|
| 198 |
+
self.conv_k.weight.data.copy_(self.conv_q.weight.data)
|
| 199 |
+
self.conv_k.bias.data.copy_(self.conv_q.bias.data)
|
| 200 |
+
nn.init.xavier_uniform_(self.conv_v.weight)
|
| 201 |
+
|
| 202 |
+
def forward(self, x, c, attn_mask=None):
|
| 203 |
+
q = self.conv_q(x)
|
| 204 |
+
k = self.conv_k(c)
|
| 205 |
+
v = self.conv_v(c)
|
| 206 |
+
|
| 207 |
+
x, self.attn = self.attention(q, k, v, mask=attn_mask)
|
| 208 |
+
|
| 209 |
+
x = self.conv_o(x)
|
| 210 |
+
return x
|
| 211 |
+
|
| 212 |
+
def attention(self, query, key, value, mask=None):
|
| 213 |
+
# reshape [b, d, t] -> [b, n_h, t, d_k]
|
| 214 |
+
b, d, t_s, t_t = (*key.size(), query.size(2))
|
| 215 |
+
query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
|
| 216 |
+
key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 217 |
+
value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
|
| 218 |
+
|
| 219 |
+
scores = torch.matmul(query, key.transpose(-2, -1)) / math.sqrt(self.k_channels)
|
| 220 |
+
if self.window_size is not None:
|
| 221 |
+
assert (
|
| 222 |
+
t_s == t_t
|
| 223 |
+
), "Relative attention is only available for self-attention."
|
| 224 |
+
key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
|
| 225 |
+
rel_logits = self._matmul_with_relative_keys(query, key_relative_embeddings)
|
| 226 |
+
rel_logits = self._relative_position_to_absolute_position(rel_logits)
|
| 227 |
+
scores_local = rel_logits / math.sqrt(self.k_channels)
|
| 228 |
+
scores = scores + scores_local
|
| 229 |
+
if self.proximal_bias:
|
| 230 |
+
assert t_s == t_t, "Proximal bias is only available for self-attention."
|
| 231 |
+
scores = scores + self._attention_bias_proximal(t_s).to(
|
| 232 |
+
device=scores.device, dtype=scores.dtype
|
| 233 |
+
)
|
| 234 |
+
if mask is not None:
|
| 235 |
+
scores = scores.masked_fill(mask == 0, -1e4)
|
| 236 |
+
if self.block_length is not None:
|
| 237 |
+
block_mask = (
|
| 238 |
+
torch.ones_like(scores)
|
| 239 |
+
.triu(-self.block_length)
|
| 240 |
+
.tril(self.block_length)
|
| 241 |
+
)
|
| 242 |
+
scores = scores * block_mask + -1e4 * (1 - block_mask)
|
| 243 |
+
p_attn = F.softmax(scores, dim=-1) # [b, n_h, t_t, t_s]
|
| 244 |
+
p_attn = self.drop(p_attn)
|
| 245 |
+
output = torch.matmul(p_attn, value)
|
| 246 |
+
if self.window_size is not None:
|
| 247 |
+
relative_weights = self._absolute_position_to_relative_position(p_attn)
|
| 248 |
+
value_relative_embeddings = self._get_relative_embeddings(
|
| 249 |
+
self.emb_rel_v, t_s
|
| 250 |
+
)
|
| 251 |
+
output = output + self._matmul_with_relative_values(
|
| 252 |
+
relative_weights, value_relative_embeddings
|
| 253 |
+
)
|
| 254 |
+
output = (
|
| 255 |
+
output.transpose(2, 3).contiguous().view(b, d, t_t)
|
| 256 |
+
) # [b, n_h, t_t, d_k] -> [b, d, t_t]
|
| 257 |
+
return output, p_attn
|
| 258 |
+
|
| 259 |
+
def _matmul_with_relative_values(self, x, y):
|
| 260 |
+
"""
|
| 261 |
+
x: [b, h, l, m]
|
| 262 |
+
y: [h or 1, m, d]
|
| 263 |
+
ret: [b, h, l, d]
|
| 264 |
+
"""
|
| 265 |
+
ret = torch.matmul(x, y.unsqueeze(0))
|
| 266 |
+
return ret
|
| 267 |
+
|
| 268 |
+
def _matmul_with_relative_keys(self, x, y):
|
| 269 |
+
"""
|
| 270 |
+
x: [b, h, l, d]
|
| 271 |
+
y: [h or 1, m, d]
|
| 272 |
+
ret: [b, h, l, m]
|
| 273 |
+
"""
|
| 274 |
+
ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
|
| 275 |
+
return ret
|
| 276 |
+
|
| 277 |
+
def _get_relative_embeddings(self, relative_embeddings, length):
|
| 278 |
+
max_relative_position = 2 * self.window_size + 1
|
| 279 |
+
# Pad first before slice to avoid using cond ops.
|
| 280 |
+
pad_length = max(length - (self.window_size + 1), 0)
|
| 281 |
+
slice_start_position = max((self.window_size + 1) - length, 0)
|
| 282 |
+
slice_end_position = slice_start_position + 2 * length - 1
|
| 283 |
+
if pad_length > 0:
|
| 284 |
+
padded_relative_embeddings = F.pad(
|
| 285 |
+
relative_embeddings,
|
| 286 |
+
commons.convert_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
|
| 287 |
+
)
|
| 288 |
+
else:
|
| 289 |
+
padded_relative_embeddings = relative_embeddings
|
| 290 |
+
used_relative_embeddings = padded_relative_embeddings[
|
| 291 |
+
:, slice_start_position:slice_end_position
|
| 292 |
+
]
|
| 293 |
+
return used_relative_embeddings
|
| 294 |
+
|
| 295 |
+
def _relative_position_to_absolute_position(self, x):
|
| 296 |
+
"""
|
| 297 |
+
x: [b, h, l, 2*l-1]
|
| 298 |
+
ret: [b, h, l, l]
|
| 299 |
+
"""
|
| 300 |
+
batch, heads, length, _ = x.size()
|
| 301 |
+
# Concat columns of pad to shift from relative to absolute indexing.
|
| 302 |
+
x = F.pad(x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
|
| 303 |
+
|
| 304 |
+
# Concat extra elements so to add up to shape (len+1, 2*len-1).
|
| 305 |
+
x_flat = x.view([batch, heads, length * 2 * length])
|
| 306 |
+
x_flat = F.pad(
|
| 307 |
+
x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [0, length - 1]])
|
| 308 |
+
)
|
| 309 |
+
|
| 310 |
+
# Reshape and slice out the padded elements.
|
| 311 |
+
x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
|
| 312 |
+
:, :, :length, length - 1 :
|
| 313 |
+
]
|
| 314 |
+
return x_final
|
| 315 |
+
|
| 316 |
+
def _absolute_position_to_relative_position(self, x):
|
| 317 |
+
"""
|
| 318 |
+
x: [b, h, l, l]
|
| 319 |
+
ret: [b, h, l, 2*l-1]
|
| 320 |
+
"""
|
| 321 |
+
batch, heads, length, _ = x.size()
|
| 322 |
+
# padd along column
|
| 323 |
+
x = F.pad(
|
| 324 |
+
x, commons.convert_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
|
| 325 |
+
)
|
| 326 |
+
x_flat = x.view([batch, heads, length ** 2 + length * (length - 1)])
|
| 327 |
+
# add 0's in the beginning that will skew the elements after reshape
|
| 328 |
+
x_flat = F.pad(x_flat, commons.convert_pad_shape([[0, 0], [0, 0], [length, 0]]))
|
| 329 |
+
x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
|
| 330 |
+
return x_final
|
| 331 |
+
|
| 332 |
+
def _attention_bias_proximal(self, length):
|
| 333 |
+
"""Bias for self-attention to encourage attention to close positions.
|
| 334 |
+
Args:
|
| 335 |
+
length: an integer scalar.
|
| 336 |
+
Returns:
|
| 337 |
+
a Tensor with shape [1, 1, length, length]
|
| 338 |
+
"""
|
| 339 |
+
r = torch.arange(length, dtype=torch.float32)
|
| 340 |
+
diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
|
| 341 |
+
return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
|
| 342 |
+
|
| 343 |
+
|
| 344 |
+
class FFN(nn.Module):
|
| 345 |
+
def __init__(
|
| 346 |
+
self,
|
| 347 |
+
in_channels,
|
| 348 |
+
out_channels,
|
| 349 |
+
filter_channels,
|
| 350 |
+
kernel_size,
|
| 351 |
+
p_dropout=0.0,
|
| 352 |
+
activation=None,
|
| 353 |
+
):
|
| 354 |
+
super().__init__()
|
| 355 |
+
self.in_channels = in_channels
|
| 356 |
+
self.out_channels = out_channels
|
| 357 |
+
self.filter_channels = filter_channels
|
| 358 |
+
self.kernel_size = kernel_size
|
| 359 |
+
self.p_dropout = p_dropout
|
| 360 |
+
self.activation = activation
|
| 361 |
+
|
| 362 |
+
self.conv_1 = nn.Conv1d(
|
| 363 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 364 |
+
)
|
| 365 |
+
self.conv_2 = nn.Conv1d(
|
| 366 |
+
filter_channels, out_channels, kernel_size, padding=kernel_size // 2
|
| 367 |
+
)
|
| 368 |
+
self.drop = nn.Dropout(p_dropout)
|
| 369 |
+
|
| 370 |
+
def forward(self, x, x_mask):
|
| 371 |
+
x = self.conv_1(x * x_mask)
|
| 372 |
+
if self.activation == "gelu":
|
| 373 |
+
x = x * torch.sigmoid(1.702 * x)
|
| 374 |
+
else:
|
| 375 |
+
x = torch.relu(x)
|
| 376 |
+
x = self.drop(x)
|
| 377 |
+
x = self.conv_2(x * x_mask)
|
| 378 |
+
return x * x_mask
|
src/glow_tts/audio_processing.py
ADDED
|
@@ -0,0 +1,100 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import numpy as np
|
| 3 |
+
from scipy.signal import get_window
|
| 4 |
+
import librosa.util as librosa_util
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def window_sumsquare(
|
| 8 |
+
window,
|
| 9 |
+
n_frames,
|
| 10 |
+
hop_length=200,
|
| 11 |
+
win_length=800,
|
| 12 |
+
n_fft=800,
|
| 13 |
+
dtype=np.float32,
|
| 14 |
+
norm=None,
|
| 15 |
+
):
|
| 16 |
+
"""
|
| 17 |
+
# from librosa 0.6
|
| 18 |
+
Compute the sum-square envelope of a window function at a given hop length.
|
| 19 |
+
|
| 20 |
+
This is used to estimate modulation effects induced by windowing
|
| 21 |
+
observations in short-time fourier transforms.
|
| 22 |
+
|
| 23 |
+
Parameters
|
| 24 |
+
----------
|
| 25 |
+
window : string, tuple, number, callable, or list-like
|
| 26 |
+
Window specification, as in `get_window`
|
| 27 |
+
|
| 28 |
+
n_frames : int > 0
|
| 29 |
+
The number of analysis frames
|
| 30 |
+
|
| 31 |
+
hop_length : int > 0
|
| 32 |
+
The number of samples to advance between frames
|
| 33 |
+
|
| 34 |
+
win_length : [optional]
|
| 35 |
+
The length of the window function. By default, this matches `n_fft`.
|
| 36 |
+
|
| 37 |
+
n_fft : int > 0
|
| 38 |
+
The length of each analysis frame.
|
| 39 |
+
|
| 40 |
+
dtype : np.dtype
|
| 41 |
+
The data type of the output
|
| 42 |
+
|
| 43 |
+
Returns
|
| 44 |
+
-------
|
| 45 |
+
wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
|
| 46 |
+
The sum-squared envelope of the window function
|
| 47 |
+
"""
|
| 48 |
+
if win_length is None:
|
| 49 |
+
win_length = n_fft
|
| 50 |
+
|
| 51 |
+
n = n_fft + hop_length * (n_frames - 1)
|
| 52 |
+
x = np.zeros(n, dtype=dtype)
|
| 53 |
+
|
| 54 |
+
# Compute the squared window at the desired length
|
| 55 |
+
win_sq = get_window(window, win_length, fftbins=True)
|
| 56 |
+
win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
|
| 57 |
+
win_sq = librosa_util.pad_center(win_sq, n_fft)
|
| 58 |
+
|
| 59 |
+
# Fill the envelope
|
| 60 |
+
for i in range(n_frames):
|
| 61 |
+
sample = i * hop_length
|
| 62 |
+
x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
|
| 63 |
+
return x
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def griffin_lim(magnitudes, stft_fn, n_iters=30):
|
| 67 |
+
"""
|
| 68 |
+
PARAMS
|
| 69 |
+
------
|
| 70 |
+
magnitudes: spectrogram magnitudes
|
| 71 |
+
stft_fn: STFT class with transform (STFT) and inverse (ISTFT) methods
|
| 72 |
+
"""
|
| 73 |
+
|
| 74 |
+
angles = np.angle(np.exp(2j * np.pi * np.random.rand(*magnitudes.size())))
|
| 75 |
+
angles = angles.astype(np.float32)
|
| 76 |
+
angles = torch.autograd.Variable(torch.from_numpy(angles))
|
| 77 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
| 78 |
+
|
| 79 |
+
for i in range(n_iters):
|
| 80 |
+
_, angles = stft_fn.transform(signal)
|
| 81 |
+
signal = stft_fn.inverse(magnitudes, angles).squeeze(1)
|
| 82 |
+
return signal
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def dynamic_range_compression(x, C=1, clip_val=1e-5):
|
| 86 |
+
"""
|
| 87 |
+
PARAMS
|
| 88 |
+
------
|
| 89 |
+
C: compression factor
|
| 90 |
+
"""
|
| 91 |
+
return torch.log(torch.clamp(x, min=clip_val) * C)
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
def dynamic_range_decompression(x, C=1):
|
| 95 |
+
"""
|
| 96 |
+
PARAMS
|
| 97 |
+
------
|
| 98 |
+
C: compression factor used to compress
|
| 99 |
+
"""
|
| 100 |
+
return torch.exp(x) / C
|
src/glow_tts/commons.py
ADDED
|
@@ -0,0 +1,273 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from torch import nn
|
| 5 |
+
from torch.nn import functional as F
|
| 6 |
+
|
| 7 |
+
from librosa.filters import mel as librosa_mel_fn
|
| 8 |
+
from audio_processing import dynamic_range_compression
|
| 9 |
+
from audio_processing import dynamic_range_decompression
|
| 10 |
+
from stft import STFT
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def intersperse(lst, item):
|
| 14 |
+
result = [item] * (len(lst) * 2 + 1)
|
| 15 |
+
result[1::2] = lst
|
| 16 |
+
return result
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
def mle_loss(z, m, logs, logdet, mask):
|
| 20 |
+
l = torch.sum(logs) + 0.5 * torch.sum(
|
| 21 |
+
torch.exp(-2 * logs) * ((z - m) ** 2)
|
| 22 |
+
) # neg normal likelihood w/o the constant term
|
| 23 |
+
l = l - torch.sum(logdet) # log jacobian determinant
|
| 24 |
+
l = l / torch.sum(
|
| 25 |
+
torch.ones_like(z) * mask
|
| 26 |
+
) # averaging across batch, channel and time axes
|
| 27 |
+
l = l + 0.5 * math.log(2 * math.pi) # add the remaining constant term
|
| 28 |
+
return l
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def duration_loss(logw, logw_, lengths):
|
| 32 |
+
l = torch.sum((logw - logw_) ** 2) / torch.sum(lengths)
|
| 33 |
+
return l
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
@torch.jit.script
|
| 37 |
+
def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
|
| 38 |
+
n_channels_int = n_channels[0]
|
| 39 |
+
in_act = input_a + input_b
|
| 40 |
+
t_act = torch.tanh(in_act[:, :n_channels_int, :])
|
| 41 |
+
s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
|
| 42 |
+
acts = t_act * s_act
|
| 43 |
+
return acts
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def convert_pad_shape(pad_shape):
|
| 47 |
+
l = pad_shape[::-1]
|
| 48 |
+
pad_shape = [item for sublist in l for item in sublist]
|
| 49 |
+
return pad_shape
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def shift_1d(x):
|
| 53 |
+
x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
|
| 54 |
+
return x
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
def sequence_mask(length, max_length=None):
|
| 58 |
+
if max_length is None:
|
| 59 |
+
max_length = length.max()
|
| 60 |
+
x = torch.arange(max_length, dtype=length.dtype, device=length.device)
|
| 61 |
+
return x.unsqueeze(0) < length.unsqueeze(1)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def maximum_path(value, mask, max_neg_val=-np.inf):
|
| 65 |
+
"""Numpy-friendly version. It's about 4 times faster than torch version.
|
| 66 |
+
value: [b, t_x, t_y]
|
| 67 |
+
mask: [b, t_x, t_y]
|
| 68 |
+
"""
|
| 69 |
+
value = value * mask
|
| 70 |
+
|
| 71 |
+
device = value.device
|
| 72 |
+
dtype = value.dtype
|
| 73 |
+
value = value.cpu().detach().numpy()
|
| 74 |
+
mask = mask.cpu().detach().numpy().astype(np.bool)
|
| 75 |
+
|
| 76 |
+
b, t_x, t_y = value.shape
|
| 77 |
+
direction = np.zeros(value.shape, dtype=np.int64)
|
| 78 |
+
v = np.zeros((b, t_x), dtype=np.float32)
|
| 79 |
+
x_range = np.arange(t_x, dtype=np.float32).reshape(1, -1)
|
| 80 |
+
for j in range(t_y):
|
| 81 |
+
v0 = np.pad(v, [[0, 0], [1, 0]], mode="constant", constant_values=max_neg_val)[
|
| 82 |
+
:, :-1
|
| 83 |
+
]
|
| 84 |
+
v1 = v
|
| 85 |
+
max_mask = v1 >= v0
|
| 86 |
+
v_max = np.where(max_mask, v1, v0)
|
| 87 |
+
direction[:, :, j] = max_mask
|
| 88 |
+
|
| 89 |
+
index_mask = x_range <= j
|
| 90 |
+
v = np.where(index_mask, v_max + value[:, :, j], max_neg_val)
|
| 91 |
+
direction = np.where(mask, direction, 1)
|
| 92 |
+
|
| 93 |
+
path = np.zeros(value.shape, dtype=np.float32)
|
| 94 |
+
index = mask[:, :, 0].sum(1).astype(np.int64) - 1
|
| 95 |
+
index_range = np.arange(b)
|
| 96 |
+
for j in reversed(range(t_y)):
|
| 97 |
+
path[index_range, index, j] = 1
|
| 98 |
+
index = index + direction[index_range, index, j] - 1
|
| 99 |
+
path = path * mask.astype(np.float32)
|
| 100 |
+
path = torch.from_numpy(path).to(device=device, dtype=dtype)
|
| 101 |
+
return path
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def generate_path(duration, mask):
|
| 105 |
+
"""
|
| 106 |
+
duration: [b, t_x]
|
| 107 |
+
mask: [b, t_x, t_y]
|
| 108 |
+
"""
|
| 109 |
+
device = duration.device
|
| 110 |
+
|
| 111 |
+
b, t_x, t_y = mask.shape
|
| 112 |
+
cum_duration = torch.cumsum(duration, 1)
|
| 113 |
+
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype).to(device=device)
|
| 114 |
+
|
| 115 |
+
cum_duration_flat = cum_duration.view(b * t_x)
|
| 116 |
+
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
|
| 117 |
+
path = path.view(b, t_x, t_y)
|
| 118 |
+
path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
|
| 119 |
+
path = path * mask
|
| 120 |
+
return path
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
class Adam:
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
params,
|
| 127 |
+
scheduler,
|
| 128 |
+
dim_model,
|
| 129 |
+
warmup_steps=4000,
|
| 130 |
+
lr=1e0,
|
| 131 |
+
betas=(0.9, 0.98),
|
| 132 |
+
eps=1e-9,
|
| 133 |
+
):
|
| 134 |
+
self.params = params
|
| 135 |
+
self.scheduler = scheduler
|
| 136 |
+
self.dim_model = dim_model
|
| 137 |
+
self.warmup_steps = warmup_steps
|
| 138 |
+
self.lr = lr
|
| 139 |
+
self.betas = betas
|
| 140 |
+
self.eps = eps
|
| 141 |
+
|
| 142 |
+
self.step_num = 1
|
| 143 |
+
self.cur_lr = lr * self._get_lr_scale()
|
| 144 |
+
|
| 145 |
+
self._optim = torch.optim.Adam(params, lr=self.cur_lr, betas=betas, eps=eps)
|
| 146 |
+
|
| 147 |
+
def _get_lr_scale(self):
|
| 148 |
+
if self.scheduler == "noam":
|
| 149 |
+
return np.power(self.dim_model, -0.5) * np.min(
|
| 150 |
+
[
|
| 151 |
+
np.power(self.step_num, -0.5),
|
| 152 |
+
self.step_num * np.power(self.warmup_steps, -1.5),
|
| 153 |
+
]
|
| 154 |
+
)
|
| 155 |
+
else:
|
| 156 |
+
return 1
|
| 157 |
+
|
| 158 |
+
def _update_learning_rate(self):
|
| 159 |
+
self.step_num += 1
|
| 160 |
+
if self.scheduler == "noam":
|
| 161 |
+
self.cur_lr = self.lr * self._get_lr_scale()
|
| 162 |
+
for param_group in self._optim.param_groups:
|
| 163 |
+
param_group["lr"] = self.cur_lr
|
| 164 |
+
|
| 165 |
+
def get_lr(self):
|
| 166 |
+
return self.cur_lr
|
| 167 |
+
|
| 168 |
+
def step(self):
|
| 169 |
+
self._optim.step()
|
| 170 |
+
self._update_learning_rate()
|
| 171 |
+
|
| 172 |
+
def zero_grad(self):
|
| 173 |
+
self._optim.zero_grad()
|
| 174 |
+
|
| 175 |
+
def load_state_dict(self, d):
|
| 176 |
+
self._optim.load_state_dict(d)
|
| 177 |
+
|
| 178 |
+
def state_dict(self):
|
| 179 |
+
return self._optim.state_dict()
|
| 180 |
+
|
| 181 |
+
|
| 182 |
+
class TacotronSTFT(nn.Module):
|
| 183 |
+
def __init__(
|
| 184 |
+
self,
|
| 185 |
+
filter_length=1024,
|
| 186 |
+
hop_length=256,
|
| 187 |
+
win_length=1024,
|
| 188 |
+
n_mel_channels=80,
|
| 189 |
+
sampling_rate=22050,
|
| 190 |
+
mel_fmin=0.0,
|
| 191 |
+
mel_fmax=8000.0,
|
| 192 |
+
):
|
| 193 |
+
super(TacotronSTFT, self).__init__()
|
| 194 |
+
self.n_mel_channels = n_mel_channels
|
| 195 |
+
self.sampling_rate = sampling_rate
|
| 196 |
+
self.stft_fn = STFT(filter_length, hop_length, win_length)
|
| 197 |
+
mel_basis = librosa_mel_fn(
|
| 198 |
+
sampling_rate, filter_length, n_mel_channels, mel_fmin, mel_fmax
|
| 199 |
+
)
|
| 200 |
+
mel_basis = torch.from_numpy(mel_basis).float()
|
| 201 |
+
self.register_buffer("mel_basis", mel_basis)
|
| 202 |
+
|
| 203 |
+
def spectral_normalize(self, magnitudes):
|
| 204 |
+
output = dynamic_range_compression(magnitudes)
|
| 205 |
+
return output
|
| 206 |
+
|
| 207 |
+
def spectral_de_normalize(self, magnitudes):
|
| 208 |
+
output = dynamic_range_decompression(magnitudes)
|
| 209 |
+
return output
|
| 210 |
+
|
| 211 |
+
def mel_spectrogram(self, y):
|
| 212 |
+
"""Computes mel-spectrograms from a batch of waves
|
| 213 |
+
PARAMS
|
| 214 |
+
------
|
| 215 |
+
y: Variable(torch.FloatTensor) with shape (B, T) in range [-1, 1]
|
| 216 |
+
|
| 217 |
+
RETURNS
|
| 218 |
+
-------
|
| 219 |
+
mel_output: torch.FloatTensor of shape (B, n_mel_channels, T)
|
| 220 |
+
"""
|
| 221 |
+
assert torch.min(y.data) >= -1
|
| 222 |
+
assert torch.max(y.data) <= 1
|
| 223 |
+
|
| 224 |
+
magnitudes, phases = self.stft_fn.transform(y)
|
| 225 |
+
magnitudes = magnitudes.data
|
| 226 |
+
mel_output = torch.matmul(self.mel_basis, magnitudes)
|
| 227 |
+
mel_output = self.spectral_normalize(mel_output)
|
| 228 |
+
return mel_output
|
| 229 |
+
|
| 230 |
+
|
| 231 |
+
def clip_grad_value_(parameters, clip_value, norm_type=2):
|
| 232 |
+
if isinstance(parameters, torch.Tensor):
|
| 233 |
+
parameters = [parameters]
|
| 234 |
+
parameters = list(filter(lambda p: p.grad is not None, parameters))
|
| 235 |
+
norm_type = float(norm_type)
|
| 236 |
+
clip_value = float(clip_value)
|
| 237 |
+
|
| 238 |
+
total_norm = 0
|
| 239 |
+
for p in parameters:
|
| 240 |
+
param_norm = p.grad.data.norm(norm_type)
|
| 241 |
+
total_norm += param_norm.item() ** norm_type
|
| 242 |
+
|
| 243 |
+
p.grad.data.clamp_(min=-clip_value, max=clip_value)
|
| 244 |
+
total_norm = total_norm ** (1.0 / norm_type)
|
| 245 |
+
return total_norm
|
| 246 |
+
|
| 247 |
+
|
| 248 |
+
def squeeze(x, x_mask=None, n_sqz=2):
|
| 249 |
+
b, c, t = x.size()
|
| 250 |
+
|
| 251 |
+
t = (t // n_sqz) * n_sqz
|
| 252 |
+
x = x[:, :, :t]
|
| 253 |
+
x_sqz = x.view(b, c, t // n_sqz, n_sqz)
|
| 254 |
+
x_sqz = x_sqz.permute(0, 3, 1, 2).contiguous().view(b, c * n_sqz, t // n_sqz)
|
| 255 |
+
|
| 256 |
+
if x_mask is not None:
|
| 257 |
+
x_mask = x_mask[:, :, n_sqz - 1 :: n_sqz]
|
| 258 |
+
else:
|
| 259 |
+
x_mask = torch.ones(b, 1, t // n_sqz).to(device=x.device, dtype=x.dtype)
|
| 260 |
+
return x_sqz * x_mask, x_mask
|
| 261 |
+
|
| 262 |
+
|
| 263 |
+
def unsqueeze(x, x_mask=None, n_sqz=2):
|
| 264 |
+
b, c, t = x.size()
|
| 265 |
+
|
| 266 |
+
x_unsqz = x.view(b, n_sqz, c // n_sqz, t)
|
| 267 |
+
x_unsqz = x_unsqz.permute(0, 2, 3, 1).contiguous().view(b, c // n_sqz, t * n_sqz)
|
| 268 |
+
|
| 269 |
+
if x_mask is not None:
|
| 270 |
+
x_mask = x_mask.unsqueeze(-1).repeat(1, 1, 1, n_sqz).view(b, 1, t * n_sqz)
|
| 271 |
+
else:
|
| 272 |
+
x_mask = torch.ones(b, 1, t * n_sqz).to(device=x.device, dtype=x.dtype)
|
| 273 |
+
return x_unsqz * x_mask, x_mask
|
src/glow_tts/data_utils.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import torch.utils.data
|
| 5 |
+
|
| 6 |
+
import commons
|
| 7 |
+
from utils import load_wav_to_torch, load_filepaths_and_text
|
| 8 |
+
from text import text_to_sequence
|
| 9 |
+
|
| 10 |
+
class TextMelLoader(torch.utils.data.Dataset):
|
| 11 |
+
"""
|
| 12 |
+
1) loads audio,text pairs
|
| 13 |
+
2) normalizes text and converts them to sequences of one-hot vectors
|
| 14 |
+
3) computes mel-spectrograms from audio files.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(self, audiopaths_and_text, hparams):
|
| 18 |
+
self.audiopaths_and_text = load_filepaths_and_text(audiopaths_and_text)
|
| 19 |
+
self.text_cleaners = hparams.text_cleaners
|
| 20 |
+
self.max_wav_value = hparams.max_wav_value
|
| 21 |
+
self.sampling_rate = hparams.sampling_rate
|
| 22 |
+
self.load_mel_from_disk = hparams.load_mel_from_disk
|
| 23 |
+
self.add_noise = hparams.add_noise
|
| 24 |
+
self.symbols = hparams.punc + hparams.chars
|
| 25 |
+
self.add_blank = getattr(hparams, "add_blank", False) # improved version
|
| 26 |
+
self.stft = commons.TacotronSTFT(
|
| 27 |
+
hparams.filter_length,
|
| 28 |
+
hparams.hop_length,
|
| 29 |
+
hparams.win_length,
|
| 30 |
+
hparams.n_mel_channels,
|
| 31 |
+
hparams.sampling_rate,
|
| 32 |
+
hparams.mel_fmin,
|
| 33 |
+
hparams.mel_fmax,
|
| 34 |
+
)
|
| 35 |
+
random.seed(1234)
|
| 36 |
+
random.shuffle(self.audiopaths_and_text)
|
| 37 |
+
|
| 38 |
+
def get_mel_text_pair(self, audiopath_and_text):
|
| 39 |
+
# separate filename and text
|
| 40 |
+
audiopath, text = audiopath_and_text[0], audiopath_and_text[1]
|
| 41 |
+
text = self.get_text(text)
|
| 42 |
+
mel = self.get_mel(audiopath)
|
| 43 |
+
return (text, mel)
|
| 44 |
+
|
| 45 |
+
def get_mel(self, filename):
|
| 46 |
+
if not self.load_mel_from_disk:
|
| 47 |
+
audio, sampling_rate = load_wav_to_torch(filename)
|
| 48 |
+
if sampling_rate != self.stft.sampling_rate:
|
| 49 |
+
raise ValueError(
|
| 50 |
+
"{} {} SR doesn't match target {} SR".format(
|
| 51 |
+
sampling_rate, self.stft.sampling_rate
|
| 52 |
+
)
|
| 53 |
+
)
|
| 54 |
+
if self.add_noise:
|
| 55 |
+
audio = audio + torch.rand_like(audio)
|
| 56 |
+
audio_norm = audio / self.max_wav_value
|
| 57 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 58 |
+
melspec = self.stft.mel_spectrogram(audio_norm)
|
| 59 |
+
melspec = torch.squeeze(melspec, 0)
|
| 60 |
+
else:
|
| 61 |
+
melspec = torch.from_numpy(np.load(filename))
|
| 62 |
+
assert (
|
| 63 |
+
melspec.size(0) == self.stft.n_mel_channels
|
| 64 |
+
), "Mel dimension mismatch: given {}, expected {}".format(
|
| 65 |
+
melspec.size(0), self.stft.n_mel_channels
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
return melspec
|
| 69 |
+
|
| 70 |
+
def get_text(self, text):
|
| 71 |
+
text_norm = text_to_sequence(text, self.symbols, self.text_cleaners)
|
| 72 |
+
if self.add_blank:
|
| 73 |
+
text_norm = commons.intersperse(
|
| 74 |
+
text_norm, len(self.symbols)
|
| 75 |
+
) # add a blank token, whose id number is len(symbols)
|
| 76 |
+
text_norm = torch.IntTensor(text_norm)
|
| 77 |
+
return text_norm
|
| 78 |
+
|
| 79 |
+
def __getitem__(self, index):
|
| 80 |
+
return self.get_mel_text_pair(self.audiopaths_and_text[index])
|
| 81 |
+
|
| 82 |
+
def __len__(self):
|
| 83 |
+
return len(self.audiopaths_and_text)
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
class TextMelCollate:
|
| 87 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 88 |
+
|
| 89 |
+
def __init__(self, n_frames_per_step=1):
|
| 90 |
+
self.n_frames_per_step = n_frames_per_step
|
| 91 |
+
|
| 92 |
+
def __call__(self, batch):
|
| 93 |
+
"""Collate's training batch from normalized text and mel-spectrogram
|
| 94 |
+
PARAMS
|
| 95 |
+
------
|
| 96 |
+
batch: [text_normalized, mel_normalized]
|
| 97 |
+
"""
|
| 98 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 99 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
| 100 |
+
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
|
| 101 |
+
)
|
| 102 |
+
max_input_len = input_lengths[0]
|
| 103 |
+
|
| 104 |
+
text_padded = torch.LongTensor(len(batch), max_input_len)
|
| 105 |
+
text_padded.zero_()
|
| 106 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 107 |
+
text = batch[ids_sorted_decreasing[i]][0]
|
| 108 |
+
text_padded[i, : text.size(0)] = text
|
| 109 |
+
|
| 110 |
+
# Right zero-pad mel-spec
|
| 111 |
+
num_mels = batch[0][1].size(0)
|
| 112 |
+
max_target_len = max([x[1].size(1) for x in batch])
|
| 113 |
+
if max_target_len % self.n_frames_per_step != 0:
|
| 114 |
+
max_target_len += (
|
| 115 |
+
self.n_frames_per_step - max_target_len % self.n_frames_per_step
|
| 116 |
+
)
|
| 117 |
+
assert max_target_len % self.n_frames_per_step == 0
|
| 118 |
+
|
| 119 |
+
# include mel padded
|
| 120 |
+
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
| 121 |
+
mel_padded.zero_()
|
| 122 |
+
output_lengths = torch.LongTensor(len(batch))
|
| 123 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 124 |
+
mel = batch[ids_sorted_decreasing[i]][1]
|
| 125 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
| 126 |
+
output_lengths[i] = mel.size(1)
|
| 127 |
+
|
| 128 |
+
return text_padded, input_lengths, mel_padded, output_lengths
|
| 129 |
+
|
| 130 |
+
|
| 131 |
+
"""Multi speaker version"""
|
| 132 |
+
|
| 133 |
+
|
| 134 |
+
class TextMelSpeakerLoader(torch.utils.data.Dataset):
|
| 135 |
+
"""
|
| 136 |
+
1) loads audio, speaker_id, text pairs
|
| 137 |
+
2) normalizes text and converts them to sequences of one-hot vectors
|
| 138 |
+
3) computes mel-spectrograms from audio files.
|
| 139 |
+
"""
|
| 140 |
+
|
| 141 |
+
def __init__(self, audiopaths_sid_text, hparams):
|
| 142 |
+
self.audiopaths_sid_text = load_filepaths_and_text(audiopaths_sid_text)
|
| 143 |
+
self.text_cleaners = hparams.text_cleaners
|
| 144 |
+
self.max_wav_value = hparams.max_wav_value
|
| 145 |
+
self.sampling_rate = hparams.sampling_rate
|
| 146 |
+
self.load_mel_from_disk = hparams.load_mel_from_disk
|
| 147 |
+
self.add_noise = hparams.add_noise
|
| 148 |
+
self.symbols = hparams.punc + hparams.chars
|
| 149 |
+
self.add_blank = getattr(hparams, "add_blank", False) # improved version
|
| 150 |
+
self.min_text_len = getattr(hparams, "min_text_len", 1)
|
| 151 |
+
self.max_text_len = getattr(hparams, "max_text_len", 190)
|
| 152 |
+
self.stft = commons.TacotronSTFT(
|
| 153 |
+
hparams.filter_length,
|
| 154 |
+
hparams.hop_length,
|
| 155 |
+
hparams.win_length,
|
| 156 |
+
hparams.n_mel_channels,
|
| 157 |
+
hparams.sampling_rate,
|
| 158 |
+
hparams.mel_fmin,
|
| 159 |
+
hparams.mel_fmax,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
self._filter_text_len()
|
| 163 |
+
random.seed(1234)
|
| 164 |
+
random.shuffle(self.audiopaths_sid_text)
|
| 165 |
+
|
| 166 |
+
def _filter_text_len(self):
|
| 167 |
+
audiopaths_sid_text_new = []
|
| 168 |
+
for audiopath, sid, text in self.audiopaths_sid_text:
|
| 169 |
+
if self.min_text_len <= len(text) and len(text) <= self.max_text_len:
|
| 170 |
+
audiopaths_sid_text_new.append([audiopath, sid, text])
|
| 171 |
+
self.audiopaths_sid_text = audiopaths_sid_text_new
|
| 172 |
+
|
| 173 |
+
def get_mel_text_speaker_pair(self, audiopath_sid_text):
|
| 174 |
+
# separate filename, speaker_id and text
|
| 175 |
+
audiopath, sid, text = (
|
| 176 |
+
audiopath_sid_text[0],
|
| 177 |
+
audiopath_sid_text[1],
|
| 178 |
+
audiopath_sid_text[2],
|
| 179 |
+
)
|
| 180 |
+
text = self.get_text(text)
|
| 181 |
+
mel = self.get_mel(audiopath)
|
| 182 |
+
sid = self.get_sid(sid)
|
| 183 |
+
return (text, mel, sid)
|
| 184 |
+
|
| 185 |
+
def get_mel(self, filename):
|
| 186 |
+
if not self.load_mel_from_disk:
|
| 187 |
+
audio, sampling_rate = load_wav_to_torch(filename)
|
| 188 |
+
if sampling_rate != self.stft.sampling_rate:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
"{} {} SR doesn't match target {} SR".format(
|
| 191 |
+
sampling_rate, self.stft.sampling_rate
|
| 192 |
+
)
|
| 193 |
+
)
|
| 194 |
+
if self.add_noise:
|
| 195 |
+
audio = audio + torch.rand_like(audio)
|
| 196 |
+
audio_norm = audio / self.max_wav_value
|
| 197 |
+
audio_norm = audio_norm.unsqueeze(0)
|
| 198 |
+
melspec = self.stft.mel_spectrogram(audio_norm)
|
| 199 |
+
melspec = torch.squeeze(melspec, 0)
|
| 200 |
+
else:
|
| 201 |
+
melspec = torch.from_numpy(np.load(filename))
|
| 202 |
+
assert (
|
| 203 |
+
melspec.size(0) == self.stft.n_mel_channels
|
| 204 |
+
), "Mel dimension mismatch: given {}, expected {}".format(
|
| 205 |
+
melspec.size(0), self.stft.n_mel_channels
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
return melspec
|
| 209 |
+
|
| 210 |
+
def get_text(self, text):
|
| 211 |
+
text_norm = text_to_sequence(text, self.symbols, self.text_cleaners)
|
| 212 |
+
if self.add_blank:
|
| 213 |
+
text_norm = commons.intersperse(
|
| 214 |
+
text_norm, len(self.symbols)
|
| 215 |
+
) # add a blank token, whose id number is len(symbols)
|
| 216 |
+
text_norm = torch.IntTensor(text_norm)
|
| 217 |
+
return text_norm
|
| 218 |
+
|
| 219 |
+
def get_sid(self, sid):
|
| 220 |
+
sid = torch.IntTensor([int(sid)])
|
| 221 |
+
return sid
|
| 222 |
+
|
| 223 |
+
def __getitem__(self, index):
|
| 224 |
+
return self.get_mel_text_speaker_pair(self.audiopaths_sid_text[index])
|
| 225 |
+
|
| 226 |
+
def __len__(self):
|
| 227 |
+
return len(self.audiopaths_sid_text)
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
class TextMelSpeakerCollate:
|
| 231 |
+
"""Zero-pads model inputs and targets based on number of frames per step"""
|
| 232 |
+
|
| 233 |
+
def __init__(self, n_frames_per_step=1):
|
| 234 |
+
self.n_frames_per_step = n_frames_per_step
|
| 235 |
+
|
| 236 |
+
def __call__(self, batch):
|
| 237 |
+
"""Collate's training batch from normalized text and mel-spectrogram
|
| 238 |
+
PARAMS
|
| 239 |
+
------
|
| 240 |
+
batch: [text_normalized, mel_normalized]
|
| 241 |
+
"""
|
| 242 |
+
# Right zero-pad all one-hot text sequences to max input length
|
| 243 |
+
input_lengths, ids_sorted_decreasing = torch.sort(
|
| 244 |
+
torch.LongTensor([len(x[0]) for x in batch]), dim=0, descending=True
|
| 245 |
+
)
|
| 246 |
+
max_input_len = input_lengths[0]
|
| 247 |
+
|
| 248 |
+
text_padded = torch.LongTensor(len(batch), max_input_len)
|
| 249 |
+
text_padded.zero_()
|
| 250 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 251 |
+
text = batch[ids_sorted_decreasing[i]][0]
|
| 252 |
+
text_padded[i, : text.size(0)] = text
|
| 253 |
+
|
| 254 |
+
# Right zero-pad mel-spec
|
| 255 |
+
num_mels = batch[0][1].size(0)
|
| 256 |
+
max_target_len = max([x[1].size(1) for x in batch])
|
| 257 |
+
if max_target_len % self.n_frames_per_step != 0:
|
| 258 |
+
max_target_len += (
|
| 259 |
+
self.n_frames_per_step - max_target_len % self.n_frames_per_step
|
| 260 |
+
)
|
| 261 |
+
assert max_target_len % self.n_frames_per_step == 0
|
| 262 |
+
|
| 263 |
+
# include mel padded & sid
|
| 264 |
+
mel_padded = torch.FloatTensor(len(batch), num_mels, max_target_len)
|
| 265 |
+
mel_padded.zero_()
|
| 266 |
+
output_lengths = torch.LongTensor(len(batch))
|
| 267 |
+
sid = torch.LongTensor(len(batch))
|
| 268 |
+
for i in range(len(ids_sorted_decreasing)):
|
| 269 |
+
mel = batch[ids_sorted_decreasing[i]][1]
|
| 270 |
+
mel_padded[i, :, : mel.size(1)] = mel
|
| 271 |
+
output_lengths[i] = mel.size(1)
|
| 272 |
+
sid[i] = batch[ids_sorted_decreasing[i]][2]
|
| 273 |
+
|
| 274 |
+
return text_padded, input_lengths, mel_padded, output_lengths, sid
|
src/glow_tts/generate_mels.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
import os
|
| 3 |
+
import torch
|
| 4 |
+
import commons
|
| 5 |
+
|
| 6 |
+
import models
|
| 7 |
+
import utils
|
| 8 |
+
from argparse import ArgumentParser
|
| 9 |
+
from tqdm import tqdm
|
| 10 |
+
from text import text_to_sequence
|
| 11 |
+
|
| 12 |
+
if __name__ == "__main__":
|
| 13 |
+
parser = ArgumentParser()
|
| 14 |
+
parser.add_argument("-m", "--model_dir", required=True, type=str)
|
| 15 |
+
parser.add_argument("-s", "--mels_dir", required=True, type=str)
|
| 16 |
+
args = parser.parse_args()
|
| 17 |
+
MODEL_DIR = args.model_dir # path to model dir
|
| 18 |
+
SAVE_MELS_DIR = args.mels_dir # path to save generated mels
|
| 19 |
+
|
| 20 |
+
if not os.path.exists(SAVE_MELS_DIR):
|
| 21 |
+
os.makedirs(SAVE_MELS_DIR)
|
| 22 |
+
|
| 23 |
+
hps = utils.get_hparams_from_dir(MODEL_DIR)
|
| 24 |
+
symbols = list(hps.data.punc) + list(hps.data.chars)
|
| 25 |
+
checkpoint_path = utils.latest_checkpoint_path(MODEL_DIR)
|
| 26 |
+
cleaner = hps.data.text_cleaners
|
| 27 |
+
|
| 28 |
+
model = models.FlowGenerator(
|
| 29 |
+
len(symbols) + getattr(hps.data, "add_blank", False),
|
| 30 |
+
out_channels=hps.data.n_mel_channels,
|
| 31 |
+
**hps.model
|
| 32 |
+
).to("cuda")
|
| 33 |
+
|
| 34 |
+
utils.load_checkpoint(checkpoint_path, model)
|
| 35 |
+
model.decoder.store_inverse() # do not calcuate jacobians for fast decoding
|
| 36 |
+
_ = model.eval()
|
| 37 |
+
|
| 38 |
+
def get_mel(text, fpath):
|
| 39 |
+
if getattr(hps.data, "add_blank", False):
|
| 40 |
+
text_norm = text_to_sequence(text, symbols, cleaner)
|
| 41 |
+
text_norm = commons.intersperse(text_norm, len(symbols))
|
| 42 |
+
else: # If not using "add_blank" option during training, adding spaces at the beginning and the end of utterance improves quality
|
| 43 |
+
text = " " + text.strip() + " "
|
| 44 |
+
text_norm = text_to_sequence(text, symbols, cleaner)
|
| 45 |
+
|
| 46 |
+
sequence = np.array(text_norm)[None, :]
|
| 47 |
+
|
| 48 |
+
x_tst = torch.autograd.Variable(torch.from_numpy(sequence)).cuda().long()
|
| 49 |
+
x_tst_lengths = torch.tensor([x_tst.shape[1]]).cuda()
|
| 50 |
+
|
| 51 |
+
with torch.no_grad():
|
| 52 |
+
noise_scale = 0.667
|
| 53 |
+
length_scale = 1.0
|
| 54 |
+
(y_gen_tst, *_), *_, (attn_gen, *_) = model(
|
| 55 |
+
x_tst,
|
| 56 |
+
x_tst_lengths,
|
| 57 |
+
gen=True,
|
| 58 |
+
noise_scale=noise_scale,
|
| 59 |
+
length_scale=length_scale,
|
| 60 |
+
)
|
| 61 |
+
|
| 62 |
+
np.save(os.path.join(SAVE_MELS_DIR, fpath), y_gen_tst.cpu().detach().numpy())
|
| 63 |
+
|
| 64 |
+
for f in [hps.data.training_files, hps.data.validation_files]:
|
| 65 |
+
file_lines = open(f).read().splitlines()
|
| 66 |
+
|
| 67 |
+
for line in tqdm(file_lines):
|
| 68 |
+
fname, text = line.split("|")
|
| 69 |
+
fname = os.path.basename(fname).replace(".wav", ".npy")
|
| 70 |
+
get_mel(text, fname)
|
src/glow_tts/hifi/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .env import AttrDict
|
| 2 |
+
from .models import Generator
|
| 3 |
+
|
| 4 |
+
if __name__ == "__main__":
|
| 5 |
+
pass
|
src/glow_tts/hifi/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (252 Bytes). View file
|
|
|
src/glow_tts/hifi/__pycache__/env.cpython-37.pyc
ADDED
|
Binary file (787 Bytes). View file
|
|
|
src/glow_tts/hifi/__pycache__/models.cpython-37.pyc
ADDED
|
Binary file (9.1 kB). View file
|
|
|
src/glow_tts/hifi/__pycache__/utils.cpython-37.pyc
ADDED
|
Binary file (2.01 kB). View file
|
|
|
src/glow_tts/hifi/env.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import shutil
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
class AttrDict(dict):
|
| 6 |
+
def __init__(self, *args, **kwargs):
|
| 7 |
+
super(AttrDict, self).__init__(*args, **kwargs)
|
| 8 |
+
self.__dict__ = self
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def build_env(config, config_name, path):
|
| 12 |
+
t_path = os.path.join(path, config_name)
|
| 13 |
+
if config != t_path:
|
| 14 |
+
os.makedirs(path, exist_ok=True)
|
| 15 |
+
shutil.copyfile(config, os.path.join(path, config_name))
|
src/glow_tts/hifi/models.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn.functional as F
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
from torch.nn import Conv1d, ConvTranspose1d, AvgPool1d, Conv2d
|
| 5 |
+
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 6 |
+
from .utils import init_weights, get_padding
|
| 7 |
+
|
| 8 |
+
LRELU_SLOPE = 0.1
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
class ResBlock1(torch.nn.Module):
|
| 12 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3, 5)):
|
| 13 |
+
super(ResBlock1, self).__init__()
|
| 14 |
+
self.h = h
|
| 15 |
+
self.convs1 = nn.ModuleList(
|
| 16 |
+
[
|
| 17 |
+
weight_norm(
|
| 18 |
+
Conv1d(
|
| 19 |
+
channels,
|
| 20 |
+
channels,
|
| 21 |
+
kernel_size,
|
| 22 |
+
1,
|
| 23 |
+
dilation=dilation[0],
|
| 24 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 25 |
+
)
|
| 26 |
+
),
|
| 27 |
+
weight_norm(
|
| 28 |
+
Conv1d(
|
| 29 |
+
channels,
|
| 30 |
+
channels,
|
| 31 |
+
kernel_size,
|
| 32 |
+
1,
|
| 33 |
+
dilation=dilation[1],
|
| 34 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 35 |
+
)
|
| 36 |
+
),
|
| 37 |
+
weight_norm(
|
| 38 |
+
Conv1d(
|
| 39 |
+
channels,
|
| 40 |
+
channels,
|
| 41 |
+
kernel_size,
|
| 42 |
+
1,
|
| 43 |
+
dilation=dilation[2],
|
| 44 |
+
padding=get_padding(kernel_size, dilation[2]),
|
| 45 |
+
)
|
| 46 |
+
),
|
| 47 |
+
]
|
| 48 |
+
)
|
| 49 |
+
self.convs1.apply(init_weights)
|
| 50 |
+
|
| 51 |
+
self.convs2 = nn.ModuleList(
|
| 52 |
+
[
|
| 53 |
+
weight_norm(
|
| 54 |
+
Conv1d(
|
| 55 |
+
channels,
|
| 56 |
+
channels,
|
| 57 |
+
kernel_size,
|
| 58 |
+
1,
|
| 59 |
+
dilation=1,
|
| 60 |
+
padding=get_padding(kernel_size, 1),
|
| 61 |
+
)
|
| 62 |
+
),
|
| 63 |
+
weight_norm(
|
| 64 |
+
Conv1d(
|
| 65 |
+
channels,
|
| 66 |
+
channels,
|
| 67 |
+
kernel_size,
|
| 68 |
+
1,
|
| 69 |
+
dilation=1,
|
| 70 |
+
padding=get_padding(kernel_size, 1),
|
| 71 |
+
)
|
| 72 |
+
),
|
| 73 |
+
weight_norm(
|
| 74 |
+
Conv1d(
|
| 75 |
+
channels,
|
| 76 |
+
channels,
|
| 77 |
+
kernel_size,
|
| 78 |
+
1,
|
| 79 |
+
dilation=1,
|
| 80 |
+
padding=get_padding(kernel_size, 1),
|
| 81 |
+
)
|
| 82 |
+
),
|
| 83 |
+
]
|
| 84 |
+
)
|
| 85 |
+
self.convs2.apply(init_weights)
|
| 86 |
+
|
| 87 |
+
def forward(self, x):
|
| 88 |
+
for c1, c2 in zip(self.convs1, self.convs2):
|
| 89 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 90 |
+
xt = c1(xt)
|
| 91 |
+
xt = F.leaky_relu(xt, LRELU_SLOPE)
|
| 92 |
+
xt = c2(xt)
|
| 93 |
+
x = xt + x
|
| 94 |
+
return x
|
| 95 |
+
|
| 96 |
+
def remove_weight_norm(self):
|
| 97 |
+
for l in self.convs1:
|
| 98 |
+
remove_weight_norm(l)
|
| 99 |
+
for l in self.convs2:
|
| 100 |
+
remove_weight_norm(l)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
class ResBlock2(torch.nn.Module):
|
| 104 |
+
def __init__(self, h, channels, kernel_size=3, dilation=(1, 3)):
|
| 105 |
+
super(ResBlock2, self).__init__()
|
| 106 |
+
self.h = h
|
| 107 |
+
self.convs = nn.ModuleList(
|
| 108 |
+
[
|
| 109 |
+
weight_norm(
|
| 110 |
+
Conv1d(
|
| 111 |
+
channels,
|
| 112 |
+
channels,
|
| 113 |
+
kernel_size,
|
| 114 |
+
1,
|
| 115 |
+
dilation=dilation[0],
|
| 116 |
+
padding=get_padding(kernel_size, dilation[0]),
|
| 117 |
+
)
|
| 118 |
+
),
|
| 119 |
+
weight_norm(
|
| 120 |
+
Conv1d(
|
| 121 |
+
channels,
|
| 122 |
+
channels,
|
| 123 |
+
kernel_size,
|
| 124 |
+
1,
|
| 125 |
+
dilation=dilation[1],
|
| 126 |
+
padding=get_padding(kernel_size, dilation[1]),
|
| 127 |
+
)
|
| 128 |
+
),
|
| 129 |
+
]
|
| 130 |
+
)
|
| 131 |
+
self.convs.apply(init_weights)
|
| 132 |
+
|
| 133 |
+
def forward(self, x):
|
| 134 |
+
for c in self.convs:
|
| 135 |
+
xt = F.leaky_relu(x, LRELU_SLOPE)
|
| 136 |
+
xt = c(xt)
|
| 137 |
+
x = xt + x
|
| 138 |
+
return x
|
| 139 |
+
|
| 140 |
+
def remove_weight_norm(self):
|
| 141 |
+
for l in self.convs:
|
| 142 |
+
remove_weight_norm(l)
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
class Generator(torch.nn.Module):
|
| 146 |
+
def __init__(self, h):
|
| 147 |
+
super(Generator, self).__init__()
|
| 148 |
+
self.h = h
|
| 149 |
+
self.num_kernels = len(h.resblock_kernel_sizes)
|
| 150 |
+
self.num_upsamples = len(h.upsample_rates)
|
| 151 |
+
self.conv_pre = weight_norm(
|
| 152 |
+
Conv1d(80, h.upsample_initial_channel, 7, 1, padding=3)
|
| 153 |
+
)
|
| 154 |
+
resblock = ResBlock1 if h.resblock == "1" else ResBlock2
|
| 155 |
+
|
| 156 |
+
self.ups = nn.ModuleList()
|
| 157 |
+
for i, (u, k) in enumerate(zip(h.upsample_rates, h.upsample_kernel_sizes)):
|
| 158 |
+
self.ups.append(
|
| 159 |
+
weight_norm(
|
| 160 |
+
ConvTranspose1d(
|
| 161 |
+
h.upsample_initial_channel // (2 ** i),
|
| 162 |
+
h.upsample_initial_channel // (2 ** (i + 1)),
|
| 163 |
+
k,
|
| 164 |
+
u,
|
| 165 |
+
padding=(k - u) // 2,
|
| 166 |
+
)
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
self.resblocks = nn.ModuleList()
|
| 171 |
+
for i in range(len(self.ups)):
|
| 172 |
+
ch = h.upsample_initial_channel // (2 ** (i + 1))
|
| 173 |
+
for j, (k, d) in enumerate(
|
| 174 |
+
zip(h.resblock_kernel_sizes, h.resblock_dilation_sizes)
|
| 175 |
+
):
|
| 176 |
+
self.resblocks.append(resblock(h, ch, k, d))
|
| 177 |
+
|
| 178 |
+
self.conv_post = weight_norm(Conv1d(ch, 1, 7, 1, padding=3))
|
| 179 |
+
self.ups.apply(init_weights)
|
| 180 |
+
self.conv_post.apply(init_weights)
|
| 181 |
+
|
| 182 |
+
def forward(self, x):
|
| 183 |
+
x = self.conv_pre(x)
|
| 184 |
+
for i in range(self.num_upsamples):
|
| 185 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 186 |
+
x = self.ups[i](x)
|
| 187 |
+
xs = None
|
| 188 |
+
for j in range(self.num_kernels):
|
| 189 |
+
if xs is None:
|
| 190 |
+
xs = self.resblocks[i * self.num_kernels + j](x)
|
| 191 |
+
else:
|
| 192 |
+
xs += self.resblocks[i * self.num_kernels + j](x)
|
| 193 |
+
x = xs / self.num_kernels
|
| 194 |
+
x = F.leaky_relu(x)
|
| 195 |
+
x = self.conv_post(x)
|
| 196 |
+
x = torch.tanh(x)
|
| 197 |
+
|
| 198 |
+
return x
|
| 199 |
+
|
| 200 |
+
def remove_weight_norm(self):
|
| 201 |
+
print("Removing weight norm...")
|
| 202 |
+
for l in self.ups:
|
| 203 |
+
remove_weight_norm(l)
|
| 204 |
+
for l in self.resblocks:
|
| 205 |
+
l.remove_weight_norm()
|
| 206 |
+
remove_weight_norm(self.conv_pre)
|
| 207 |
+
remove_weight_norm(self.conv_post)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class DiscriminatorP(torch.nn.Module):
|
| 211 |
+
def __init__(self, period, kernel_size=5, stride=3, use_spectral_norm=False):
|
| 212 |
+
super(DiscriminatorP, self).__init__()
|
| 213 |
+
self.period = period
|
| 214 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 215 |
+
self.convs = nn.ModuleList(
|
| 216 |
+
[
|
| 217 |
+
norm_f(
|
| 218 |
+
Conv2d(
|
| 219 |
+
1,
|
| 220 |
+
32,
|
| 221 |
+
(kernel_size, 1),
|
| 222 |
+
(stride, 1),
|
| 223 |
+
padding=(get_padding(5, 1), 0),
|
| 224 |
+
)
|
| 225 |
+
),
|
| 226 |
+
norm_f(
|
| 227 |
+
Conv2d(
|
| 228 |
+
32,
|
| 229 |
+
128,
|
| 230 |
+
(kernel_size, 1),
|
| 231 |
+
(stride, 1),
|
| 232 |
+
padding=(get_padding(5, 1), 0),
|
| 233 |
+
)
|
| 234 |
+
),
|
| 235 |
+
norm_f(
|
| 236 |
+
Conv2d(
|
| 237 |
+
128,
|
| 238 |
+
512,
|
| 239 |
+
(kernel_size, 1),
|
| 240 |
+
(stride, 1),
|
| 241 |
+
padding=(get_padding(5, 1), 0),
|
| 242 |
+
)
|
| 243 |
+
),
|
| 244 |
+
norm_f(
|
| 245 |
+
Conv2d(
|
| 246 |
+
512,
|
| 247 |
+
1024,
|
| 248 |
+
(kernel_size, 1),
|
| 249 |
+
(stride, 1),
|
| 250 |
+
padding=(get_padding(5, 1), 0),
|
| 251 |
+
)
|
| 252 |
+
),
|
| 253 |
+
norm_f(Conv2d(1024, 1024, (kernel_size, 1), 1, padding=(2, 0))),
|
| 254 |
+
]
|
| 255 |
+
)
|
| 256 |
+
self.conv_post = norm_f(Conv2d(1024, 1, (3, 1), 1, padding=(1, 0)))
|
| 257 |
+
|
| 258 |
+
def forward(self, x):
|
| 259 |
+
fmap = []
|
| 260 |
+
|
| 261 |
+
# 1d to 2d
|
| 262 |
+
b, c, t = x.shape
|
| 263 |
+
if t % self.period != 0: # pad first
|
| 264 |
+
n_pad = self.period - (t % self.period)
|
| 265 |
+
x = F.pad(x, (0, n_pad), "reflect")
|
| 266 |
+
t = t + n_pad
|
| 267 |
+
x = x.view(b, c, t // self.period, self.period)
|
| 268 |
+
|
| 269 |
+
for l in self.convs:
|
| 270 |
+
x = l(x)
|
| 271 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 272 |
+
fmap.append(x)
|
| 273 |
+
x = self.conv_post(x)
|
| 274 |
+
fmap.append(x)
|
| 275 |
+
x = torch.flatten(x, 1, -1)
|
| 276 |
+
|
| 277 |
+
return x, fmap
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
class MultiPeriodDiscriminator(torch.nn.Module):
|
| 281 |
+
def __init__(self):
|
| 282 |
+
super(MultiPeriodDiscriminator, self).__init__()
|
| 283 |
+
self.discriminators = nn.ModuleList(
|
| 284 |
+
[
|
| 285 |
+
DiscriminatorP(2),
|
| 286 |
+
DiscriminatorP(3),
|
| 287 |
+
DiscriminatorP(5),
|
| 288 |
+
DiscriminatorP(7),
|
| 289 |
+
DiscriminatorP(11),
|
| 290 |
+
]
|
| 291 |
+
)
|
| 292 |
+
|
| 293 |
+
def forward(self, y, y_hat):
|
| 294 |
+
y_d_rs = []
|
| 295 |
+
y_d_gs = []
|
| 296 |
+
fmap_rs = []
|
| 297 |
+
fmap_gs = []
|
| 298 |
+
for i, d in enumerate(self.discriminators):
|
| 299 |
+
y_d_r, fmap_r = d(y)
|
| 300 |
+
y_d_g, fmap_g = d(y_hat)
|
| 301 |
+
y_d_rs.append(y_d_r)
|
| 302 |
+
fmap_rs.append(fmap_r)
|
| 303 |
+
y_d_gs.append(y_d_g)
|
| 304 |
+
fmap_gs.append(fmap_g)
|
| 305 |
+
|
| 306 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 307 |
+
|
| 308 |
+
|
| 309 |
+
class DiscriminatorS(torch.nn.Module):
|
| 310 |
+
def __init__(self, use_spectral_norm=False):
|
| 311 |
+
super(DiscriminatorS, self).__init__()
|
| 312 |
+
norm_f = weight_norm if use_spectral_norm == False else spectral_norm
|
| 313 |
+
self.convs = nn.ModuleList(
|
| 314 |
+
[
|
| 315 |
+
norm_f(Conv1d(1, 128, 15, 1, padding=7)),
|
| 316 |
+
norm_f(Conv1d(128, 128, 41, 2, groups=4, padding=20)),
|
| 317 |
+
norm_f(Conv1d(128, 256, 41, 2, groups=16, padding=20)),
|
| 318 |
+
norm_f(Conv1d(256, 512, 41, 4, groups=16, padding=20)),
|
| 319 |
+
norm_f(Conv1d(512, 1024, 41, 4, groups=16, padding=20)),
|
| 320 |
+
norm_f(Conv1d(1024, 1024, 41, 1, groups=16, padding=20)),
|
| 321 |
+
norm_f(Conv1d(1024, 1024, 5, 1, padding=2)),
|
| 322 |
+
]
|
| 323 |
+
)
|
| 324 |
+
self.conv_post = norm_f(Conv1d(1024, 1, 3, 1, padding=1))
|
| 325 |
+
|
| 326 |
+
def forward(self, x):
|
| 327 |
+
fmap = []
|
| 328 |
+
for l in self.convs:
|
| 329 |
+
x = l(x)
|
| 330 |
+
x = F.leaky_relu(x, LRELU_SLOPE)
|
| 331 |
+
fmap.append(x)
|
| 332 |
+
x = self.conv_post(x)
|
| 333 |
+
fmap.append(x)
|
| 334 |
+
x = torch.flatten(x, 1, -1)
|
| 335 |
+
|
| 336 |
+
return x, fmap
|
| 337 |
+
|
| 338 |
+
|
| 339 |
+
class MultiScaleDiscriminator(torch.nn.Module):
|
| 340 |
+
def __init__(self):
|
| 341 |
+
super(MultiScaleDiscriminator, self).__init__()
|
| 342 |
+
self.discriminators = nn.ModuleList(
|
| 343 |
+
[
|
| 344 |
+
DiscriminatorS(use_spectral_norm=True),
|
| 345 |
+
DiscriminatorS(),
|
| 346 |
+
DiscriminatorS(),
|
| 347 |
+
]
|
| 348 |
+
)
|
| 349 |
+
self.meanpools = nn.ModuleList(
|
| 350 |
+
[AvgPool1d(4, 2, padding=2), AvgPool1d(4, 2, padding=2)]
|
| 351 |
+
)
|
| 352 |
+
|
| 353 |
+
def forward(self, y, y_hat):
|
| 354 |
+
y_d_rs = []
|
| 355 |
+
y_d_gs = []
|
| 356 |
+
fmap_rs = []
|
| 357 |
+
fmap_gs = []
|
| 358 |
+
for i, d in enumerate(self.discriminators):
|
| 359 |
+
if i != 0:
|
| 360 |
+
y = self.meanpools[i - 1](y)
|
| 361 |
+
y_hat = self.meanpools[i - 1](y_hat)
|
| 362 |
+
y_d_r, fmap_r = d(y)
|
| 363 |
+
y_d_g, fmap_g = d(y_hat)
|
| 364 |
+
y_d_rs.append(y_d_r)
|
| 365 |
+
fmap_rs.append(fmap_r)
|
| 366 |
+
y_d_gs.append(y_d_g)
|
| 367 |
+
fmap_gs.append(fmap_g)
|
| 368 |
+
|
| 369 |
+
return y_d_rs, y_d_gs, fmap_rs, fmap_gs
|
| 370 |
+
|
| 371 |
+
|
| 372 |
+
def feature_loss(fmap_r, fmap_g):
|
| 373 |
+
loss = 0
|
| 374 |
+
for dr, dg in zip(fmap_r, fmap_g):
|
| 375 |
+
for rl, gl in zip(dr, dg):
|
| 376 |
+
loss += torch.mean(torch.abs(rl - gl))
|
| 377 |
+
|
| 378 |
+
return loss * 2
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def discriminator_loss(disc_real_outputs, disc_generated_outputs):
|
| 382 |
+
loss = 0
|
| 383 |
+
r_losses = []
|
| 384 |
+
g_losses = []
|
| 385 |
+
for dr, dg in zip(disc_real_outputs, disc_generated_outputs):
|
| 386 |
+
r_loss = torch.mean((1 - dr) ** 2)
|
| 387 |
+
g_loss = torch.mean(dg ** 2)
|
| 388 |
+
loss += r_loss + g_loss
|
| 389 |
+
r_losses.append(r_loss.item())
|
| 390 |
+
g_losses.append(g_loss.item())
|
| 391 |
+
|
| 392 |
+
return loss, r_losses, g_losses
|
| 393 |
+
|
| 394 |
+
|
| 395 |
+
def generator_loss(disc_outputs):
|
| 396 |
+
loss = 0
|
| 397 |
+
gen_losses = []
|
| 398 |
+
for dg in disc_outputs:
|
| 399 |
+
l = torch.mean((1 - dg) ** 2)
|
| 400 |
+
gen_losses.append(l)
|
| 401 |
+
loss += l
|
| 402 |
+
|
| 403 |
+
return loss, gen_losses
|
src/glow_tts/hifi/utils.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import glob
|
| 2 |
+
import os
|
| 3 |
+
import matplotlib
|
| 4 |
+
import torch
|
| 5 |
+
from torch.nn.utils import weight_norm
|
| 6 |
+
|
| 7 |
+
matplotlib.use("Agg")
|
| 8 |
+
import matplotlib.pylab as plt
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def plot_spectrogram(spectrogram):
|
| 12 |
+
fig, ax = plt.subplots(figsize=(10, 2))
|
| 13 |
+
im = ax.imshow(spectrogram, aspect="auto", origin="lower", interpolation="none")
|
| 14 |
+
plt.colorbar(im, ax=ax)
|
| 15 |
+
|
| 16 |
+
fig.canvas.draw()
|
| 17 |
+
plt.close()
|
| 18 |
+
|
| 19 |
+
return fig
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def init_weights(m, mean=0.0, std=0.01):
|
| 23 |
+
classname = m.__class__.__name__
|
| 24 |
+
if classname.find("Conv") != -1:
|
| 25 |
+
m.weight.data.normal_(mean, std)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def apply_weight_norm(m):
|
| 29 |
+
classname = m.__class__.__name__
|
| 30 |
+
if classname.find("Conv") != -1:
|
| 31 |
+
weight_norm(m)
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def get_padding(kernel_size, dilation=1):
|
| 35 |
+
return int((kernel_size * dilation - dilation) / 2)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def load_checkpoint(filepath, device):
|
| 39 |
+
assert os.path.isfile(filepath)
|
| 40 |
+
print("Loading '{}'".format(filepath))
|
| 41 |
+
checkpoint_dict = torch.load(filepath, map_location=device)
|
| 42 |
+
print("Complete.")
|
| 43 |
+
return checkpoint_dict
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
def save_checkpoint(filepath, obj):
|
| 47 |
+
print("Saving checkpoint to {}".format(filepath))
|
| 48 |
+
torch.save(obj, filepath)
|
| 49 |
+
print("Complete.")
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def scan_checkpoint(cp_dir, prefix):
|
| 53 |
+
pattern = os.path.join(cp_dir, prefix + "????????")
|
| 54 |
+
cp_list = glob.glob(pattern)
|
| 55 |
+
if len(cp_list) == 0:
|
| 56 |
+
return None
|
| 57 |
+
return sorted(cp_list)[-1]
|
src/glow_tts/init.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import json
|
| 3 |
+
import argparse
|
| 4 |
+
import math
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn, optim
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
from torch.utils.data import DataLoader
|
| 9 |
+
|
| 10 |
+
from data_utils import TextMelLoader, TextMelCollate
|
| 11 |
+
import models
|
| 12 |
+
import commons
|
| 13 |
+
import utils
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class FlowGenerator_DDI(models.FlowGenerator):
|
| 17 |
+
"""A helper for Data-dependent Initialization"""
|
| 18 |
+
|
| 19 |
+
def __init__(self, *args, **kwargs):
|
| 20 |
+
super().__init__(*args, **kwargs)
|
| 21 |
+
for f in self.decoder.flows:
|
| 22 |
+
if getattr(f, "set_ddi", False):
|
| 23 |
+
f.set_ddi(True)
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def main():
|
| 27 |
+
hps = utils.get_hparams()
|
| 28 |
+
logger = utils.get_logger(hps.log_dir)
|
| 29 |
+
logger.info(hps)
|
| 30 |
+
utils.check_git_hash(hps.log_dir)
|
| 31 |
+
|
| 32 |
+
torch.manual_seed(hps.train.seed)
|
| 33 |
+
|
| 34 |
+
train_dataset = TextMelLoader(hps.data.training_files, hps.data)
|
| 35 |
+
collate_fn = TextMelCollate(1)
|
| 36 |
+
train_loader = DataLoader(
|
| 37 |
+
train_dataset,
|
| 38 |
+
num_workers=8,
|
| 39 |
+
shuffle=True,
|
| 40 |
+
batch_size=hps.train.batch_size,
|
| 41 |
+
pin_memory=True,
|
| 42 |
+
drop_last=True,
|
| 43 |
+
collate_fn=collate_fn,
|
| 44 |
+
)
|
| 45 |
+
symbols = hps.data.punc + hps.data.chars
|
| 46 |
+
generator = FlowGenerator_DDI(
|
| 47 |
+
len(symbols) + getattr(hps.data, "add_blank", False),
|
| 48 |
+
out_channels=hps.data.n_mel_channels,
|
| 49 |
+
**hps.model
|
| 50 |
+
).cuda()
|
| 51 |
+
optimizer_g = commons.Adam(
|
| 52 |
+
generator.parameters(),
|
| 53 |
+
scheduler=hps.train.scheduler,
|
| 54 |
+
dim_model=hps.model.hidden_channels,
|
| 55 |
+
warmup_steps=hps.train.warmup_steps,
|
| 56 |
+
lr=hps.train.learning_rate,
|
| 57 |
+
betas=hps.train.betas,
|
| 58 |
+
eps=hps.train.eps,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
generator.train()
|
| 62 |
+
for batch_idx, (x, x_lengths, y, y_lengths) in enumerate(train_loader):
|
| 63 |
+
x, x_lengths = x.cuda(), x_lengths.cuda()
|
| 64 |
+
y, y_lengths = y.cuda(), y_lengths.cuda()
|
| 65 |
+
|
| 66 |
+
_ = generator(x, x_lengths, y, y_lengths, gen=False)
|
| 67 |
+
break
|
| 68 |
+
|
| 69 |
+
utils.save_checkpoint(
|
| 70 |
+
generator,
|
| 71 |
+
optimizer_g,
|
| 72 |
+
hps.train.learning_rate,
|
| 73 |
+
0,
|
| 74 |
+
os.path.join(hps.model_dir, "ddi_G.pth"),
|
| 75 |
+
)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
if __name__ == "__main__":
|
| 79 |
+
main()
|
src/glow_tts/models.py
ADDED
|
@@ -0,0 +1,403 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
import torch
|
| 3 |
+
from torch import nn
|
| 4 |
+
from torch.nn import functional as F
|
| 5 |
+
|
| 6 |
+
import modules
|
| 7 |
+
import commons
|
| 8 |
+
import attentions
|
| 9 |
+
import monotonic_align
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class DurationPredictor(nn.Module):
|
| 13 |
+
def __init__(self, in_channels, filter_channels, kernel_size, p_dropout):
|
| 14 |
+
super().__init__()
|
| 15 |
+
|
| 16 |
+
self.in_channels = in_channels
|
| 17 |
+
self.filter_channels = filter_channels
|
| 18 |
+
self.kernel_size = kernel_size
|
| 19 |
+
self.p_dropout = p_dropout
|
| 20 |
+
|
| 21 |
+
self.drop = nn.Dropout(p_dropout)
|
| 22 |
+
self.conv_1 = nn.Conv1d(
|
| 23 |
+
in_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 24 |
+
)
|
| 25 |
+
self.norm_1 = attentions.LayerNorm(filter_channels)
|
| 26 |
+
self.conv_2 = nn.Conv1d(
|
| 27 |
+
filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
|
| 28 |
+
)
|
| 29 |
+
self.norm_2 = attentions.LayerNorm(filter_channels)
|
| 30 |
+
self.proj = nn.Conv1d(filter_channels, 1, 1)
|
| 31 |
+
|
| 32 |
+
def forward(self, x, x_mask):
|
| 33 |
+
x = self.conv_1(x * x_mask)
|
| 34 |
+
x = torch.relu(x)
|
| 35 |
+
x = self.norm_1(x)
|
| 36 |
+
x = self.drop(x)
|
| 37 |
+
x = self.conv_2(x * x_mask)
|
| 38 |
+
x = torch.relu(x)
|
| 39 |
+
x = self.norm_2(x)
|
| 40 |
+
x = self.drop(x)
|
| 41 |
+
x = self.proj(x * x_mask)
|
| 42 |
+
return x * x_mask
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
class TextEncoder(nn.Module):
|
| 46 |
+
def __init__(
|
| 47 |
+
self,
|
| 48 |
+
n_vocab,
|
| 49 |
+
out_channels,
|
| 50 |
+
hidden_channels,
|
| 51 |
+
filter_channels,
|
| 52 |
+
filter_channels_dp,
|
| 53 |
+
n_heads,
|
| 54 |
+
n_layers,
|
| 55 |
+
kernel_size,
|
| 56 |
+
p_dropout,
|
| 57 |
+
window_size=None,
|
| 58 |
+
block_length=None,
|
| 59 |
+
mean_only=False,
|
| 60 |
+
prenet=False,
|
| 61 |
+
gin_channels=0,
|
| 62 |
+
):
|
| 63 |
+
|
| 64 |
+
super().__init__()
|
| 65 |
+
|
| 66 |
+
self.n_vocab = n_vocab
|
| 67 |
+
self.out_channels = out_channels
|
| 68 |
+
self.hidden_channels = hidden_channels
|
| 69 |
+
self.filter_channels = filter_channels
|
| 70 |
+
self.filter_channels_dp = filter_channels_dp
|
| 71 |
+
self.n_heads = n_heads
|
| 72 |
+
self.n_layers = n_layers
|
| 73 |
+
self.kernel_size = kernel_size
|
| 74 |
+
self.p_dropout = p_dropout
|
| 75 |
+
self.window_size = window_size
|
| 76 |
+
self.block_length = block_length
|
| 77 |
+
self.mean_only = mean_only
|
| 78 |
+
self.prenet = prenet
|
| 79 |
+
self.gin_channels = gin_channels
|
| 80 |
+
|
| 81 |
+
self.emb = nn.Embedding(n_vocab, hidden_channels)
|
| 82 |
+
nn.init.normal_(self.emb.weight, 0.0, hidden_channels ** -0.5)
|
| 83 |
+
|
| 84 |
+
if prenet:
|
| 85 |
+
self.pre = modules.ConvReluNorm(
|
| 86 |
+
hidden_channels,
|
| 87 |
+
hidden_channels,
|
| 88 |
+
hidden_channels,
|
| 89 |
+
kernel_size=5,
|
| 90 |
+
n_layers=3,
|
| 91 |
+
p_dropout=0.5,
|
| 92 |
+
)
|
| 93 |
+
self.encoder = attentions.Encoder(
|
| 94 |
+
hidden_channels,
|
| 95 |
+
filter_channels,
|
| 96 |
+
n_heads,
|
| 97 |
+
n_layers,
|
| 98 |
+
kernel_size,
|
| 99 |
+
p_dropout,
|
| 100 |
+
window_size=window_size,
|
| 101 |
+
block_length=block_length,
|
| 102 |
+
)
|
| 103 |
+
|
| 104 |
+
self.proj_m = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 105 |
+
if not mean_only:
|
| 106 |
+
self.proj_s = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 107 |
+
self.proj_w = DurationPredictor(
|
| 108 |
+
hidden_channels + gin_channels, filter_channels_dp, kernel_size, p_dropout
|
| 109 |
+
)
|
| 110 |
+
|
| 111 |
+
def forward(self, x, x_lengths, g=None):
|
| 112 |
+
x = self.emb(x) * math.sqrt(self.hidden_channels) # [b, t, h]
|
| 113 |
+
x = torch.transpose(x, 1, -1) # [b, h, t]
|
| 114 |
+
x_mask = torch.unsqueeze(commons.sequence_mask(x_lengths, x.size(2)), 1).to(
|
| 115 |
+
x.dtype
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
if self.prenet:
|
| 119 |
+
x = self.pre(x, x_mask)
|
| 120 |
+
x = self.encoder(x, x_mask)
|
| 121 |
+
|
| 122 |
+
if g is not None:
|
| 123 |
+
g_exp = g.expand(-1, -1, x.size(-1))
|
| 124 |
+
x_dp = torch.cat([torch.detach(x), g_exp], 1)
|
| 125 |
+
else:
|
| 126 |
+
x_dp = torch.detach(x)
|
| 127 |
+
|
| 128 |
+
x_m = self.proj_m(x) * x_mask
|
| 129 |
+
if not self.mean_only:
|
| 130 |
+
x_logs = self.proj_s(x) * x_mask
|
| 131 |
+
else:
|
| 132 |
+
x_logs = torch.zeros_like(x_m)
|
| 133 |
+
|
| 134 |
+
logw = self.proj_w(x_dp, x_mask)
|
| 135 |
+
return x_m, x_logs, logw, x_mask
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
class FlowSpecDecoder(nn.Module):
|
| 139 |
+
def __init__(
|
| 140 |
+
self,
|
| 141 |
+
in_channels,
|
| 142 |
+
hidden_channels,
|
| 143 |
+
kernel_size,
|
| 144 |
+
dilation_rate,
|
| 145 |
+
n_blocks,
|
| 146 |
+
n_layers,
|
| 147 |
+
p_dropout=0.0,
|
| 148 |
+
n_split=4,
|
| 149 |
+
n_sqz=2,
|
| 150 |
+
sigmoid_scale=False,
|
| 151 |
+
gin_channels=0,
|
| 152 |
+
):
|
| 153 |
+
super().__init__()
|
| 154 |
+
|
| 155 |
+
self.in_channels = in_channels
|
| 156 |
+
self.hidden_channels = hidden_channels
|
| 157 |
+
self.kernel_size = kernel_size
|
| 158 |
+
self.dilation_rate = dilation_rate
|
| 159 |
+
self.n_blocks = n_blocks
|
| 160 |
+
self.n_layers = n_layers
|
| 161 |
+
self.p_dropout = p_dropout
|
| 162 |
+
self.n_split = n_split
|
| 163 |
+
self.n_sqz = n_sqz
|
| 164 |
+
self.sigmoid_scale = sigmoid_scale
|
| 165 |
+
self.gin_channels = gin_channels
|
| 166 |
+
|
| 167 |
+
self.flows = nn.ModuleList()
|
| 168 |
+
for b in range(n_blocks):
|
| 169 |
+
self.flows.append(modules.ActNorm(channels=in_channels * n_sqz))
|
| 170 |
+
self.flows.append(
|
| 171 |
+
modules.InvConvNear(channels=in_channels * n_sqz, n_split=n_split)
|
| 172 |
+
)
|
| 173 |
+
self.flows.append(
|
| 174 |
+
attentions.CouplingBlock(
|
| 175 |
+
in_channels * n_sqz,
|
| 176 |
+
hidden_channels,
|
| 177 |
+
kernel_size=kernel_size,
|
| 178 |
+
dilation_rate=dilation_rate,
|
| 179 |
+
n_layers=n_layers,
|
| 180 |
+
gin_channels=gin_channels,
|
| 181 |
+
p_dropout=p_dropout,
|
| 182 |
+
sigmoid_scale=sigmoid_scale,
|
| 183 |
+
)
|
| 184 |
+
)
|
| 185 |
+
|
| 186 |
+
def forward(self, x, x_mask, g=None, reverse=False):
|
| 187 |
+
if not reverse:
|
| 188 |
+
flows = self.flows
|
| 189 |
+
logdet_tot = 0
|
| 190 |
+
else:
|
| 191 |
+
flows = reversed(self.flows)
|
| 192 |
+
logdet_tot = None
|
| 193 |
+
|
| 194 |
+
if self.n_sqz > 1:
|
| 195 |
+
x, x_mask = commons.squeeze(x, x_mask, self.n_sqz)
|
| 196 |
+
for f in flows:
|
| 197 |
+
if not reverse:
|
| 198 |
+
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
| 199 |
+
logdet_tot += logdet
|
| 200 |
+
else:
|
| 201 |
+
x, logdet = f(x, x_mask, g=g, reverse=reverse)
|
| 202 |
+
if self.n_sqz > 1:
|
| 203 |
+
x, x_mask = commons.unsqueeze(x, x_mask, self.n_sqz)
|
| 204 |
+
return x, logdet_tot
|
| 205 |
+
|
| 206 |
+
def store_inverse(self):
|
| 207 |
+
for f in self.flows:
|
| 208 |
+
f.store_inverse()
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class FlowGenerator(nn.Module):
|
| 212 |
+
def __init__(
|
| 213 |
+
self,
|
| 214 |
+
n_vocab,
|
| 215 |
+
hidden_channels,
|
| 216 |
+
filter_channels,
|
| 217 |
+
filter_channels_dp,
|
| 218 |
+
out_channels,
|
| 219 |
+
kernel_size=3,
|
| 220 |
+
n_heads=2,
|
| 221 |
+
n_layers_enc=6,
|
| 222 |
+
p_dropout=0.0,
|
| 223 |
+
n_blocks_dec=12,
|
| 224 |
+
kernel_size_dec=5,
|
| 225 |
+
dilation_rate=5,
|
| 226 |
+
n_block_layers=4,
|
| 227 |
+
p_dropout_dec=0.0,
|
| 228 |
+
n_speakers=0,
|
| 229 |
+
gin_channels=0,
|
| 230 |
+
n_split=4,
|
| 231 |
+
n_sqz=1,
|
| 232 |
+
sigmoid_scale=False,
|
| 233 |
+
window_size=None,
|
| 234 |
+
block_length=None,
|
| 235 |
+
mean_only=False,
|
| 236 |
+
hidden_channels_enc=None,
|
| 237 |
+
hidden_channels_dec=None,
|
| 238 |
+
prenet=False,
|
| 239 |
+
**kwargs
|
| 240 |
+
):
|
| 241 |
+
|
| 242 |
+
super().__init__()
|
| 243 |
+
self.n_vocab = n_vocab
|
| 244 |
+
self.hidden_channels = hidden_channels
|
| 245 |
+
self.filter_channels = filter_channels
|
| 246 |
+
self.filter_channels_dp = filter_channels_dp
|
| 247 |
+
self.out_channels = out_channels
|
| 248 |
+
self.kernel_size = kernel_size
|
| 249 |
+
self.n_heads = n_heads
|
| 250 |
+
self.n_layers_enc = n_layers_enc
|
| 251 |
+
self.p_dropout = p_dropout
|
| 252 |
+
self.n_blocks_dec = n_blocks_dec
|
| 253 |
+
self.kernel_size_dec = kernel_size_dec
|
| 254 |
+
self.dilation_rate = dilation_rate
|
| 255 |
+
self.n_block_layers = n_block_layers
|
| 256 |
+
self.p_dropout_dec = p_dropout_dec
|
| 257 |
+
self.n_speakers = n_speakers
|
| 258 |
+
self.gin_channels = gin_channels
|
| 259 |
+
self.n_split = n_split
|
| 260 |
+
self.n_sqz = n_sqz
|
| 261 |
+
self.sigmoid_scale = sigmoid_scale
|
| 262 |
+
self.window_size = window_size
|
| 263 |
+
self.block_length = block_length
|
| 264 |
+
self.mean_only = mean_only
|
| 265 |
+
self.hidden_channels_enc = hidden_channels_enc
|
| 266 |
+
self.hidden_channels_dec = hidden_channels_dec
|
| 267 |
+
self.prenet = prenet
|
| 268 |
+
|
| 269 |
+
self.encoder = TextEncoder(
|
| 270 |
+
n_vocab,
|
| 271 |
+
out_channels,
|
| 272 |
+
hidden_channels_enc or hidden_channels,
|
| 273 |
+
filter_channels,
|
| 274 |
+
filter_channels_dp,
|
| 275 |
+
n_heads,
|
| 276 |
+
n_layers_enc,
|
| 277 |
+
kernel_size,
|
| 278 |
+
p_dropout,
|
| 279 |
+
window_size=window_size,
|
| 280 |
+
block_length=block_length,
|
| 281 |
+
mean_only=mean_only,
|
| 282 |
+
prenet=prenet,
|
| 283 |
+
gin_channels=gin_channels,
|
| 284 |
+
)
|
| 285 |
+
|
| 286 |
+
self.decoder = FlowSpecDecoder(
|
| 287 |
+
out_channels,
|
| 288 |
+
hidden_channels_dec or hidden_channels,
|
| 289 |
+
kernel_size_dec,
|
| 290 |
+
dilation_rate,
|
| 291 |
+
n_blocks_dec,
|
| 292 |
+
n_block_layers,
|
| 293 |
+
p_dropout=p_dropout_dec,
|
| 294 |
+
n_split=n_split,
|
| 295 |
+
n_sqz=n_sqz,
|
| 296 |
+
sigmoid_scale=sigmoid_scale,
|
| 297 |
+
gin_channels=gin_channels,
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
if n_speakers > 1:
|
| 301 |
+
self.emb_g = nn.Embedding(n_speakers, gin_channels)
|
| 302 |
+
nn.init.uniform_(self.emb_g.weight, -0.1, 0.1)
|
| 303 |
+
|
| 304 |
+
def forward(
|
| 305 |
+
self,
|
| 306 |
+
x,
|
| 307 |
+
x_lengths,
|
| 308 |
+
y=None,
|
| 309 |
+
y_lengths=None,
|
| 310 |
+
g=None,
|
| 311 |
+
gen=False,
|
| 312 |
+
noise_scale=1.0,
|
| 313 |
+
length_scale=1.0,
|
| 314 |
+
):
|
| 315 |
+
if g is not None:
|
| 316 |
+
g = F.normalize(self.emb_g(g)).unsqueeze(-1) # [b, h]
|
| 317 |
+
x_m, x_logs, logw, x_mask = self.encoder(x, x_lengths, g=g)
|
| 318 |
+
|
| 319 |
+
if gen:
|
| 320 |
+
w = torch.exp(logw) * x_mask * length_scale
|
| 321 |
+
w_ceil = torch.ceil(w)
|
| 322 |
+
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
|
| 323 |
+
y_max_length = None
|
| 324 |
+
else:
|
| 325 |
+
y_max_length = y.size(2)
|
| 326 |
+
y, y_lengths, y_max_length = self.preprocess(y, y_lengths, y_max_length)
|
| 327 |
+
z_mask = torch.unsqueeze(commons.sequence_mask(y_lengths, y_max_length), 1).to(
|
| 328 |
+
x_mask.dtype
|
| 329 |
+
)
|
| 330 |
+
attn_mask = torch.unsqueeze(x_mask, -1) * torch.unsqueeze(z_mask, 2)
|
| 331 |
+
|
| 332 |
+
if gen:
|
| 333 |
+
attn = commons.generate_path(
|
| 334 |
+
w_ceil.squeeze(1), attn_mask.squeeze(1)
|
| 335 |
+
).unsqueeze(1)
|
| 336 |
+
z_m = torch.matmul(
|
| 337 |
+
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)
|
| 338 |
+
).transpose(
|
| 339 |
+
1, 2
|
| 340 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 341 |
+
z_logs = torch.matmul(
|
| 342 |
+
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)
|
| 343 |
+
).transpose(
|
| 344 |
+
1, 2
|
| 345 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 346 |
+
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
|
| 347 |
+
|
| 348 |
+
z = (z_m + torch.exp(z_logs) * torch.randn_like(z_m) * noise_scale) * z_mask
|
| 349 |
+
y, logdet = self.decoder(z, z_mask, g=g, reverse=True)
|
| 350 |
+
return (
|
| 351 |
+
(y, z_m, z_logs, logdet, z_mask),
|
| 352 |
+
(x_m, x_logs, x_mask),
|
| 353 |
+
(attn, logw, logw_),
|
| 354 |
+
)
|
| 355 |
+
else:
|
| 356 |
+
z, logdet = self.decoder(y, z_mask, g=g, reverse=False)
|
| 357 |
+
with torch.no_grad():
|
| 358 |
+
x_s_sq_r = torch.exp(-2 * x_logs)
|
| 359 |
+
logp1 = torch.sum(-0.5 * math.log(2 * math.pi) - x_logs, [1]).unsqueeze(
|
| 360 |
+
-1
|
| 361 |
+
) # [b, t, 1]
|
| 362 |
+
logp2 = torch.matmul(
|
| 363 |
+
x_s_sq_r.transpose(1, 2), -0.5 * (z ** 2)
|
| 364 |
+
) # [b, t, d] x [b, d, t'] = [b, t, t']
|
| 365 |
+
logp3 = torch.matmul(
|
| 366 |
+
(x_m * x_s_sq_r).transpose(1, 2), z
|
| 367 |
+
) # [b, t, d] x [b, d, t'] = [b, t, t']
|
| 368 |
+
logp4 = torch.sum(-0.5 * (x_m ** 2) * x_s_sq_r, [1]).unsqueeze(
|
| 369 |
+
-1
|
| 370 |
+
) # [b, t, 1]
|
| 371 |
+
logp = logp1 + logp2 + logp3 + logp4 # [b, t, t']
|
| 372 |
+
|
| 373 |
+
attn = (
|
| 374 |
+
monotonic_align.maximum_path(logp, attn_mask.squeeze(1))
|
| 375 |
+
.unsqueeze(1)
|
| 376 |
+
.detach()
|
| 377 |
+
)
|
| 378 |
+
z_m = torch.matmul(
|
| 379 |
+
attn.squeeze(1).transpose(1, 2), x_m.transpose(1, 2)
|
| 380 |
+
).transpose(
|
| 381 |
+
1, 2
|
| 382 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 383 |
+
z_logs = torch.matmul(
|
| 384 |
+
attn.squeeze(1).transpose(1, 2), x_logs.transpose(1, 2)
|
| 385 |
+
).transpose(
|
| 386 |
+
1, 2
|
| 387 |
+
) # [b, t', t], [b, t, d] -> [b, d, t']
|
| 388 |
+
logw_ = torch.log(1e-8 + torch.sum(attn, -1)) * x_mask
|
| 389 |
+
return (
|
| 390 |
+
(z, z_m, z_logs, logdet, z_mask),
|
| 391 |
+
(x_m, x_logs, x_mask),
|
| 392 |
+
(attn, logw, logw_),
|
| 393 |
+
)
|
| 394 |
+
|
| 395 |
+
def preprocess(self, y, y_lengths, y_max_length):
|
| 396 |
+
if y_max_length is not None:
|
| 397 |
+
y_max_length = (y_max_length // self.n_sqz) * self.n_sqz
|
| 398 |
+
y = y[:, :, :y_max_length]
|
| 399 |
+
y_lengths = (y_lengths // self.n_sqz) * self.n_sqz
|
| 400 |
+
return y, y_lengths, y_max_length
|
| 401 |
+
|
| 402 |
+
def store_inverse(self):
|
| 403 |
+
self.decoder.store_inverse()
|
src/glow_tts/modules.py
ADDED
|
@@ -0,0 +1,276 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import copy
|
| 2 |
+
import math
|
| 3 |
+
import numpy as np
|
| 4 |
+
import scipy
|
| 5 |
+
import torch
|
| 6 |
+
from torch import nn
|
| 7 |
+
from torch.nn import functional as F
|
| 8 |
+
|
| 9 |
+
import commons
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class LayerNorm(nn.Module):
|
| 13 |
+
def __init__(self, channels, eps=1e-4):
|
| 14 |
+
super().__init__()
|
| 15 |
+
self.channels = channels
|
| 16 |
+
self.eps = eps
|
| 17 |
+
|
| 18 |
+
self.gamma = nn.Parameter(torch.ones(channels))
|
| 19 |
+
self.beta = nn.Parameter(torch.zeros(channels))
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
n_dims = len(x.shape)
|
| 23 |
+
mean = torch.mean(x, 1, keepdim=True)
|
| 24 |
+
variance = torch.mean((x - mean) ** 2, 1, keepdim=True)
|
| 25 |
+
|
| 26 |
+
x = (x - mean) * torch.rsqrt(variance + self.eps)
|
| 27 |
+
|
| 28 |
+
shape = [1, -1] + [1] * (n_dims - 2)
|
| 29 |
+
x = x * self.gamma.view(*shape) + self.beta.view(*shape)
|
| 30 |
+
return x
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ConvReluNorm(nn.Module):
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
in_channels,
|
| 37 |
+
hidden_channels,
|
| 38 |
+
out_channels,
|
| 39 |
+
kernel_size,
|
| 40 |
+
n_layers,
|
| 41 |
+
p_dropout,
|
| 42 |
+
):
|
| 43 |
+
super().__init__()
|
| 44 |
+
self.in_channels = in_channels
|
| 45 |
+
self.hidden_channels = hidden_channels
|
| 46 |
+
self.out_channels = out_channels
|
| 47 |
+
self.kernel_size = kernel_size
|
| 48 |
+
self.n_layers = n_layers
|
| 49 |
+
self.p_dropout = p_dropout
|
| 50 |
+
assert n_layers > 1, "Number of layers should be larger than 0."
|
| 51 |
+
|
| 52 |
+
self.conv_layers = nn.ModuleList()
|
| 53 |
+
self.norm_layers = nn.ModuleList()
|
| 54 |
+
self.conv_layers.append(
|
| 55 |
+
nn.Conv1d(
|
| 56 |
+
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
|
| 57 |
+
)
|
| 58 |
+
)
|
| 59 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 60 |
+
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
|
| 61 |
+
for _ in range(n_layers - 1):
|
| 62 |
+
self.conv_layers.append(
|
| 63 |
+
nn.Conv1d(
|
| 64 |
+
hidden_channels,
|
| 65 |
+
hidden_channels,
|
| 66 |
+
kernel_size,
|
| 67 |
+
padding=kernel_size // 2,
|
| 68 |
+
)
|
| 69 |
+
)
|
| 70 |
+
self.norm_layers.append(LayerNorm(hidden_channels))
|
| 71 |
+
self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
|
| 72 |
+
self.proj.weight.data.zero_()
|
| 73 |
+
self.proj.bias.data.zero_()
|
| 74 |
+
|
| 75 |
+
def forward(self, x, x_mask):
|
| 76 |
+
x_org = x
|
| 77 |
+
for i in range(self.n_layers):
|
| 78 |
+
x = self.conv_layers[i](x * x_mask)
|
| 79 |
+
x = self.norm_layers[i](x)
|
| 80 |
+
x = self.relu_drop(x)
|
| 81 |
+
x = x_org + self.proj(x)
|
| 82 |
+
return x * x_mask
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class WN(torch.nn.Module):
|
| 86 |
+
def __init__(
|
| 87 |
+
self,
|
| 88 |
+
in_channels,
|
| 89 |
+
hidden_channels,
|
| 90 |
+
kernel_size,
|
| 91 |
+
dilation_rate,
|
| 92 |
+
n_layers,
|
| 93 |
+
gin_channels=0,
|
| 94 |
+
p_dropout=0,
|
| 95 |
+
):
|
| 96 |
+
super(WN, self).__init__()
|
| 97 |
+
assert kernel_size % 2 == 1
|
| 98 |
+
assert hidden_channels % 2 == 0
|
| 99 |
+
self.in_channels = in_channels
|
| 100 |
+
self.hidden_channels = hidden_channels
|
| 101 |
+
self.kernel_size = (kernel_size,)
|
| 102 |
+
self.dilation_rate = dilation_rate
|
| 103 |
+
self.n_layers = n_layers
|
| 104 |
+
self.gin_channels = gin_channels
|
| 105 |
+
self.p_dropout = p_dropout
|
| 106 |
+
|
| 107 |
+
self.in_layers = torch.nn.ModuleList()
|
| 108 |
+
self.res_skip_layers = torch.nn.ModuleList()
|
| 109 |
+
self.drop = nn.Dropout(p_dropout)
|
| 110 |
+
|
| 111 |
+
if gin_channels != 0:
|
| 112 |
+
cond_layer = torch.nn.Conv1d(
|
| 113 |
+
gin_channels, 2 * hidden_channels * n_layers, 1
|
| 114 |
+
)
|
| 115 |
+
self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
|
| 116 |
+
|
| 117 |
+
for i in range(n_layers):
|
| 118 |
+
dilation = dilation_rate ** i
|
| 119 |
+
padding = int((kernel_size * dilation - dilation) / 2)
|
| 120 |
+
in_layer = torch.nn.Conv1d(
|
| 121 |
+
hidden_channels,
|
| 122 |
+
2 * hidden_channels,
|
| 123 |
+
kernel_size,
|
| 124 |
+
dilation=dilation,
|
| 125 |
+
padding=padding,
|
| 126 |
+
)
|
| 127 |
+
in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
|
| 128 |
+
self.in_layers.append(in_layer)
|
| 129 |
+
|
| 130 |
+
# last one is not necessary
|
| 131 |
+
if i < n_layers - 1:
|
| 132 |
+
res_skip_channels = 2 * hidden_channels
|
| 133 |
+
else:
|
| 134 |
+
res_skip_channels = hidden_channels
|
| 135 |
+
|
| 136 |
+
res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
|
| 137 |
+
res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
|
| 138 |
+
self.res_skip_layers.append(res_skip_layer)
|
| 139 |
+
|
| 140 |
+
def forward(self, x, x_mask=None, g=None, **kwargs):
|
| 141 |
+
output = torch.zeros_like(x)
|
| 142 |
+
n_channels_tensor = torch.IntTensor([self.hidden_channels])
|
| 143 |
+
|
| 144 |
+
if g is not None:
|
| 145 |
+
g = self.cond_layer(g)
|
| 146 |
+
|
| 147 |
+
for i in range(self.n_layers):
|
| 148 |
+
x_in = self.in_layers[i](x)
|
| 149 |
+
x_in = self.drop(x_in)
|
| 150 |
+
if g is not None:
|
| 151 |
+
cond_offset = i * 2 * self.hidden_channels
|
| 152 |
+
g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
|
| 153 |
+
else:
|
| 154 |
+
g_l = torch.zeros_like(x_in)
|
| 155 |
+
|
| 156 |
+
acts = commons.fused_add_tanh_sigmoid_multiply(x_in, g_l, n_channels_tensor)
|
| 157 |
+
|
| 158 |
+
res_skip_acts = self.res_skip_layers[i](acts)
|
| 159 |
+
if i < self.n_layers - 1:
|
| 160 |
+
x = (x + res_skip_acts[:, : self.hidden_channels, :]) * x_mask
|
| 161 |
+
output = output + res_skip_acts[:, self.hidden_channels :, :]
|
| 162 |
+
else:
|
| 163 |
+
output = output + res_skip_acts
|
| 164 |
+
return output * x_mask
|
| 165 |
+
|
| 166 |
+
def remove_weight_norm(self):
|
| 167 |
+
if self.gin_channels != 0:
|
| 168 |
+
torch.nn.utils.remove_weight_norm(self.cond_layer)
|
| 169 |
+
for l in self.in_layers:
|
| 170 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 171 |
+
for l in self.res_skip_layers:
|
| 172 |
+
torch.nn.utils.remove_weight_norm(l)
|
| 173 |
+
|
| 174 |
+
|
| 175 |
+
class ActNorm(nn.Module):
|
| 176 |
+
def __init__(self, channels, ddi=False, **kwargs):
|
| 177 |
+
super().__init__()
|
| 178 |
+
self.channels = channels
|
| 179 |
+
self.initialized = not ddi
|
| 180 |
+
|
| 181 |
+
self.logs = nn.Parameter(torch.zeros(1, channels, 1))
|
| 182 |
+
self.bias = nn.Parameter(torch.zeros(1, channels, 1))
|
| 183 |
+
|
| 184 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
| 185 |
+
if x_mask is None:
|
| 186 |
+
x_mask = torch.ones(x.size(0), 1, x.size(2)).to(
|
| 187 |
+
device=x.device, dtype=x.dtype
|
| 188 |
+
)
|
| 189 |
+
x_len = torch.sum(x_mask, [1, 2])
|
| 190 |
+
if not self.initialized:
|
| 191 |
+
self.initialize(x, x_mask)
|
| 192 |
+
self.initialized = True
|
| 193 |
+
|
| 194 |
+
if reverse:
|
| 195 |
+
z = (x - self.bias) * torch.exp(-self.logs) * x_mask
|
| 196 |
+
logdet = None
|
| 197 |
+
else:
|
| 198 |
+
z = (self.bias + torch.exp(self.logs) * x) * x_mask
|
| 199 |
+
logdet = torch.sum(self.logs) * x_len # [b]
|
| 200 |
+
|
| 201 |
+
return z, logdet
|
| 202 |
+
|
| 203 |
+
def store_inverse(self):
|
| 204 |
+
pass
|
| 205 |
+
|
| 206 |
+
def set_ddi(self, ddi):
|
| 207 |
+
self.initialized = not ddi
|
| 208 |
+
|
| 209 |
+
def initialize(self, x, x_mask):
|
| 210 |
+
with torch.no_grad():
|
| 211 |
+
denom = torch.sum(x_mask, [0, 2])
|
| 212 |
+
m = torch.sum(x * x_mask, [0, 2]) / denom
|
| 213 |
+
m_sq = torch.sum(x * x * x_mask, [0, 2]) / denom
|
| 214 |
+
v = m_sq - (m ** 2)
|
| 215 |
+
logs = 0.5 * torch.log(torch.clamp_min(v, 1e-6))
|
| 216 |
+
|
| 217 |
+
bias_init = (
|
| 218 |
+
(-m * torch.exp(-logs)).view(*self.bias.shape).to(dtype=self.bias.dtype)
|
| 219 |
+
)
|
| 220 |
+
logs_init = (-logs).view(*self.logs.shape).to(dtype=self.logs.dtype)
|
| 221 |
+
|
| 222 |
+
self.bias.data.copy_(bias_init)
|
| 223 |
+
self.logs.data.copy_(logs_init)
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
class InvConvNear(nn.Module):
|
| 227 |
+
def __init__(self, channels, n_split=4, no_jacobian=False, **kwargs):
|
| 228 |
+
super().__init__()
|
| 229 |
+
assert n_split % 2 == 0
|
| 230 |
+
self.channels = channels
|
| 231 |
+
self.n_split = n_split
|
| 232 |
+
self.no_jacobian = no_jacobian
|
| 233 |
+
|
| 234 |
+
w_init = torch.qr(torch.FloatTensor(self.n_split, self.n_split).normal_())[0]
|
| 235 |
+
if torch.det(w_init) < 0:
|
| 236 |
+
w_init[:, 0] = -1 * w_init[:, 0]
|
| 237 |
+
self.weight = nn.Parameter(w_init)
|
| 238 |
+
|
| 239 |
+
def forward(self, x, x_mask=None, reverse=False, **kwargs):
|
| 240 |
+
b, c, t = x.size()
|
| 241 |
+
assert c % self.n_split == 0
|
| 242 |
+
if x_mask is None:
|
| 243 |
+
x_mask = 1
|
| 244 |
+
x_len = torch.ones((b,), dtype=x.dtype, device=x.device) * t
|
| 245 |
+
else:
|
| 246 |
+
x_len = torch.sum(x_mask, [1, 2])
|
| 247 |
+
|
| 248 |
+
x = x.view(b, 2, c // self.n_split, self.n_split // 2, t)
|
| 249 |
+
x = (
|
| 250 |
+
x.permute(0, 1, 3, 2, 4)
|
| 251 |
+
.contiguous()
|
| 252 |
+
.view(b, self.n_split, c // self.n_split, t)
|
| 253 |
+
)
|
| 254 |
+
|
| 255 |
+
if reverse:
|
| 256 |
+
if hasattr(self, "weight_inv"):
|
| 257 |
+
weight = self.weight_inv
|
| 258 |
+
else:
|
| 259 |
+
weight = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
| 260 |
+
logdet = None
|
| 261 |
+
else:
|
| 262 |
+
weight = self.weight
|
| 263 |
+
if self.no_jacobian:
|
| 264 |
+
logdet = 0
|
| 265 |
+
else:
|
| 266 |
+
logdet = torch.logdet(self.weight) * (c / self.n_split) * x_len # [b]
|
| 267 |
+
|
| 268 |
+
weight = weight.view(self.n_split, self.n_split, 1, 1)
|
| 269 |
+
z = F.conv2d(x, weight)
|
| 270 |
+
|
| 271 |
+
z = z.view(b, 2, self.n_split // 2, c // self.n_split, t)
|
| 272 |
+
z = z.permute(0, 1, 3, 2, 4).contiguous().view(b, c, t) * x_mask
|
| 273 |
+
return z, logdet
|
| 274 |
+
|
| 275 |
+
def store_inverse(self):
|
| 276 |
+
self.weight_inv = torch.inverse(self.weight.float()).to(dtype=self.weight.dtype)
|
src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pkg_resources
|
| 2 |
+
|
| 3 |
+
__version__ = pkg_resources.get_distribution("monotonic_align").version
|
| 4 |
+
|
| 5 |
+
from monotonic_align.mas import *
|
src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/core.cpython-37m-x86_64-linux-gnu.so
ADDED
|
Binary file (982 kB). View file
|
|
|
src/glow_tts/monotonic_align/build/lib.linux-x86_64-cpython-37/monotonic_align/mas.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import overload
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from monotonic_align.core import maximum_path_c
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mask_from_len(lens: torch.Tensor, max_len=None):
|
| 8 |
+
"""
|
| 9 |
+
Make a `mask` from lens.
|
| 10 |
+
|
| 11 |
+
:param inputs: (B, T, D)
|
| 12 |
+
:param lens: (B)
|
| 13 |
+
|
| 14 |
+
:return:
|
| 15 |
+
`mask`: (B, T)
|
| 16 |
+
"""
|
| 17 |
+
if max_len is None:
|
| 18 |
+
max_len = lens.max()
|
| 19 |
+
index = torch.arange(max_len).to(lens).view(1, -1)
|
| 20 |
+
return index < lens.unsqueeze(1) # (B, T)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mask_from_lens(
|
| 24 |
+
similarity: torch.Tensor,
|
| 25 |
+
symbol_lens: torch.Tensor,
|
| 26 |
+
mel_lens: torch.Tensor,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
:param similarity: (B, S, T)
|
| 30 |
+
:param symbol_lens: (B,)
|
| 31 |
+
:param mel_lens: (B,)
|
| 32 |
+
"""
|
| 33 |
+
_, S, T = similarity.size()
|
| 34 |
+
mask_S = mask_from_len(symbol_lens, S)
|
| 35 |
+
mask_T = mask_from_len(mel_lens, T)
|
| 36 |
+
mask_ST = mask_S.unsqueeze(2) * mask_T.unsqueeze(1)
|
| 37 |
+
return mask_ST.to(similarity)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def maximum_path(value, mask=None):
|
| 41 |
+
"""Cython optimised version.
|
| 42 |
+
value: [b, t_x, t_y]
|
| 43 |
+
mask: [b, t_x, t_y]
|
| 44 |
+
"""
|
| 45 |
+
if mask is None:
|
| 46 |
+
mask = torch.zeros_like(value)
|
| 47 |
+
|
| 48 |
+
value = value * mask
|
| 49 |
+
device = value.device
|
| 50 |
+
dtype = value.dtype
|
| 51 |
+
value = value.data.cpu().numpy().astype(np.float32)
|
| 52 |
+
path = np.zeros_like(value).astype(np.int32)
|
| 53 |
+
mask = mask.data.cpu().numpy()
|
| 54 |
+
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
| 55 |
+
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
| 56 |
+
maximum_path_c(path, value, t_x_max, t_y_max)
|
| 57 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
src/glow_tts/monotonic_align/build/temp.linux-x86_64-cpython-37/monotonic_align/core.o
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:dd5e43fed9aa26e4d448896b5090e5813d06ca4a8c3c635f8e7ee1d1d4bf41dc
|
| 3 |
+
size 1746200
|
src/glow_tts/monotonic_align/monotonic_align.egg-info/PKG-INFO
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: monotonic-align
|
| 3 |
+
Version: 1.1
|
src/glow_tts/monotonic_align/monotonic_align.egg-info/SOURCES.txt
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
pyproject.toml
|
| 2 |
+
setup.py
|
| 3 |
+
monotonic_align/__init__.py
|
| 4 |
+
monotonic_align/core.c
|
| 5 |
+
monotonic_align/mas.py
|
| 6 |
+
monotonic_align.egg-info/PKG-INFO
|
| 7 |
+
monotonic_align.egg-info/SOURCES.txt
|
| 8 |
+
monotonic_align.egg-info/dependency_links.txt
|
| 9 |
+
monotonic_align.egg-info/requires.txt
|
| 10 |
+
monotonic_align.egg-info/top_level.txt
|
src/glow_tts/monotonic_align/monotonic_align.egg-info/dependency_links.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
|
src/glow_tts/monotonic_align/monotonic_align.egg-info/requires.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
numpy
|
src/glow_tts/monotonic_align/monotonic_align.egg-info/top_level.txt
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
monotonic_align
|
src/glow_tts/monotonic_align/monotonic_align/__init__.py
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pkg_resources
|
| 2 |
+
|
| 3 |
+
__version__ = pkg_resources.get_distribution("monotonic_align").version
|
| 4 |
+
|
| 5 |
+
from monotonic_align.mas import *
|
src/glow_tts/monotonic_align/monotonic_align/core.c
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|
src/glow_tts/monotonic_align/monotonic_align/core.pyx
ADDED
|
@@ -0,0 +1,45 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy as np
|
| 2 |
+
cimport numpy as np
|
| 3 |
+
cimport cython
|
| 4 |
+
from cython.parallel import prange
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
@cython.boundscheck(False)
|
| 8 |
+
@cython.wraparound(False)
|
| 9 |
+
cdef void maximum_path_each(int[:,::1] path, float[:,::1] value, int t_x, int t_y, float max_neg_val) nogil:
|
| 10 |
+
cdef int x
|
| 11 |
+
cdef int y
|
| 12 |
+
cdef float v_prev
|
| 13 |
+
cdef float v_cur
|
| 14 |
+
cdef float tmp
|
| 15 |
+
cdef int index = t_x - 1
|
| 16 |
+
|
| 17 |
+
for y in range(t_y):
|
| 18 |
+
for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
|
| 19 |
+
if x == y:
|
| 20 |
+
v_cur = max_neg_val
|
| 21 |
+
else:
|
| 22 |
+
v_cur = value[x, y-1]
|
| 23 |
+
if x == 0:
|
| 24 |
+
if y == 0:
|
| 25 |
+
v_prev = 0.
|
| 26 |
+
else:
|
| 27 |
+
v_prev = max_neg_val
|
| 28 |
+
else:
|
| 29 |
+
v_prev = value[x-1, y-1]
|
| 30 |
+
value[x, y] = max(v_cur, v_prev) + value[x, y]
|
| 31 |
+
|
| 32 |
+
for y in range(t_y - 1, -1, -1):
|
| 33 |
+
path[index, y] = 1
|
| 34 |
+
if index != 0 and (index == y or value[index, y-1] < value[index-1, y-1]):
|
| 35 |
+
index = index - 1
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
@cython.boundscheck(False)
|
| 39 |
+
@cython.wraparound(False)
|
| 40 |
+
cpdef void maximum_path_c(int[:,:,::1] paths, float[:,:,::1] values, int[::1] t_xs, int[::1] t_ys, float max_neg_val=-1e9) nogil:
|
| 41 |
+
cdef int b = values.shape[0]
|
| 42 |
+
|
| 43 |
+
cdef int i
|
| 44 |
+
for i in prange(b, nogil=True):
|
| 45 |
+
maximum_path_each(paths[i], values[i], t_xs[i], t_ys[i], max_neg_val)
|
src/glow_tts/monotonic_align/monotonic_align/mas.py
ADDED
|
@@ -0,0 +1,57 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import overload
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
from monotonic_align.core import maximum_path_c
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def mask_from_len(lens: torch.Tensor, max_len=None):
|
| 8 |
+
"""
|
| 9 |
+
Make a `mask` from lens.
|
| 10 |
+
|
| 11 |
+
:param inputs: (B, T, D)
|
| 12 |
+
:param lens: (B)
|
| 13 |
+
|
| 14 |
+
:return:
|
| 15 |
+
`mask`: (B, T)
|
| 16 |
+
"""
|
| 17 |
+
if max_len is None:
|
| 18 |
+
max_len = lens.max()
|
| 19 |
+
index = torch.arange(max_len).to(lens).view(1, -1)
|
| 20 |
+
return index < lens.unsqueeze(1) # (B, T)
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
def mask_from_lens(
|
| 24 |
+
similarity: torch.Tensor,
|
| 25 |
+
symbol_lens: torch.Tensor,
|
| 26 |
+
mel_lens: torch.Tensor,
|
| 27 |
+
):
|
| 28 |
+
"""
|
| 29 |
+
:param similarity: (B, S, T)
|
| 30 |
+
:param symbol_lens: (B,)
|
| 31 |
+
:param mel_lens: (B,)
|
| 32 |
+
"""
|
| 33 |
+
_, S, T = similarity.size()
|
| 34 |
+
mask_S = mask_from_len(symbol_lens, S)
|
| 35 |
+
mask_T = mask_from_len(mel_lens, T)
|
| 36 |
+
mask_ST = mask_S.unsqueeze(2) * mask_T.unsqueeze(1)
|
| 37 |
+
return mask_ST.to(similarity)
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def maximum_path(value, mask=None):
|
| 41 |
+
"""Cython optimised version.
|
| 42 |
+
value: [b, t_x, t_y]
|
| 43 |
+
mask: [b, t_x, t_y]
|
| 44 |
+
"""
|
| 45 |
+
if mask is None:
|
| 46 |
+
mask = torch.zeros_like(value)
|
| 47 |
+
|
| 48 |
+
value = value * mask
|
| 49 |
+
device = value.device
|
| 50 |
+
dtype = value.dtype
|
| 51 |
+
value = value.data.cpu().numpy().astype(np.float32)
|
| 52 |
+
path = np.zeros_like(value).astype(np.int32)
|
| 53 |
+
mask = mask.data.cpu().numpy()
|
| 54 |
+
t_x_max = mask.sum(1)[:, 0].astype(np.int32)
|
| 55 |
+
t_y_max = mask.sum(2)[:, 0].astype(np.int32)
|
| 56 |
+
maximum_path_c(path, value, t_x_max, t_y_max)
|
| 57 |
+
return torch.from_numpy(path).to(device=device, dtype=dtype)
|
src/glow_tts/monotonic_align/pyproject.toml
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[build-system]
|
| 2 |
+
requires = [
|
| 3 |
+
"wheel",
|
| 4 |
+
"setuptools",
|
| 5 |
+
"cython>=0.24.0",
|
| 6 |
+
"numpy<v1.20.0",
|
| 7 |
+
]
|
src/glow_tts/monotonic_align/setup.py
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import numpy
|
| 2 |
+
from setuptools import Extension, find_packages
|
| 3 |
+
from distutils.core import setup
|
| 4 |
+
from Cython.Build import cythonize
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
_VERSION = "1.1"
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
ext_modules = cythonize(
|
| 11 |
+
"monotonic_align/core.pyx",
|
| 12 |
+
compiler_directives={"language_level": "3"},
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
setup(
|
| 16 |
+
name="monotonic_align",
|
| 17 |
+
ext_modules=ext_modules,
|
| 18 |
+
include_dirs=[numpy.get_include(), "monotonic_align"],
|
| 19 |
+
packages=find_packages(),
|
| 20 |
+
setup_requires=["numpy", "cython"],
|
| 21 |
+
install_requires=["numpy"],
|
| 22 |
+
version=_VERSION,
|
| 23 |
+
)
|
src/glow_tts/stft.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
BSD 3-Clause License
|
| 3 |
+
|
| 4 |
+
Copyright (c) 2017, Prem Seetharaman
|
| 5 |
+
All rights reserved.
|
| 6 |
+
|
| 7 |
+
* Redistribution and use in source and binary forms, with or without
|
| 8 |
+
modification, are permitted provided that the following conditions are met:
|
| 9 |
+
|
| 10 |
+
* Redistributions of source code must retain the above copyright notice,
|
| 11 |
+
this list of conditions and the following disclaimer.
|
| 12 |
+
|
| 13 |
+
* Redistributions in binary form must reproduce the above copyright notice, this
|
| 14 |
+
list of conditions and the following disclaimer in the
|
| 15 |
+
documentation and/or other materials provided with the distribution.
|
| 16 |
+
|
| 17 |
+
* Neither the name of the copyright holder nor the names of its
|
| 18 |
+
contributors may be used to endorse or promote products derived from this
|
| 19 |
+
software without specific prior written permission.
|
| 20 |
+
|
| 21 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 22 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 23 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 24 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR
|
| 25 |
+
ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
|
| 26 |
+
(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
|
| 27 |
+
LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON
|
| 28 |
+
ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
|
| 29 |
+
(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
|
| 30 |
+
SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
import numpy as np
|
| 35 |
+
import torch.nn.functional as F
|
| 36 |
+
from torch.autograd import Variable
|
| 37 |
+
from scipy.signal import get_window
|
| 38 |
+
from librosa.util import pad_center, tiny
|
| 39 |
+
from librosa import stft, istft
|
| 40 |
+
from audio_processing import window_sumsquare
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class STFT(torch.nn.Module):
|
| 44 |
+
"""adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
|
| 45 |
+
|
| 46 |
+
def __init__(
|
| 47 |
+
self, filter_length=800, hop_length=200, win_length=800, window="hann"
|
| 48 |
+
):
|
| 49 |
+
super(STFT, self).__init__()
|
| 50 |
+
self.filter_length = filter_length
|
| 51 |
+
self.hop_length = hop_length
|
| 52 |
+
self.win_length = win_length
|
| 53 |
+
self.window = window
|
| 54 |
+
self.forward_transform = None
|
| 55 |
+
scale = self.filter_length / self.hop_length
|
| 56 |
+
fourier_basis = np.fft.fft(np.eye(self.filter_length))
|
| 57 |
+
|
| 58 |
+
cutoff = int((self.filter_length / 2 + 1))
|
| 59 |
+
fourier_basis = np.vstack(
|
| 60 |
+
[np.real(fourier_basis[:cutoff, :]), np.imag(fourier_basis[:cutoff, :])]
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
|
| 64 |
+
inverse_basis = torch.FloatTensor(
|
| 65 |
+
np.linalg.pinv(scale * fourier_basis).T[:, None, :]
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
if window is not None:
|
| 69 |
+
assert filter_length >= win_length
|
| 70 |
+
# get window and zero center pad it to filter_length
|
| 71 |
+
fft_window = get_window(window, win_length, fftbins=True)
|
| 72 |
+
fft_window = pad_center(fft_window, filter_length)
|
| 73 |
+
fft_window = torch.from_numpy(fft_window).float()
|
| 74 |
+
|
| 75 |
+
# window the bases
|
| 76 |
+
forward_basis *= fft_window
|
| 77 |
+
inverse_basis *= fft_window
|
| 78 |
+
|
| 79 |
+
self.register_buffer("forward_basis", forward_basis.float())
|
| 80 |
+
self.register_buffer("inverse_basis", inverse_basis.float())
|
| 81 |
+
|
| 82 |
+
def transform(self, input_data):
|
| 83 |
+
num_batches = input_data.size(0)
|
| 84 |
+
num_samples = input_data.size(1)
|
| 85 |
+
|
| 86 |
+
self.num_samples = num_samples
|
| 87 |
+
|
| 88 |
+
if input_data.device.type == "cuda":
|
| 89 |
+
# similar to librosa, reflect-pad the input
|
| 90 |
+
input_data = input_data.view(num_batches, 1, num_samples)
|
| 91 |
+
input_data = F.pad(
|
| 92 |
+
input_data.unsqueeze(1),
|
| 93 |
+
(int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
|
| 94 |
+
mode="reflect",
|
| 95 |
+
)
|
| 96 |
+
input_data = input_data.squeeze(1)
|
| 97 |
+
|
| 98 |
+
forward_transform = F.conv1d(
|
| 99 |
+
input_data, self.forward_basis, stride=self.hop_length, padding=0
|
| 100 |
+
)
|
| 101 |
+
|
| 102 |
+
cutoff = int((self.filter_length / 2) + 1)
|
| 103 |
+
real_part = forward_transform[:, :cutoff, :]
|
| 104 |
+
imag_part = forward_transform[:, cutoff:, :]
|
| 105 |
+
else:
|
| 106 |
+
x = input_data.detach().numpy()
|
| 107 |
+
real_part = []
|
| 108 |
+
imag_part = []
|
| 109 |
+
for y in x:
|
| 110 |
+
y_ = stft(
|
| 111 |
+
y, self.filter_length, self.hop_length, self.win_length, self.window
|
| 112 |
+
)
|
| 113 |
+
real_part.append(y_.real[None, :, :])
|
| 114 |
+
imag_part.append(y_.imag[None, :, :])
|
| 115 |
+
real_part = np.concatenate(real_part, 0)
|
| 116 |
+
imag_part = np.concatenate(imag_part, 0)
|
| 117 |
+
|
| 118 |
+
real_part = torch.from_numpy(real_part).to(input_data.dtype)
|
| 119 |
+
imag_part = torch.from_numpy(imag_part).to(input_data.dtype)
|
| 120 |
+
|
| 121 |
+
magnitude = torch.sqrt(real_part ** 2 + imag_part ** 2)
|
| 122 |
+
phase = torch.atan2(imag_part.data, real_part.data)
|
| 123 |
+
|
| 124 |
+
return magnitude, phase
|
| 125 |
+
|
| 126 |
+
def inverse(self, magnitude, phase):
|
| 127 |
+
recombine_magnitude_phase = torch.cat(
|
| 128 |
+
[magnitude * torch.cos(phase), magnitude * torch.sin(phase)], dim=1
|
| 129 |
+
)
|
| 130 |
+
|
| 131 |
+
if magnitude.device.type == "cuda":
|
| 132 |
+
inverse_transform = F.conv_transpose1d(
|
| 133 |
+
recombine_magnitude_phase,
|
| 134 |
+
self.inverse_basis,
|
| 135 |
+
stride=self.hop_length,
|
| 136 |
+
padding=0,
|
| 137 |
+
)
|
| 138 |
+
|
| 139 |
+
if self.window is not None:
|
| 140 |
+
window_sum = window_sumsquare(
|
| 141 |
+
self.window,
|
| 142 |
+
magnitude.size(-1),
|
| 143 |
+
hop_length=self.hop_length,
|
| 144 |
+
win_length=self.win_length,
|
| 145 |
+
n_fft=self.filter_length,
|
| 146 |
+
dtype=np.float32,
|
| 147 |
+
)
|
| 148 |
+
# remove modulation effects
|
| 149 |
+
approx_nonzero_indices = torch.from_numpy(
|
| 150 |
+
np.where(window_sum > tiny(window_sum))[0]
|
| 151 |
+
)
|
| 152 |
+
window_sum = torch.from_numpy(window_sum).to(inverse_transform.device)
|
| 153 |
+
inverse_transform[:, :, approx_nonzero_indices] /= window_sum[
|
| 154 |
+
approx_nonzero_indices
|
| 155 |
+
]
|
| 156 |
+
|
| 157 |
+
# scale by hop ratio
|
| 158 |
+
inverse_transform *= float(self.filter_length) / self.hop_length
|
| 159 |
+
|
| 160 |
+
inverse_transform = inverse_transform[:, :, int(self.filter_length / 2) :]
|
| 161 |
+
inverse_transform = inverse_transform[
|
| 162 |
+
:, :, : -int(self.filter_length / 2) :
|
| 163 |
+
]
|
| 164 |
+
inverse_transform = inverse_transform.squeeze(1)
|
| 165 |
+
else:
|
| 166 |
+
x_org = recombine_magnitude_phase.detach().numpy()
|
| 167 |
+
n_b, n_f, n_t = x_org.shape
|
| 168 |
+
x = np.empty([n_b, n_f // 2, n_t], dtype=np.complex64)
|
| 169 |
+
x.real = x_org[:, : n_f // 2]
|
| 170 |
+
x.imag = x_org[:, n_f // 2 :]
|
| 171 |
+
inverse_transform = []
|
| 172 |
+
for y in x:
|
| 173 |
+
y_ = istft(y, self.hop_length, self.win_length, self.window)
|
| 174 |
+
inverse_transform.append(y_[None, :])
|
| 175 |
+
inverse_transform = np.concatenate(inverse_transform, 0)
|
| 176 |
+
inverse_transform = torch.from_numpy(inverse_transform).to(
|
| 177 |
+
recombine_magnitude_phase.dtype
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
return inverse_transform
|
| 181 |
+
|
| 182 |
+
def forward(self, input_data):
|
| 183 |
+
self.magnitude, self.phase = self.transform(input_data)
|
| 184 |
+
reconstruction = self.inverse(self.magnitude, self.phase)
|
| 185 |
+
return reconstruction
|
src/glow_tts/t2s_fastapi.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from starlette.responses import StreamingResponse
|
| 2 |
+
from texttospeech import MelToWav, TextToMel
|
| 3 |
+
from typing import Optional
|
| 4 |
+
from pydantic import BaseModel
|
| 5 |
+
from fastapi import FastAPI, HTTPException
|
| 6 |
+
import uvicorn
|
| 7 |
+
import base64
|
| 8 |
+
|
| 9 |
+
app = FastAPI()
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class TextJson(BaseModel):
|
| 13 |
+
text: str
|
| 14 |
+
lang: Optional[str] = "hi"
|
| 15 |
+
gender: Optional[str] = "male"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
glow_hi_male = TextToMel(glow_model_dir="", device="")
|
| 19 |
+
glow_hi_female = TextToMel(glow_model_dir="", device="")
|
| 20 |
+
hifi_hi = MelToWav(hifi_model_dir="", device="")
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
available_choice = {
|
| 24 |
+
"hi_male": [glow_hi_male, hifi_hi],
|
| 25 |
+
"hi_female": [glow_hi_female, hifi_hi],
|
| 26 |
+
}
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
@app.post("/TTS/")
|
| 30 |
+
async def tts(input: TextJson):
|
| 31 |
+
text = input.text
|
| 32 |
+
lang = input.lang
|
| 33 |
+
gender = input.gender
|
| 34 |
+
|
| 35 |
+
choice = lang + "_" + gender
|
| 36 |
+
if choice in available_choice.keys():
|
| 37 |
+
t2s = available_choice[choice]
|
| 38 |
+
else:
|
| 39 |
+
raise HTTPException(
|
| 40 |
+
status_code=400, detail={"error": "Requested model not found"}
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
if text:
|
| 44 |
+
mel = t2s[0].generate_mel(text)
|
| 45 |
+
data, sr = t2s[1].generate_wav(mel)
|
| 46 |
+
t2s.save_audio("out.wav", data, sr)
|
| 47 |
+
else:
|
| 48 |
+
raise HTTPException(status_code=400, detail={"error": "No text"})
|
| 49 |
+
|
| 50 |
+
## to return outpur as a file
|
| 51 |
+
# audio = open('out.wav', mode='rb')
|
| 52 |
+
# return StreamingResponse(audio, media_type="audio/wav")
|
| 53 |
+
|
| 54 |
+
with open("out.wav", "rb") as audio_file:
|
| 55 |
+
encoded_bytes = base64.b64encode(audio_file.read())
|
| 56 |
+
encoded_string = encoded_bytes.decode()
|
| 57 |
+
return {"encoding": "base64", "data": encoded_string, "sr": sr}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
if __name__ == "__main__":
|
| 61 |
+
uvicorn.run(
|
| 62 |
+
"t2s_fastapi:app", host="127.0.0.1", port=5000, log_level="info", reload=True
|
| 63 |
+
)
|
src/glow_tts/t2s_gradio.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import gradio as gr
|
| 2 |
+
from texttospeech import TextToMel, MelToWav
|
| 3 |
+
|
| 4 |
+
text_to_mel = TextToMel(
|
| 5 |
+
glow_model_dir="/path/to/glow-tts/checkpoint/dir", device="cuda"
|
| 6 |
+
)
|
| 7 |
+
mel_to_wav = MelToWav(hifi_model_dir="/path/to/glow-tts/checkpoint/dir", device="cuda")
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def run_tts(text):
|
| 11 |
+
mel = text_to_mel.generate_mel(text)
|
| 12 |
+
audio, sr = mel_to_wav.generate_wav(mel)
|
| 13 |
+
return (sr, audio)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
# text = " सीआईएसएफ में उप-निरीक्षक महावीर प्रसाद गोदरा को मरणोपरांत 'शौर्य चक्र' से सम्मानित किया गया। "
|
| 17 |
+
# run_tts(text)
|
| 18 |
+
|
| 19 |
+
textbox = gr.inputs.Textbox(
|
| 20 |
+
placeholder="Enter Telugu text here", default="", label="TTS"
|
| 21 |
+
)
|
| 22 |
+
op = gr.outputs.Audio(type="numpy", label=None)
|
| 23 |
+
iface = gr.Interface(fn=run_tts, inputs=textbox, outputs=op)
|
| 24 |
+
iface.launch(share=True)
|
src/glow_tts/text/__init__.py
ADDED
|
@@ -0,0 +1,84 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
""" from https://github.com/keithito/tacotron """
|
| 2 |
+
import re
|
| 3 |
+
from text import cleaners
|
| 4 |
+
|
| 5 |
+
# Regular expression matching text enclosed in curly braces:
|
| 6 |
+
_curly_re = re.compile(r'(.*?)\{(.+?)\}(.*)')
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def get_arpabet(word, dictionary):
|
| 10 |
+
word_arpabet = dictionary.lookup(word)
|
| 11 |
+
if word_arpabet is not None:
|
| 12 |
+
return "{" + word_arpabet[0] + "}"
|
| 13 |
+
else:
|
| 14 |
+
return word
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def text_to_sequence(text, symbols, cleaner_names, dictionary=None):
|
| 18 |
+
'''Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
|
| 19 |
+
|
| 20 |
+
The text can optionally have ARPAbet sequences enclosed in curly braces embedded
|
| 21 |
+
in it. For example, "Turn left on {HH AW1 S S T AH0 N} Street."
|
| 22 |
+
|
| 23 |
+
Args:
|
| 24 |
+
text: string to convert to a sequence
|
| 25 |
+
cleaner_names: names of the cleaner functions to run the text through
|
| 26 |
+
dictionary: arpabet class with arpabet dictionary
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
List of integers corresponding to the symbols in the text
|
| 30 |
+
'''
|
| 31 |
+
# Mappings from symbol to numeric ID and vice versa:
|
| 32 |
+
global _id_to_symbol, _symbol_to_id
|
| 33 |
+
_symbol_to_id = {s: i for i, s in enumerate(symbols)}
|
| 34 |
+
_id_to_symbol = {i: s for i, s in enumerate(symbols)}
|
| 35 |
+
|
| 36 |
+
sequence = []
|
| 37 |
+
|
| 38 |
+
space = _symbols_to_sequence(' ')
|
| 39 |
+
# Check for curly braces and treat their contents as ARPAbet:
|
| 40 |
+
while len(text):
|
| 41 |
+
m = _curly_re.match(text)
|
| 42 |
+
if not m:
|
| 43 |
+
clean_text = _clean_text(text, cleaner_names)
|
| 44 |
+
if dictionary is not None:
|
| 45 |
+
clean_text = [get_arpabet(w, dictionary) for w in clean_text.split(" ")]
|
| 46 |
+
for i in range(len(clean_text)):
|
| 47 |
+
t = clean_text[i]
|
| 48 |
+
if t.startswith("{"):
|
| 49 |
+
sequence += _arpabet_to_sequence(t[1:-1])
|
| 50 |
+
else:
|
| 51 |
+
sequence += _symbols_to_sequence(t)
|
| 52 |
+
sequence += space
|
| 53 |
+
else:
|
| 54 |
+
sequence += _symbols_to_sequence(clean_text)
|
| 55 |
+
break
|
| 56 |
+
sequence += _symbols_to_sequence(_clean_text(m.group(1), cleaner_names))
|
| 57 |
+
sequence += _arpabet_to_sequence(m.group(2))
|
| 58 |
+
text = m.group(3)
|
| 59 |
+
|
| 60 |
+
# remove trailing space
|
| 61 |
+
if dictionary is not None:
|
| 62 |
+
sequence = sequence[:-1] if sequence[-1] == space[0] else sequence
|
| 63 |
+
return sequence
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
def _clean_text(text, cleaner_names):
|
| 67 |
+
for name in cleaner_names:
|
| 68 |
+
cleaner = getattr(cleaners, name)
|
| 69 |
+
if not cleaner:
|
| 70 |
+
raise Exception('Unknown cleaner: %s' % name)
|
| 71 |
+
text = cleaner(text)
|
| 72 |
+
return text
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
def _symbols_to_sequence(symbols):
|
| 76 |
+
return [_symbol_to_id[s] for s in symbols if _should_keep_symbol(s)]
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
def _arpabet_to_sequence(text):
|
| 80 |
+
return _symbols_to_sequence(['@' + s for s in text.split()])
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def _should_keep_symbol(s):
|
| 84 |
+
return s in _symbol_to_id and s is not '_' and s is not '~'
|
src/glow_tts/text/__pycache__/__init__.cpython-37.pyc
ADDED
|
Binary file (3.03 kB). View file
|
|
|
src/glow_tts/text/__pycache__/cleaners.cpython-37.pyc
ADDED
|
Binary file (2.16 kB). View file
|
|
|
src/glow_tts/text/__pycache__/numbers.cpython-37.pyc
ADDED
|
Binary file (2.17 kB). View file
|
|
|
src/glow_tts/text/cleaners.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import re
|
| 2 |
+
|
| 3 |
+
from unidecode import unidecode
|
| 4 |
+
from .numbers import normalize_numbers
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
# Regular expression matching whitespace:
|
| 10 |
+
_whitespace_re = re.compile(r"\s+")
|
| 11 |
+
|
| 12 |
+
def lowercase(text):
|
| 13 |
+
return text.lower()
|
| 14 |
+
|
| 15 |
+
def collapse_whitespace(text):
|
| 16 |
+
return re.sub(_whitespace_re, " ", text)
|
| 17 |
+
|
| 18 |
+
def basic_indic_cleaners(text):
|
| 19 |
+
"""Basic pipeline that collapses whitespace without transliteration."""
|
| 20 |
+
text = collapse_whitespace(text)
|
| 21 |
+
return text
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def english_cleaner(text):
|
| 25 |
+
text = text.lower().replace('‘','\'').replace('’','\'')
|
| 26 |
+
return text
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def lowercase(text):
|
| 30 |
+
return text.lower()
|
| 31 |
+
|
| 32 |
+
def convert_to_ascii(text):
|
| 33 |
+
return unidecode(text)
|
| 34 |
+
|
| 35 |
+
def expand_numbers(text):
|
| 36 |
+
return normalize_numbers(text)
|
| 37 |
+
|
| 38 |
+
def expand_abbreviations(text):
|
| 39 |
+
for regex, replacement in _abbreviations:
|
| 40 |
+
text = re.sub(regex, replacement, text)
|
| 41 |
+
return text
|
| 42 |
+
|
| 43 |
+
_abbreviations = [(re.compile('\\b%s\\.' % x[0], re.IGNORECASE), x[1]) for x in [
|
| 44 |
+
('mrs', 'missus'),
|
| 45 |
+
('mr', 'mister'),
|
| 46 |
+
('dr', 'doctor'),
|
| 47 |
+
('st', 'saint'),
|
| 48 |
+
('co', 'company'),
|
| 49 |
+
('jr', 'junior'),
|
| 50 |
+
('maj', 'major'),
|
| 51 |
+
('gen', 'general'),
|
| 52 |
+
('drs', 'doctors'),
|
| 53 |
+
('rev', 'reverend'),
|
| 54 |
+
('lt', 'lieutenant'),
|
| 55 |
+
('hon', 'honorable'),
|
| 56 |
+
('sgt', 'sergeant'),
|
| 57 |
+
('capt', 'captain'),
|
| 58 |
+
('esq', 'esquire'),
|
| 59 |
+
('ltd', 'limited'),
|
| 60 |
+
('col', 'colonel'),
|
| 61 |
+
('ft', 'fort'),
|
| 62 |
+
('pvt', 'private'),
|
| 63 |
+
('rs', 'Rupees')
|
| 64 |
+
]]
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def english_cleaners(text):
|
| 72 |
+
'''Pipeline for English text, including number and abbreviation expansion.'''
|
| 73 |
+
text = convert_to_ascii(text)
|
| 74 |
+
text = lowercase(text)
|
| 75 |
+
text = expand_numbers(text)
|
| 76 |
+
text = expand_abbreviations(text)
|
| 77 |
+
text = collapse_whitespace(text)
|
| 78 |
+
return text
|
src/glow_tts/text/numbers.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import inflect
|
| 2 |
+
import re
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
_inflect = inflect.engine()
|
| 6 |
+
_comma_number_re = re.compile(r'([0-9][0-9\,]+[0-9])')
|
| 7 |
+
_decimal_number_re = re.compile(r'([0-9]+\.[0-9]+)')
|
| 8 |
+
_pounds_re = re.compile(r'£([0-9\,]*[0-9]+)')
|
| 9 |
+
_dollars_re = re.compile(r'\$([0-9\.\,]*[0-9]+)')
|
| 10 |
+
_ordinal_re = re.compile(r'[0-9]+(st|nd|rd|th)')
|
| 11 |
+
_number_re = re.compile(r'[0-9]+')
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def _remove_commas(m):
|
| 15 |
+
return m.group(1).replace(',', '')
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def _expand_decimal_point(m):
|
| 19 |
+
return m.group(1).replace('.', ' point ')
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _expand_dollars(m):
|
| 23 |
+
match = m.group(1)
|
| 24 |
+
parts = match.split('.')
|
| 25 |
+
if len(parts) > 2:
|
| 26 |
+
return match + ' dollars' # Unexpected format
|
| 27 |
+
dollars = int(parts[0]) if parts[0] else 0
|
| 28 |
+
cents = int(parts[1]) if len(parts) > 1 and parts[1] else 0
|
| 29 |
+
if dollars and cents:
|
| 30 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
| 31 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
| 32 |
+
return '%s %s, %s %s' % (dollars, dollar_unit, cents, cent_unit)
|
| 33 |
+
elif dollars:
|
| 34 |
+
dollar_unit = 'dollar' if dollars == 1 else 'dollars'
|
| 35 |
+
return '%s %s' % (dollars, dollar_unit)
|
| 36 |
+
elif cents:
|
| 37 |
+
cent_unit = 'cent' if cents == 1 else 'cents'
|
| 38 |
+
return '%s %s' % (cents, cent_unit)
|
| 39 |
+
else:
|
| 40 |
+
return 'zero dollars'
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
def _expand_ordinal(m):
|
| 44 |
+
return _inflect.number_to_words(m.group(0))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
def _expand_number(m):
|
| 48 |
+
num = int(m.group(0))
|
| 49 |
+
if num > 1000 and num < 3000:
|
| 50 |
+
if num == 2000:
|
| 51 |
+
return 'two thousand'
|
| 52 |
+
elif num > 2000 and num < 2010:
|
| 53 |
+
return 'two thousand ' + _inflect.number_to_words(num % 100)
|
| 54 |
+
elif num % 100 == 0:
|
| 55 |
+
return _inflect.number_to_words(num // 100) + ' hundred'
|
| 56 |
+
else:
|
| 57 |
+
return _inflect.number_to_words(num, andword='', zero='oh', group=2).replace(', ', ' ')
|
| 58 |
+
else:
|
| 59 |
+
return _inflect.number_to_words(num, andword='')
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def normalize_numbers(text):
|
| 63 |
+
text = re.sub(_comma_number_re, _remove_commas, text)
|
| 64 |
+
text = re.sub(_pounds_re, r'\1 pounds', text)
|
| 65 |
+
text = re.sub(_dollars_re, _expand_dollars, text)
|
| 66 |
+
text = re.sub(_decimal_number_re, _expand_decimal_point, text)
|
| 67 |
+
text = re.sub(_ordinal_re, _expand_ordinal, text)
|
| 68 |
+
text = re.sub(_number_re, _expand_number, text)
|
| 69 |
+
return text
|