backtracking commited on
Commit
c0c84cf
·
verified ·
1 Parent(s): 46dc475

Upload folder using huggingface_hub

Browse files
Files changed (40) hide show
  1. tiny_tts/__init__.py +87 -0
  2. tiny_tts/alignment/__init__.py +16 -0
  3. tiny_tts/alignment/__pycache__/__init__.cpython-310.pyc +0 -0
  4. tiny_tts/alignment/__pycache__/core.cpython-310.pyc +0 -0
  5. tiny_tts/alignment/core.py +46 -0
  6. tiny_tts/infer.py +172 -0
  7. tiny_tts/models/__init__.py +1 -0
  8. tiny_tts/models/__pycache__/__init__.cpython-310.pyc +0 -0
  9. tiny_tts/models/__pycache__/synthesizer.cpython-310.pyc +0 -0
  10. tiny_tts/models/synthesizer.py +718 -0
  11. tiny_tts/nn/__init__.py +1 -0
  12. tiny_tts/nn/__pycache__/__init__.cpython-310.pyc +0 -0
  13. tiny_tts/nn/__pycache__/attentions.cpython-310.pyc +0 -0
  14. tiny_tts/nn/__pycache__/commons.cpython-310.pyc +0 -0
  15. tiny_tts/nn/__pycache__/modules.cpython-310.pyc +0 -0
  16. tiny_tts/nn/__pycache__/transforms.cpython-310.pyc +0 -0
  17. tiny_tts/nn/attentions.py +424 -0
  18. tiny_tts/nn/commons.py +151 -0
  19. tiny_tts/nn/modules.py +578 -0
  20. tiny_tts/nn/transforms.py +209 -0
  21. tiny_tts/text/__init__.py +19 -0
  22. tiny_tts/text/__pycache__/__init__.cpython-310.pyc +0 -0
  23. tiny_tts/text/__pycache__/english.cpython-310.pyc +0 -0
  24. tiny_tts/text/__pycache__/symbols.cpython-310.pyc +0 -0
  25. tiny_tts/text/cmudict.rep +0 -0
  26. tiny_tts/text/cmudict_cache.pickle +3 -0
  27. tiny_tts/text/english.py +173 -0
  28. tiny_tts/text/english_utils/__init__.py +0 -0
  29. tiny_tts/text/english_utils/__pycache__/__init__.cpython-310.pyc +0 -0
  30. tiny_tts/text/english_utils/__pycache__/abbreviations.cpython-310.pyc +0 -0
  31. tiny_tts/text/english_utils/__pycache__/number_norm.cpython-310.pyc +0 -0
  32. tiny_tts/text/english_utils/__pycache__/time_norm.cpython-310.pyc +0 -0
  33. tiny_tts/text/english_utils/abbreviations.py +35 -0
  34. tiny_tts/text/english_utils/number_norm.py +97 -0
  35. tiny_tts/text/english_utils/time_norm.py +47 -0
  36. tiny_tts/text/symbols.py +293 -0
  37. tiny_tts/utils/__init__.py +5 -0
  38. tiny_tts/utils/__pycache__/__init__.cpython-310.pyc +0 -0
  39. tiny_tts/utils/__pycache__/config.cpython-310.pyc +0 -0
  40. tiny_tts/utils/config.py +41 -0
tiny_tts/__init__.py ADDED
@@ -0,0 +1,87 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import soundfile as sf
4
+ from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
5
+ from tiny_tts.text import phonemes_to_ids
6
+ from tiny_tts.nn import commons
7
+ from tiny_tts.models.synthesizer import VoiceSynthesizer
8
+ from tiny_tts.text.symbols import symbols
9
+ from tiny_tts.utils.config import (
10
+ SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
11
+ N_SPEAKERS, SPK2ID, MODEL_PARAMS,
12
+ )
13
+ from tiny_tts.infer import load_engine
14
+
15
+ class TinyTTS:
16
+ def __init__(self, checkpoint_path=None, device=None):
17
+ if device is None:
18
+ self.device = 'cuda' if torch.cuda.is_available() else 'cpu'
19
+ else:
20
+ self.device = device
21
+
22
+ if checkpoint_path is None:
23
+ # Look for default checkpoint in pacakage
24
+ pkg_dir = os.path.dirname(os.path.abspath(__file__))
25
+ default_ckpt = os.path.join(os.path.dirname(pkg_dir), "checkpoints", "G.pth")
26
+ # 2. Check HuggingFace Cache / Download
27
+ if not os.path.exists(default_ckpt):
28
+ try:
29
+ from huggingface_hub import hf_hub_download
30
+ print("Downloading/Loading checkpoint from Hugging Face Hub (backtracking/tiny-tts)...")
31
+ default_ckpt = hf_hub_download(repo_id="backtracking/tiny-tts", filename="G.pth")
32
+ except ImportError:
33
+ raise ImportError("huggingface_hub is required to auto-download the model. Run: pip install huggingface_hub")
34
+ except Exception as e:
35
+ raise ValueError(f"Failed to download checkpoint from Hugging Face: {e}")
36
+
37
+ checkpoint_path = default_ckpt
38
+
39
+ self.model = load_engine(checkpoint_path, self.device)
40
+
41
+ def speak(self, text, output_path="output.wav", speaker="LJ"):
42
+ """Synthesize text to speech and save to output_path."""
43
+ print(f"Synthesizing: {text}")
44
+
45
+ # Normalize text
46
+ normalized = normalize_text(text)
47
+
48
+ # Phonemize
49
+ phones, tones, word2ph = grapheme_to_phoneme(normalized)
50
+
51
+ # Convert to sequence
52
+ phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
53
+
54
+ # Add blanks
55
+ if ADD_BLANK:
56
+ phone_ids = commons.insert_blanks(phone_ids, 0)
57
+ tone_ids = commons.insert_blanks(tone_ids, 0)
58
+ lang_ids = commons.insert_blanks(lang_ids, 0)
59
+
60
+ x = torch.LongTensor(phone_ids).unsqueeze(0).to(self.device)
61
+ x_lengths = torch.LongTensor([len(phone_ids)]).to(self.device)
62
+ tone = torch.LongTensor(tone_ids).unsqueeze(0).to(self.device)
63
+ language = torch.LongTensor(lang_ids).unsqueeze(0).to(self.device)
64
+
65
+ # Speaker ID
66
+ if speaker not in SPK2ID:
67
+ print(f"Warning: Speaker '{speaker}' not found, using ID 0. Available: {list(SPK2ID.keys())}")
68
+ sid = torch.LongTensor([0]).to(self.device)
69
+ else:
70
+ sid = torch.LongTensor([SPK2ID[speaker]]).to(self.device)
71
+
72
+ # BERT features (disabled - using zero tensors)
73
+ bert = torch.zeros(1024, len(phone_ids)).to(self.device).unsqueeze(0)
74
+ ja_bert = torch.zeros(768, len(phone_ids)).to(self.device).unsqueeze(0)
75
+
76
+ with torch.no_grad():
77
+ audio, *_ = self.model.infer(
78
+ x, x_lengths, sid, tone, language, bert, ja_bert,
79
+ noise_scale=0.667,
80
+ noise_scale_w=0.8,
81
+ length_scale=1.0
82
+ )
83
+
84
+ audio_np = audio[0, 0].cpu().numpy()
85
+ sf.write(output_path, audio_np, SAMPLING_RATE)
86
+ print(f"Saved audio to {output_path}")
87
+ return audio_np
tiny_tts/alignment/__init__.py ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from numpy import zeros, int32, float32
2
+ from torch import from_numpy
3
+
4
+ from .core import viterbi_decode_kernel
5
+
6
+
7
+ def viterbi_decode(neg_cent, mask):
8
+ device = neg_cent.device
9
+ dtype = neg_cent.dtype
10
+ neg_cent = neg_cent.data.cpu().numpy().astype(float32)
11
+ path = zeros(neg_cent.shape, dtype=int32)
12
+
13
+ t_t_max = mask.sum(1)[:, 0].data.cpu().numpy().astype(int32)
14
+ t_s_max = mask.sum(2)[:, 0].data.cpu().numpy().astype(int32)
15
+ viterbi_decode_kernel(path, neg_cent, t_t_max, t_s_max)
16
+ return from_numpy(path).to(device=device, dtype=dtype)
tiny_tts/alignment/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (754 Bytes). View file
 
tiny_tts/alignment/__pycache__/core.cpython-310.pyc ADDED
Binary file (1.01 kB). View file
 
tiny_tts/alignment/core.py ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numba
2
+
3
+
4
+ @numba.jit(
5
+ numba.void(
6
+ numba.int32[:, :, ::1],
7
+ numba.float32[:, :, ::1],
8
+ numba.int32[::1],
9
+ numba.int32[::1],
10
+ ),
11
+ nopython=True,
12
+ nogil=True,
13
+ )
14
+ def viterbi_decode_kernel(paths, values, t_ys, t_xs):
15
+ b = paths.shape[0]
16
+ max_neg_val = -1e9
17
+ for i in range(int(b)):
18
+ path = paths[i]
19
+ value = values[i]
20
+ t_y = t_ys[i]
21
+ t_x = t_xs[i]
22
+
23
+ v_prev = v_cur = 0.0
24
+ index = t_x - 1
25
+
26
+ for y in range(t_y):
27
+ for x in range(max(0, t_x + y - t_y), min(t_x, y + 1)):
28
+ if x == y:
29
+ v_cur = max_neg_val
30
+ else:
31
+ v_cur = value[y - 1, x]
32
+ if x == 0:
33
+ if y == 0:
34
+ v_prev = 0.0
35
+ else:
36
+ v_prev = max_neg_val
37
+ else:
38
+ v_prev = value[y - 1, x - 1]
39
+ value[y, x] += max(v_prev, v_cur)
40
+
41
+ for y in range(t_y - 1, -1, -1):
42
+ path[y, index] = 1
43
+ if index != 0 and (
44
+ index == y or value[y - 1, index] < value[y - 1, index - 1]
45
+ ):
46
+ index = index - 1
tiny_tts/infer.py ADDED
@@ -0,0 +1,172 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import re
4
+ import torch
5
+ import soundfile as sf
6
+ import argparse
7
+ from tiny_tts.text.english import normalize_text, grapheme_to_phoneme
8
+ from tiny_tts.text import phonemes_to_ids
9
+ from tiny_tts.nn import commons
10
+ from tiny_tts.models import VoiceSynthesizer
11
+ from tiny_tts.text.symbols import symbols
12
+ from tiny_tts.utils import (
13
+ SAMPLING_RATE, SEGMENT_FRAMES, ADD_BLANK, SPEC_CHANNELS,
14
+ N_SPEAKERS, SPK2ID, MODEL_PARAMS,
15
+ )
16
+
17
+
18
+ def load_engine(checkpoint_path, device='cuda'):
19
+ print(f"Loading model from {checkpoint_path}")
20
+ net_g = VoiceSynthesizer(
21
+ len(symbols),
22
+ SPEC_CHANNELS,
23
+ SEGMENT_FRAMES,
24
+ n_speakers=N_SPEAKERS,
25
+ **MODEL_PARAMS
26
+ ).to(device)
27
+
28
+ # Count model parameters
29
+ total_params = sum(p.numel() for p in net_g.parameters())
30
+ trainable_params = sum(p.numel() for p in net_g.parameters() if p.requires_grad)
31
+ print(f"Model parameters: {total_params/1e6:.2f}M total, {trainable_params/1e6:.2f}M trainable")
32
+
33
+ checkpoint = torch.load(checkpoint_path, map_location=device)
34
+ state_dict = checkpoint['model']
35
+
36
+ # Remove module. prefix and filter shape mismatches
37
+ model_state = net_g.state_dict()
38
+ new_state_dict = {}
39
+ skipped = []
40
+ for k, v in state_dict.items():
41
+ key = k[7:] if k.startswith('module.') else k
42
+ if key in model_state:
43
+ if v.shape == model_state[key].shape:
44
+ new_state_dict[key] = v
45
+ else:
46
+ skipped.append(f"{key}: ckpt{v.shape} vs model{model_state[key].shape}")
47
+ else:
48
+ new_state_dict[key] = v
49
+
50
+ if skipped:
51
+ print(f"Skipped {len(skipped)} mismatched keys:")
52
+ for s in skipped[:5]:
53
+ print(f" {s}")
54
+ if len(skipped) > 5:
55
+ print(f" ... and {len(skipped)-5} more")
56
+
57
+ net_g.load_state_dict(new_state_dict, strict=False)
58
+ net_g.eval()
59
+ return net_g
60
+
61
+
62
+ def synthesize(text, output_path, model, speaker="LJ", device='cuda'):
63
+ print(f"Synthesizing: {text}")
64
+
65
+ # Normalize text
66
+ normalized = normalize_text(text)
67
+
68
+ # Phonemize
69
+ phones, tones, word2ph = grapheme_to_phoneme(normalized)
70
+
71
+ # Convert to sequence
72
+ phone_ids, tone_ids, lang_ids = phonemes_to_ids(phones, tones, "EN")
73
+
74
+ # Add blanks
75
+ if ADD_BLANK:
76
+ phone_ids = commons.insert_blanks(phone_ids, 0)
77
+ tone_ids = commons.insert_blanks(tone_ids, 0)
78
+ lang_ids = commons.insert_blanks(lang_ids, 0)
79
+
80
+ x = torch.LongTensor(phone_ids).unsqueeze(0).to(device)
81
+ x_lengths = torch.LongTensor([len(phone_ids)]).to(device)
82
+ tone = torch.LongTensor(tone_ids).unsqueeze(0).to(device)
83
+ language = torch.LongTensor(lang_ids).unsqueeze(0).to(device)
84
+
85
+ # Speaker ID
86
+ if speaker not in SPK2ID:
87
+ print(f"Warning: Speaker {speaker} not found, using ID 0")
88
+ sid = torch.LongTensor([0]).to(device)
89
+ else:
90
+ sid = torch.LongTensor([SPK2ID[speaker]]).to(device)
91
+
92
+ # BERT features (disabled - using zero tensors)
93
+ bert = torch.zeros(1024, len(phone_ids)).to(device).unsqueeze(0)
94
+ ja_bert = torch.zeros(768, len(phone_ids)).to(device).unsqueeze(0)
95
+
96
+ with torch.no_grad():
97
+ audio, *_ = model.infer(
98
+ x, x_lengths, sid, tone, language, bert, ja_bert,
99
+ noise_scale=0.667,
100
+ noise_scale_w=0.8,
101
+ length_scale=1.0
102
+ )
103
+
104
+ audio = audio[0, 0].cpu().numpy()
105
+ sf.write(output_path, audio, SAMPLING_RATE)
106
+ print(f"Saved audio to {output_path}")
107
+
108
+
109
+ def get_latest_checkpoint(checkpoint_dir):
110
+ """Finds the latest G_*.pth checkpoint in the given directory."""
111
+ checkpoints = [f for f in os.listdir(checkpoint_dir) if f.startswith('G_') and f.endswith('.pth')]
112
+ if not checkpoints:
113
+ return None
114
+
115
+ def get_step(filename):
116
+ match = re.search(r'_(\d+)\.pth', filename)
117
+ return int(match.group(1)) if match else -1
118
+
119
+ latest_ckpt = max(checkpoints, key=get_step)
120
+ return os.path.join(checkpoint_dir, latest_ckpt)
121
+
122
+
123
+ def main():
124
+ parser = argparse.ArgumentParser(description="TinyTTS — English Text-to-Speech Inference")
125
+ parser.add_argument("--text", "-t", type=str, default="The weather is nice today, and I feel very relaxed.", help="Text to synthesize")
126
+ parser.add_argument("--checkpoint", "-c", type=str, required=True, help="Path to checkpoint (G_*.pth) or directory containing checkpoints")
127
+ parser.add_argument("--output", "-o", type=str, default="english_test.wav", help="Output audio file path")
128
+ parser.add_argument("--speaker", "-s", type=str, default="female", help="Speaker ID")
129
+ parser.add_argument("--device", type=str, default="cuda", help="Device to use (cuda or cpu)")
130
+
131
+ args = parser.parse_args()
132
+
133
+ if not os.path.exists(args.checkpoint):
134
+ print(f"Error: Checkpoint or directory not found at {args.checkpoint}")
135
+ sys.exit(1)
136
+
137
+ if os.path.isdir(args.checkpoint):
138
+ latest_ckpt = get_latest_checkpoint(args.checkpoint)
139
+ if not latest_ckpt:
140
+ print(f"Error: No G_*.pth checkpoints found in directory {args.checkpoint}")
141
+ sys.exit(1)
142
+ args.checkpoint = latest_ckpt
143
+ print(f"Auto-detected latest checkpoint: {args.checkpoint}")
144
+
145
+ # Extract step from checkpoint filename
146
+ ckpt_basename = os.path.basename(args.checkpoint)
147
+ match = re.search(r'_(\d+)\.pth', ckpt_basename)
148
+ step_str = match.group(1) if match else "unknown"
149
+
150
+ # Save to output folder
151
+ out_dir = "infer_outputs"
152
+ os.makedirs(out_dir, exist_ok=True)
153
+
154
+ out_name = os.path.basename(args.output)
155
+ name, ext = os.path.splitext(out_name)
156
+ model = load_engine(args.checkpoint, args.device)
157
+
158
+ if args.speaker.lower() == "all":
159
+ if not SPK2ID:
160
+ print("Error: No speakers found")
161
+ sys.exit(1)
162
+ print(f"Synthesizing for all {len(SPK2ID)} speakers...")
163
+ for spk in SPK2ID.keys():
164
+ final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{spk}{ext}")
165
+ synthesize(args.text, final_output, model, speaker=spk, device=args.device)
166
+ else:
167
+ final_output = os.path.join(out_dir, f"{name}_step{step_str}_spk{args.speaker}{ext}")
168
+ synthesize(args.text, final_output, model, speaker=args.speaker, device=args.device)
169
+
170
+
171
+ if __name__ == "__main__":
172
+ main()
tiny_tts/models/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .synthesizer import VoiceSynthesizer
tiny_tts/models/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (208 Bytes). View file
 
tiny_tts/models/__pycache__/synthesizer.cpython-310.pyc ADDED
Binary file (14.8 kB). View file
 
tiny_tts/models/synthesizer.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from tiny_tts.nn import commons
7
+ from tiny_tts.nn import modules
8
+ from tiny_tts.nn import attentions
9
+
10
+ from torch.nn import Conv1d, ConvTranspose1d
11
+ from torch.nn.utils import weight_norm, remove_weight_norm
12
+
13
+ from tiny_tts.nn.commons import initialize_weights, compute_padding
14
+ import tiny_tts.alignment as alignment
15
+
16
+
17
+ class AttentionFlowBlock(nn.Module):
18
+ def __init__(
19
+ self,
20
+ channels,
21
+ hidden_channels,
22
+ filter_channels,
23
+ n_heads,
24
+ n_layers,
25
+ kernel_size,
26
+ p_dropout,
27
+ n_flows=4,
28
+ gin_channels=0,
29
+ share_parameter=False,
30
+ ):
31
+ super().__init__()
32
+ self.channels = channels
33
+ self.hidden_channels = hidden_channels
34
+ self.kernel_size = kernel_size
35
+ self.n_layers = n_layers
36
+ self.n_flows = n_flows
37
+ self.gin_channels = gin_channels
38
+
39
+ self.flows = nn.ModuleList()
40
+
41
+ self.wn = (
42
+ attentions.FeedForward(
43
+ hidden_channels,
44
+ filter_channels,
45
+ n_heads,
46
+ n_layers,
47
+ kernel_size,
48
+ p_dropout,
49
+ isflow=True,
50
+ gin_channels=self.gin_channels,
51
+ )
52
+ if share_parameter
53
+ else None
54
+ )
55
+
56
+ for i in range(n_flows):
57
+ self.flows.append(
58
+ modules.TransformerCouplingLayer(
59
+ channels,
60
+ hidden_channels,
61
+ kernel_size,
62
+ n_layers,
63
+ n_heads,
64
+ p_dropout,
65
+ filter_channels,
66
+ mean_only=True,
67
+ wn_sharing_parameter=self.wn,
68
+ gin_channels=self.gin_channels,
69
+ )
70
+ )
71
+ self.flows.append(modules.FlipTransform())
72
+
73
+ def forward(self, x, x_mask, g=None, reverse=False):
74
+ if not reverse:
75
+ for flow in self.flows:
76
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
77
+ else:
78
+ for flow in reversed(self.flows):
79
+ x = flow(x, x_mask, g=g, reverse=reverse)
80
+ return x
81
+
82
+
83
+ class VariationalDurationModel(nn.Module):
84
+ def __init__(
85
+ self,
86
+ in_channels,
87
+ filter_channels,
88
+ kernel_size,
89
+ p_dropout,
90
+ n_flows=4,
91
+ gin_channels=0,
92
+ ):
93
+ super().__init__()
94
+ filter_channels = in_channels
95
+ self.in_channels = in_channels
96
+ self.filter_channels = filter_channels
97
+ self.kernel_size = kernel_size
98
+ self.p_dropout = p_dropout
99
+ self.n_flows = n_flows
100
+ self.gin_channels = gin_channels
101
+
102
+ self.log_flow = modules.LogTransform()
103
+ self.flows = nn.ModuleList()
104
+ self.flows.append(modules.AffineCoupling(2))
105
+ for i in range(n_flows):
106
+ self.flows.append(
107
+ modules.ConvolutionalFlow(2, filter_channels, kernel_size, n_layers=3)
108
+ )
109
+ self.flows.append(modules.FlipTransform())
110
+
111
+ self.post_pre = nn.Conv1d(1, filter_channels, 1)
112
+ self.post_proj = nn.Conv1d(filter_channels, filter_channels, 1)
113
+ self.post_convs = modules.DepthwiseSepConv(
114
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
115
+ )
116
+ self.post_flows = nn.ModuleList()
117
+ self.post_flows.append(modules.AffineCoupling(2))
118
+ for i in range(4):
119
+ self.post_flows.append(
120
+ modules.ConvolutionalFlow(2, filter_channels, kernel_size, n_layers=3)
121
+ )
122
+ self.post_flows.append(modules.FlipTransform())
123
+
124
+ self.pre = nn.Conv1d(in_channels, filter_channels, 1)
125
+ self.proj = nn.Conv1d(filter_channels, filter_channels, 1)
126
+ self.convs = modules.DepthwiseSepConv(
127
+ filter_channels, kernel_size, n_layers=3, p_dropout=p_dropout
128
+ )
129
+ if gin_channels != 0:
130
+ self.cond = nn.Conv1d(gin_channels, filter_channels, 1)
131
+
132
+ def forward(self, x, x_mask, w=None, g=None, reverse=False, noise_scale=1.0):
133
+ x = torch.detach(x)
134
+ x = self.pre(x)
135
+ if g is not None:
136
+ g = torch.detach(g)
137
+ x = x + self.cond(g)
138
+ x = self.convs(x, x_mask)
139
+ x = self.proj(x) * x_mask
140
+
141
+ if not reverse:
142
+ flows = self.flows
143
+ assert w is not None
144
+
145
+ logdet_tot_q = 0
146
+ h_w = self.post_pre(w)
147
+ h_w = self.post_convs(h_w, x_mask)
148
+ h_w = self.post_proj(h_w) * x_mask
149
+ e_q = (
150
+ torch.randn(w.size(0), 2, w.size(2)).to(device=x.device, dtype=x.dtype)
151
+ * x_mask
152
+ )
153
+ z_q = e_q
154
+ for flow in self.post_flows:
155
+ z_q, logdet_q = flow(z_q, x_mask, g=(x + h_w))
156
+ logdet_tot_q += logdet_q
157
+ z_u, z1 = torch.split(z_q, [1, 1], 1)
158
+ u = torch.sigmoid(z_u) * x_mask
159
+ z0 = (w - u) * x_mask
160
+ logdet_tot_q += torch.sum(
161
+ (F.logsigmoid(z_u) + F.logsigmoid(-z_u)) * x_mask, [1, 2]
162
+ )
163
+ logq = (
164
+ torch.sum(-0.5 * (math.log(2 * math.pi) + (e_q**2)) * x_mask, [1, 2])
165
+ - logdet_tot_q
166
+ )
167
+
168
+ logdet_tot = 0
169
+ z0, logdet = self.log_flow(z0, x_mask)
170
+ logdet_tot += logdet
171
+ z = torch.cat([z0, z1], 1)
172
+ for flow in flows:
173
+ z, logdet = flow(z, x_mask, g=x, reverse=reverse)
174
+ logdet_tot = logdet_tot + logdet
175
+ nll = (
176
+ torch.sum(0.5 * (math.log(2 * math.pi) + (z**2)) * x_mask, [1, 2])
177
+ - logdet_tot
178
+ )
179
+ return nll + logq
180
+ else:
181
+ flows = list(reversed(self.flows))
182
+ flows = flows[:-2] + [flows[-1]]
183
+ z = (
184
+ torch.randn(x.size(0), 2, x.size(2)).to(device=x.device, dtype=x.dtype)
185
+ * noise_scale
186
+ )
187
+ for flow in flows:
188
+ z = flow(z, x_mask, g=x, reverse=reverse)
189
+ z0, z1 = torch.split(z, [1, 1], 1)
190
+ logw = z0
191
+ return logw
192
+
193
+
194
+ class DurationEstimator(nn.Module):
195
+ def __init__(
196
+ self, in_channels, filter_channels, kernel_size, p_dropout, gin_channels=0
197
+ ):
198
+ super().__init__()
199
+
200
+ self.in_channels = in_channels
201
+ self.filter_channels = filter_channels
202
+ self.kernel_size = kernel_size
203
+ self.p_dropout = p_dropout
204
+ self.gin_channels = gin_channels
205
+
206
+ self.drop = nn.Dropout(p_dropout)
207
+ self.conv_1 = nn.Conv1d(
208
+ in_channels, filter_channels, kernel_size, padding=kernel_size // 2
209
+ )
210
+ self.norm_1 = modules.ChannelNorm(filter_channels)
211
+ self.conv_2 = nn.Conv1d(
212
+ filter_channels, filter_channels, kernel_size, padding=kernel_size // 2
213
+ )
214
+ self.norm_2 = modules.ChannelNorm(filter_channels)
215
+ self.proj = nn.Conv1d(filter_channels, 1, 1)
216
+
217
+ if gin_channels != 0:
218
+ self.cond = nn.Conv1d(gin_channels, in_channels, 1)
219
+
220
+ def forward(self, x, x_mask, g=None):
221
+ x = torch.detach(x)
222
+ if g is not None:
223
+ g = torch.detach(g)
224
+ x = x + self.cond(g)
225
+ x = self.conv_1(x * x_mask)
226
+ x = torch.relu(x)
227
+ x = self.norm_1(x)
228
+ x = self.drop(x)
229
+ x = self.conv_2(x * x_mask)
230
+ x = torch.relu(x)
231
+ x = self.norm_2(x)
232
+ x = self.drop(x)
233
+ x = self.proj(x * x_mask)
234
+ return x * x_mask
235
+
236
+
237
+ class PhonemeEncoder(nn.Module):
238
+ def __init__(
239
+ self,
240
+ n_vocab,
241
+ out_channels,
242
+ hidden_channels,
243
+ filter_channels,
244
+ n_heads,
245
+ n_layers,
246
+ kernel_size,
247
+ p_dropout,
248
+ gin_channels=0,
249
+ num_languages=None,
250
+ num_tones=None,
251
+ ):
252
+ super().__init__()
253
+ if num_languages is None:
254
+ from tiny_tts.text import num_languages
255
+ if num_tones is None:
256
+ from tiny_tts.text import num_tones
257
+ self.n_vocab = n_vocab
258
+ self.out_channels = out_channels
259
+ self.hidden_channels = hidden_channels
260
+ self.filter_channels = filter_channels
261
+ self.n_heads = n_heads
262
+ self.n_layers = n_layers
263
+ self.kernel_size = kernel_size
264
+ self.p_dropout = p_dropout
265
+ self.gin_channels = gin_channels
266
+ self.emb = nn.Embedding(n_vocab, hidden_channels)
267
+ nn.init.normal_(self.emb.weight, 0.0, hidden_channels**-0.5)
268
+ self.tone_emb = nn.Embedding(num_tones, hidden_channels)
269
+ nn.init.normal_(self.tone_emb.weight, 0.0, hidden_channels**-0.5)
270
+ self.language_emb = nn.Embedding(num_languages, hidden_channels)
271
+ nn.init.normal_(self.language_emb.weight, 0.0, hidden_channels**-0.5)
272
+ self.bert_proj = nn.Conv1d(1024, hidden_channels, 1)
273
+ self.ja_bert_proj = nn.Conv1d(768, hidden_channels, 1)
274
+
275
+ self.encoder = attentions.TransformerBlock(
276
+ hidden_channels,
277
+ filter_channels,
278
+ n_heads,
279
+ n_layers,
280
+ kernel_size,
281
+ p_dropout,
282
+ gin_channels=self.gin_channels,
283
+ )
284
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
285
+
286
+ def forward(self, x, x_lengths, tone, language, bert, ja_bert, g=None):
287
+ bert_emb = self.bert_proj(bert).transpose(1, 2)
288
+ ja_bert_emb = self.ja_bert_proj(ja_bert).transpose(1, 2)
289
+ x = (
290
+ self.emb(x)
291
+ + self.tone_emb(tone)
292
+ + self.language_emb(language)
293
+ + bert_emb
294
+ + ja_bert_emb
295
+ ) * math.sqrt(
296
+ self.hidden_channels
297
+ )
298
+ x = torch.transpose(x, 1, -1)
299
+ x_mask = torch.unsqueeze(commons.create_length_mask(x_lengths, x.size(2)), 1).to(
300
+ x.dtype
301
+ )
302
+
303
+ x = self.encoder(x * x_mask, x_mask, g=g)
304
+ stats = self.proj(x) * x_mask
305
+
306
+ m, logs = torch.split(stats, self.out_channels, dim=1)
307
+ return x, m, logs, x_mask
308
+
309
+
310
+ class FlowBlock(nn.Module):
311
+ def __init__(
312
+ self,
313
+ channels,
314
+ hidden_channels,
315
+ kernel_size,
316
+ dilation_rate,
317
+ n_layers,
318
+ n_flows=4,
319
+ gin_channels=0,
320
+ ):
321
+ super().__init__()
322
+ self.channels = channels
323
+ self.hidden_channels = hidden_channels
324
+ self.kernel_size = kernel_size
325
+ self.dilation_rate = dilation_rate
326
+ self.n_layers = n_layers
327
+ self.n_flows = n_flows
328
+ self.gin_channels = gin_channels
329
+
330
+ self.flows = nn.ModuleList()
331
+ for i in range(n_flows):
332
+ self.flows.append(
333
+ modules.FlowCouplingLayer(
334
+ channels,
335
+ hidden_channels,
336
+ kernel_size,
337
+ dilation_rate,
338
+ n_layers,
339
+ gin_channels=gin_channels,
340
+ mean_only=True,
341
+ )
342
+ )
343
+ self.flows.append(modules.FlipTransform())
344
+
345
+ def forward(self, x, x_mask, g=None, reverse=False):
346
+ if not reverse:
347
+ for flow in self.flows:
348
+ x, _ = flow(x, x_mask, g=g, reverse=reverse)
349
+ else:
350
+ for flow in reversed(self.flows):
351
+ x = flow(x, x_mask, g=g, reverse=reverse)
352
+ return x
353
+
354
+
355
+ class LatentEncoder(nn.Module):
356
+ def __init__(
357
+ self,
358
+ in_channels,
359
+ out_channels,
360
+ hidden_channels,
361
+ kernel_size,
362
+ dilation_rate,
363
+ n_layers,
364
+ gin_channels=0,
365
+ ):
366
+ super().__init__()
367
+ self.in_channels = in_channels
368
+ self.out_channels = out_channels
369
+ self.hidden_channels = hidden_channels
370
+ self.kernel_size = kernel_size
371
+ self.dilation_rate = dilation_rate
372
+ self.n_layers = n_layers
373
+ self.gin_channels = gin_channels
374
+
375
+ self.pre = nn.Conv1d(in_channels, hidden_channels, 1)
376
+ self.enc = modules.WaveNet(
377
+ hidden_channels,
378
+ kernel_size,
379
+ dilation_rate,
380
+ n_layers,
381
+ gin_channels=gin_channels,
382
+ )
383
+ self.proj = nn.Conv1d(hidden_channels, out_channels * 2, 1)
384
+
385
+ def forward(self, x, x_lengths, g=None, tau=1.0):
386
+ x_mask = torch.unsqueeze(commons.create_length_mask(x_lengths, x.size(2)), 1).to(
387
+ x.dtype
388
+ )
389
+ x = self.pre(x) * x_mask
390
+ x = self.enc(x, x_mask, g=g)
391
+ stats = self.proj(x) * x_mask
392
+ m, logs = torch.split(stats, self.out_channels, dim=1)
393
+ z = (m + torch.randn_like(m) * tau * torch.exp(logs)) * x_mask
394
+ return z, m, logs, x_mask
395
+
396
+
397
+ class WaveformDecoder(torch.nn.Module):
398
+ def __init__(
399
+ self,
400
+ initial_channel,
401
+ resblock,
402
+ resblock_kernel_sizes,
403
+ resblock_dilation_sizes,
404
+ upsample_rates,
405
+ upsample_initial_channel,
406
+ upsample_kernel_sizes,
407
+ gin_channels=0,
408
+ ):
409
+ super(WaveformDecoder, self).__init__()
410
+ self.num_kernels = len(resblock_kernel_sizes)
411
+ self.num_upsamples = len(upsample_rates)
412
+ self.conv_pre = Conv1d(
413
+ initial_channel, upsample_initial_channel, 7, 1, padding=3
414
+ )
415
+ resblock = modules.ConvResBlock if resblock == "1" else modules.ConvResBlockLight
416
+
417
+ self.ups = nn.ModuleList()
418
+ for i, (u, k) in enumerate(zip(upsample_rates, upsample_kernel_sizes)):
419
+ self.ups.append(
420
+ weight_norm(
421
+ ConvTranspose1d(
422
+ upsample_initial_channel // (2**i),
423
+ upsample_initial_channel // (2 ** (i + 1)),
424
+ k,
425
+ u,
426
+ padding=(k - u) // 2,
427
+ )
428
+ )
429
+ )
430
+
431
+ self.resblocks = nn.ModuleList()
432
+ for i in range(len(self.ups)):
433
+ ch = upsample_initial_channel // (2 ** (i + 1))
434
+ for j, (k, d) in enumerate(
435
+ zip(resblock_kernel_sizes, resblock_dilation_sizes)
436
+ ):
437
+ self.resblocks.append(resblock(ch, k, d))
438
+
439
+ self.conv_post = Conv1d(ch, 1, 7, 1, padding=3, bias=False)
440
+ self.ups.apply(initialize_weights)
441
+
442
+ if gin_channels != 0:
443
+ self.cond = nn.Conv1d(gin_channels, upsample_initial_channel, 1)
444
+
445
+ def forward(self, x, g=None):
446
+ x = self.conv_pre(x)
447
+ if g is not None:
448
+ x = x + self.cond(g)
449
+
450
+ for i in range(self.num_upsamples):
451
+ x = F.leaky_relu(x, modules.LRELU_SLOPE)
452
+ x = self.ups[i](x)
453
+ xs = None
454
+ for j in range(self.num_kernels):
455
+ if xs is None:
456
+ xs = self.resblocks[i * self.num_kernels + j](x)
457
+ else:
458
+ xs += self.resblocks[i * self.num_kernels + j](x)
459
+ x = xs / self.num_kernels
460
+ x = F.leaky_relu(x)
461
+ x = self.conv_post(x)
462
+ x = torch.tanh(x)
463
+
464
+ return x
465
+
466
+ def remove_weight_norm(self):
467
+ for layer in self.ups:
468
+ remove_weight_norm(layer)
469
+ for layer in self.resblocks:
470
+ layer.remove_weight_norm()
471
+
472
+
473
+ class StyleEncoder(nn.Module):
474
+ def __init__(self, spec_channels, gin_channels=0, layernorm=False):
475
+ super().__init__()
476
+ self.spec_channels = spec_channels
477
+ ref_enc_filters = [32, 32, 64, 64, 128, 128]
478
+ K = len(ref_enc_filters)
479
+ filters = [1] + ref_enc_filters
480
+ convs = [
481
+ weight_norm(
482
+ nn.Conv2d(
483
+ in_channels=filters[i],
484
+ out_channels=filters[i + 1],
485
+ kernel_size=(3, 3),
486
+ stride=(2, 2),
487
+ padding=(1, 1),
488
+ )
489
+ )
490
+ for i in range(K)
491
+ ]
492
+ self.convs = nn.ModuleList(convs)
493
+
494
+ out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
495
+ self.gru = nn.GRU(
496
+ input_size=ref_enc_filters[-1] * out_channels,
497
+ hidden_size=256 // 2,
498
+ batch_first=True,
499
+ )
500
+ self.proj = nn.Linear(128, gin_channels)
501
+ if layernorm:
502
+ self.layernorm = nn.LayerNorm(self.spec_channels)
503
+ else:
504
+ self.layernorm = None
505
+
506
+ def forward(self, inputs, mask=None):
507
+ N = inputs.size(0)
508
+
509
+ out = inputs.view(N, 1, -1, self.spec_channels)
510
+ if self.layernorm is not None:
511
+ out = self.layernorm(out)
512
+
513
+ for conv in self.convs:
514
+ out = conv(out)
515
+ out = F.relu(out)
516
+
517
+ out = out.transpose(1, 2)
518
+ T = out.size(1)
519
+ N = out.size(0)
520
+ out = out.contiguous().view(N, T, -1)
521
+
522
+ self.gru.flatten_parameters()
523
+ memory, out = self.gru(out)
524
+
525
+ return self.proj(out.squeeze(0))
526
+
527
+ def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
528
+ for i in range(n_convs):
529
+ L = (L - kernel_size + 2 * pad) // stride + 1
530
+ return L
531
+
532
+
533
+ class VoiceSynthesizer(nn.Module):
534
+ """Voice synthesis model for inference."""
535
+
536
+ def __init__(
537
+ self,
538
+ n_vocab,
539
+ spec_channels,
540
+ segment_size,
541
+ inter_channels,
542
+ hidden_channels,
543
+ filter_channels,
544
+ n_heads,
545
+ n_layers,
546
+ kernel_size,
547
+ p_dropout,
548
+ resblock,
549
+ resblock_kernel_sizes,
550
+ resblock_dilation_sizes,
551
+ upsample_rates,
552
+ upsample_initial_channel,
553
+ upsample_kernel_sizes,
554
+ n_speakers=256,
555
+ gin_channels=256,
556
+ use_sdp=True,
557
+ n_flow_layer=4,
558
+ n_layers_trans_flow=6,
559
+ flow_share_parameter=False,
560
+ use_transformer_flow=True,
561
+ use_vc=False,
562
+ num_languages=None,
563
+ num_tones=None,
564
+ norm_refenc=False,
565
+ **kwargs
566
+ ):
567
+ super().__init__()
568
+ self.n_vocab = n_vocab
569
+ self.spec_channels = spec_channels
570
+ self.inter_channels = inter_channels
571
+ self.hidden_channels = hidden_channels
572
+ self.filter_channels = filter_channels
573
+ self.n_heads = n_heads
574
+ self.n_layers = n_layers
575
+ self.kernel_size = kernel_size
576
+ self.p_dropout = p_dropout
577
+ self.resblock = resblock
578
+ self.resblock_kernel_sizes = resblock_kernel_sizes
579
+ self.resblock_dilation_sizes = resblock_dilation_sizes
580
+ self.upsample_rates = upsample_rates
581
+ self.upsample_initial_channel = upsample_initial_channel
582
+ self.upsample_kernel_sizes = upsample_kernel_sizes
583
+ self.segment_size = segment_size
584
+ self.n_speakers = n_speakers
585
+ self.gin_channels = gin_channels
586
+ self.n_layers_trans_flow = n_layers_trans_flow
587
+ self.use_spk_conditioned_encoder = kwargs.get(
588
+ "use_spk_conditioned_encoder", True
589
+ )
590
+ self.use_sdp = use_sdp
591
+ self.use_noise_scaled_mas = kwargs.get("use_noise_scaled_mas", False)
592
+ self.mas_noise_scale_initial = kwargs.get("mas_noise_scale_initial", 0.01)
593
+ self.noise_scale_delta = kwargs.get("noise_scale_delta", 2e-6)
594
+ self.current_mas_noise_scale = self.mas_noise_scale_initial
595
+ if self.use_spk_conditioned_encoder and gin_channels > 0:
596
+ self.enc_gin_channels = gin_channels
597
+ else:
598
+ self.enc_gin_channels = 0
599
+ self.enc_p = PhonemeEncoder(
600
+ n_vocab,
601
+ inter_channels,
602
+ hidden_channels,
603
+ filter_channels,
604
+ n_heads,
605
+ n_layers,
606
+ kernel_size,
607
+ p_dropout,
608
+ gin_channels=self.enc_gin_channels,
609
+ num_languages=num_languages,
610
+ num_tones=num_tones,
611
+ )
612
+ self.dec = WaveformDecoder(
613
+ inter_channels,
614
+ resblock,
615
+ resblock_kernel_sizes,
616
+ resblock_dilation_sizes,
617
+ upsample_rates,
618
+ upsample_initial_channel,
619
+ upsample_kernel_sizes,
620
+ gin_channels=gin_channels,
621
+ )
622
+ self.enc_q = LatentEncoder(
623
+ spec_channels,
624
+ inter_channels,
625
+ hidden_channels,
626
+ 5,
627
+ 1,
628
+ 16,
629
+ gin_channels=gin_channels,
630
+ )
631
+ if use_transformer_flow:
632
+ self.flow = AttentionFlowBlock(
633
+ inter_channels,
634
+ hidden_channels,
635
+ filter_channels,
636
+ n_heads,
637
+ n_layers_trans_flow,
638
+ 5,
639
+ p_dropout,
640
+ n_flow_layer,
641
+ gin_channels=gin_channels,
642
+ share_parameter=flow_share_parameter,
643
+ )
644
+ else:
645
+ self.flow = FlowBlock(
646
+ inter_channels,
647
+ hidden_channels,
648
+ 5,
649
+ 1,
650
+ n_flow_layer,
651
+ gin_channels=gin_channels,
652
+ )
653
+ self.sdp = VariationalDurationModel(
654
+ hidden_channels, 192, 3, 0.5, 4, gin_channels=gin_channels
655
+ )
656
+ self.dp = DurationEstimator(
657
+ hidden_channels, 256, 3, 0.5, gin_channels=gin_channels
658
+ )
659
+
660
+ if n_speakers > 0:
661
+ self.emb_g = nn.Embedding(n_speakers, gin_channels)
662
+ else:
663
+ self.ref_enc = StyleEncoder(spec_channels, gin_channels, layernorm=norm_refenc)
664
+ self.use_vc = use_vc
665
+
666
+ def infer(
667
+ self,
668
+ x,
669
+ x_lengths,
670
+ sid,
671
+ tone,
672
+ language,
673
+ bert,
674
+ ja_bert,
675
+ noise_scale=0.667,
676
+ length_scale=1,
677
+ noise_scale_w=0.8,
678
+ max_len=None,
679
+ sdp_ratio=0,
680
+ y=None,
681
+ g=None,
682
+ ):
683
+ if g is None:
684
+ if self.n_speakers > 0:
685
+ g = self.emb_g(sid).unsqueeze(-1)
686
+ else:
687
+ g = self.ref_enc(y.transpose(1, 2)).unsqueeze(-1)
688
+ if self.use_vc:
689
+ g_p = None
690
+ else:
691
+ g_p = g
692
+ x, m_p, logs_p, x_mask = self.enc_p(
693
+ x, x_lengths, tone, language, bert, ja_bert, g=g_p
694
+ )
695
+ logw = self.sdp(x, x_mask, g=g, reverse=True, noise_scale=noise_scale_w) * (
696
+ sdp_ratio
697
+ ) + self.dp(x, x_mask, g=g) * (1 - sdp_ratio)
698
+ w = torch.exp(logw) * x_mask * length_scale
699
+
700
+ w_ceil = torch.ceil(w)
701
+ y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
702
+ y_mask = torch.unsqueeze(commons.create_length_mask(y_lengths, None), 1).to(
703
+ x_mask.dtype
704
+ )
705
+ attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
706
+ attn = commons.compute_alignment_path(w_ceil, attn_mask)
707
+
708
+ m_p = torch.matmul(attn.squeeze(1), m_p.transpose(1, 2)).transpose(
709
+ 1, 2
710
+ )
711
+ logs_p = torch.matmul(attn.squeeze(1), logs_p.transpose(1, 2)).transpose(
712
+ 1, 2
713
+ )
714
+
715
+ z_p = m_p + torch.randn_like(m_p) * torch.exp(logs_p) * noise_scale
716
+ z = self.flow(z_p, y_mask, g=g, reverse=True)
717
+ o = self.dec((z * y_mask)[:, :, :max_len], g=g)
718
+ return o, attn, y_mask, (z, z_p, m_p, logs_p)
tiny_tts/nn/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ # Neural network building blocks
tiny_tts/nn/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (152 Bytes). View file
 
tiny_tts/nn/__pycache__/attentions.cpython-310.pyc ADDED
Binary file (10.7 kB). View file
 
tiny_tts/nn/__pycache__/commons.cpython-310.pyc ADDED
Binary file (5.58 kB). View file
 
tiny_tts/nn/__pycache__/modules.cpython-310.pyc ADDED
Binary file (12.6 kB). View file
 
tiny_tts/nn/__pycache__/transforms.cpython-310.pyc ADDED
Binary file (3.86 kB). View file
 
tiny_tts/nn/attentions.py ADDED
@@ -0,0 +1,424 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from . import commons
7
+ import logging
8
+
9
+ logger = logging.getLogger(__name__)
10
+
11
+
12
+ class ChannelLayerNorm(nn.Module):
13
+ def __init__(self, channels, eps=1e-5):
14
+ super().__init__()
15
+ self.channels = channels
16
+ self.eps = eps
17
+
18
+ self.gamma = nn.Parameter(torch.ones(channels))
19
+ self.beta = nn.Parameter(torch.zeros(channels))
20
+
21
+ def forward(self, x):
22
+ x = x.transpose(1, -1)
23
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
24
+ return x.transpose(1, -1)
25
+
26
+
27
+ @torch.jit.script
28
+ def gated_activation(input_a, input_b, n_channels):
29
+ n_channels_int = n_channels[0]
30
+ in_act = input_a + input_b
31
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
32
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
33
+ acts = t_act * s_act
34
+ return acts
35
+
36
+
37
+ class TransformerBlock(nn.Module):
38
+ def __init__(
39
+ self,
40
+ hidden_channels,
41
+ filter_channels,
42
+ n_heads,
43
+ n_layers,
44
+ kernel_size=1,
45
+ p_dropout=0.0,
46
+ window_size=4,
47
+ isflow=True,
48
+ **kwargs
49
+ ):
50
+ super().__init__()
51
+ self.hidden_channels = hidden_channels
52
+ self.filter_channels = filter_channels
53
+ self.n_heads = n_heads
54
+ self.n_layers = n_layers
55
+ self.kernel_size = kernel_size
56
+ self.p_dropout = p_dropout
57
+ self.window_size = window_size
58
+
59
+ self.cond_layer_idx = self.n_layers
60
+ if "gin_channels" in kwargs:
61
+ self.gin_channels = kwargs["gin_channels"]
62
+ if self.gin_channels != 0:
63
+ self.spk_emb_linear = nn.Linear(self.gin_channels, self.hidden_channels)
64
+ self.cond_layer_idx = (
65
+ kwargs["cond_layer_idx"] if "cond_layer_idx" in kwargs else 2
66
+ )
67
+ assert (
68
+ self.cond_layer_idx < self.n_layers
69
+ ), "cond_layer_idx should be less than n_layers"
70
+ self.drop = nn.Dropout(p_dropout)
71
+ self.attn_layers = nn.ModuleList()
72
+ self.norm_layers_1 = nn.ModuleList()
73
+ self.ffn_layers = nn.ModuleList()
74
+ self.norm_layers_2 = nn.ModuleList()
75
+
76
+ for i in range(self.n_layers):
77
+ self.attn_layers.append(
78
+ MultiHeadSelfAttention(
79
+ hidden_channels,
80
+ hidden_channels,
81
+ n_heads,
82
+ p_dropout=p_dropout,
83
+ window_size=window_size,
84
+ )
85
+ )
86
+ self.norm_layers_1.append(ChannelLayerNorm(hidden_channels))
87
+ self.ffn_layers.append(
88
+ FeedForward(
89
+ hidden_channels,
90
+ hidden_channels,
91
+ filter_channels,
92
+ kernel_size,
93
+ p_dropout=p_dropout,
94
+ )
95
+ )
96
+ self.norm_layers_2.append(ChannelLayerNorm(hidden_channels))
97
+
98
+ def forward(self, x, x_mask, g=None):
99
+ attn_mask = x_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
100
+ x = x * x_mask
101
+ for i in range(self.n_layers):
102
+ if i == self.cond_layer_idx and g is not None:
103
+ g = self.spk_emb_linear(g.transpose(1, 2))
104
+ g = g.transpose(1, 2)
105
+ x = x + g
106
+ x = x * x_mask
107
+ y = self.attn_layers[i](x, x, attn_mask)
108
+ y = self.drop(y)
109
+ x = self.norm_layers_1[i](x + y)
110
+
111
+ y = self.ffn_layers[i](x, x_mask)
112
+ y = self.drop(y)
113
+ x = self.norm_layers_2[i](x + y)
114
+ x = x * x_mask
115
+ return x
116
+
117
+
118
+ class TransformerDecoder(nn.Module):
119
+ def __init__(
120
+ self,
121
+ hidden_channels,
122
+ filter_channels,
123
+ n_heads,
124
+ n_layers,
125
+ kernel_size=1,
126
+ p_dropout=0.0,
127
+ proximal_bias=False,
128
+ proximal_init=True,
129
+ **kwargs
130
+ ):
131
+ super().__init__()
132
+ self.hidden_channels = hidden_channels
133
+ self.filter_channels = filter_channels
134
+ self.n_heads = n_heads
135
+ self.n_layers = n_layers
136
+ self.kernel_size = kernel_size
137
+ self.p_dropout = p_dropout
138
+ self.proximal_bias = proximal_bias
139
+ self.proximal_init = proximal_init
140
+
141
+ self.drop = nn.Dropout(p_dropout)
142
+ self.self_attn_layers = nn.ModuleList()
143
+ self.norm_layers_0 = nn.ModuleList()
144
+ self.encdec_attn_layers = nn.ModuleList()
145
+ self.norm_layers_1 = nn.ModuleList()
146
+ self.ffn_layers = nn.ModuleList()
147
+ self.norm_layers_2 = nn.ModuleList()
148
+ for i in range(self.n_layers):
149
+ self.self_attn_layers.append(
150
+ MultiHeadSelfAttention(
151
+ hidden_channels,
152
+ hidden_channels,
153
+ n_heads,
154
+ p_dropout=p_dropout,
155
+ proximal_bias=proximal_bias,
156
+ proximal_init=proximal_init,
157
+ )
158
+ )
159
+ self.norm_layers_0.append(ChannelLayerNorm(hidden_channels))
160
+ self.encdec_attn_layers.append(
161
+ MultiHeadSelfAttention(
162
+ hidden_channels, hidden_channels, n_heads, p_dropout=p_dropout
163
+ )
164
+ )
165
+ self.norm_layers_1.append(ChannelLayerNorm(hidden_channels))
166
+ self.ffn_layers.append(
167
+ FeedForward(
168
+ hidden_channels,
169
+ hidden_channels,
170
+ filter_channels,
171
+ kernel_size,
172
+ p_dropout=p_dropout,
173
+ causal=True,
174
+ )
175
+ )
176
+ self.norm_layers_2.append(ChannelLayerNorm(hidden_channels))
177
+
178
+ def forward(self, x, x_mask, h, h_mask):
179
+ self_attn_mask = commons.subsequent_mask(x_mask.size(2)).to(
180
+ device=x.device, dtype=x.dtype
181
+ )
182
+ encdec_attn_mask = h_mask.unsqueeze(2) * x_mask.unsqueeze(-1)
183
+ x = x * x_mask
184
+ for i in range(self.n_layers):
185
+ y = self.self_attn_layers[i](x, x, self_attn_mask)
186
+ y = self.drop(y)
187
+ x = self.norm_layers_0[i](x + y)
188
+
189
+ y = self.encdec_attn_layers[i](x, h, encdec_attn_mask)
190
+ y = self.drop(y)
191
+ x = self.norm_layers_1[i](x + y)
192
+
193
+ y = self.ffn_layers[i](x, x_mask)
194
+ y = self.drop(y)
195
+ x = self.norm_layers_2[i](x + y)
196
+ x = x * x_mask
197
+ return x
198
+
199
+
200
+ class MultiHeadSelfAttention(nn.Module):
201
+ def __init__(
202
+ self,
203
+ channels,
204
+ out_channels,
205
+ n_heads,
206
+ p_dropout=0.0,
207
+ window_size=None,
208
+ heads_share=True,
209
+ block_length=None,
210
+ proximal_bias=False,
211
+ proximal_init=False,
212
+ ):
213
+ super().__init__()
214
+ assert channels % n_heads == 0
215
+
216
+ self.channels = channels
217
+ self.out_channels = out_channels
218
+ self.n_heads = n_heads
219
+ self.p_dropout = p_dropout
220
+ self.window_size = window_size
221
+ self.heads_share = heads_share
222
+ self.block_length = block_length
223
+ self.proximal_bias = proximal_bias
224
+ self.proximal_init = proximal_init
225
+ self.attn = None
226
+
227
+ self.k_channels = channels // n_heads
228
+ self.conv_q = nn.Conv1d(channels, channels, 1)
229
+ self.conv_k = nn.Conv1d(channels, channels, 1)
230
+ self.conv_v = nn.Conv1d(channels, channels, 1)
231
+ self.conv_o = nn.Conv1d(channels, out_channels, 1)
232
+ self.drop = nn.Dropout(p_dropout)
233
+
234
+ if window_size is not None:
235
+ n_heads_rel = 1 if heads_share else n_heads
236
+ rel_stddev = self.k_channels**-0.5
237
+ self.emb_rel_k = nn.Parameter(
238
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
239
+ * rel_stddev
240
+ )
241
+ self.emb_rel_v = nn.Parameter(
242
+ torch.randn(n_heads_rel, window_size * 2 + 1, self.k_channels)
243
+ * rel_stddev
244
+ )
245
+
246
+ nn.init.xavier_uniform_(self.conv_q.weight)
247
+ nn.init.xavier_uniform_(self.conv_k.weight)
248
+ nn.init.xavier_uniform_(self.conv_v.weight)
249
+ if proximal_init:
250
+ with torch.no_grad():
251
+ self.conv_k.weight.copy_(self.conv_q.weight)
252
+ self.conv_k.bias.copy_(self.conv_q.bias)
253
+
254
+ def forward(self, x, c, attn_mask=None):
255
+ q = self.conv_q(x)
256
+ k = self.conv_k(c)
257
+ v = self.conv_v(c)
258
+
259
+ x, self.attn = self.attention(q, k, v, mask=attn_mask)
260
+
261
+ x = self.conv_o(x)
262
+ return x
263
+
264
+ def attention(self, query, key, value, mask=None):
265
+ b, d, t_s, t_t = (*key.size(), query.size(2))
266
+ query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3)
267
+ key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
268
+ value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3)
269
+
270
+ scores = torch.matmul(query / math.sqrt(self.k_channels), key.transpose(-2, -1))
271
+ if self.window_size is not None:
272
+ assert (
273
+ t_s == t_t
274
+ ), "Relative attention is only available for self-attention."
275
+ key_relative_embeddings = self._get_relative_embeddings(self.emb_rel_k, t_s)
276
+ rel_logits = self._matmul_with_relative_keys(
277
+ query / math.sqrt(self.k_channels), key_relative_embeddings
278
+ )
279
+ scores_local = self._relative_position_to_absolute_position(rel_logits)
280
+ scores = scores + scores_local
281
+ if self.proximal_bias:
282
+ assert t_s == t_t, "Proximal bias is only available for self-attention."
283
+ scores = scores + self._attention_bias_proximal(t_s).to(
284
+ device=scores.device, dtype=scores.dtype
285
+ )
286
+ if mask is not None:
287
+ scores = scores.masked_fill(mask == 0, -1e4)
288
+ if self.block_length is not None:
289
+ assert (
290
+ t_s == t_t
291
+ ), "Local attention is only available for self-attention."
292
+ block_mask = (
293
+ torch.ones_like(scores)
294
+ .triu(-self.block_length)
295
+ .tril(self.block_length)
296
+ )
297
+ scores = scores.masked_fill(block_mask == 0, -1e4)
298
+ p_attn = F.softmax(scores, dim=-1)
299
+ p_attn = self.drop(p_attn)
300
+ output = torch.matmul(p_attn, value)
301
+ if self.window_size is not None:
302
+ relative_weights = self._absolute_position_to_relative_position(p_attn)
303
+ value_relative_embeddings = self._get_relative_embeddings(
304
+ self.emb_rel_v, t_s
305
+ )
306
+ output = output + self._matmul_with_relative_values(
307
+ relative_weights, value_relative_embeddings
308
+ )
309
+ output = (
310
+ output.transpose(2, 3).contiguous().view(b, d, t_t)
311
+ )
312
+ return output, p_attn
313
+
314
+ def _matmul_with_relative_values(self, x, y):
315
+ ret = torch.matmul(x, y.unsqueeze(0))
316
+ return ret
317
+
318
+ def _matmul_with_relative_keys(self, x, y):
319
+ ret = torch.matmul(x, y.unsqueeze(0).transpose(-2, -1))
320
+ return ret
321
+
322
+ def _get_relative_embeddings(self, relative_embeddings, length):
323
+ 2 * self.window_size + 1
324
+ pad_length = max(length - (self.window_size + 1), 0)
325
+ slice_start_position = max((self.window_size + 1) - length, 0)
326
+ slice_end_position = slice_start_position + 2 * length - 1
327
+ if pad_length > 0:
328
+ padded_relative_embeddings = F.pad(
329
+ relative_embeddings,
330
+ commons.flatten_pad_shape([[0, 0], [pad_length, pad_length], [0, 0]]),
331
+ )
332
+ else:
333
+ padded_relative_embeddings = relative_embeddings
334
+ used_relative_embeddings = padded_relative_embeddings[
335
+ :, slice_start_position:slice_end_position
336
+ ]
337
+ return used_relative_embeddings
338
+
339
+ def _relative_position_to_absolute_position(self, x):
340
+ batch, heads, length, _ = x.size()
341
+ x = F.pad(x, commons.flatten_pad_shape([[0, 0], [0, 0], [0, 0], [0, 1]]))
342
+
343
+ x_flat = x.view([batch, heads, length * 2 * length])
344
+ x_flat = F.pad(
345
+ x_flat, commons.flatten_pad_shape([[0, 0], [0, 0], [0, length - 1]])
346
+ )
347
+
348
+ x_final = x_flat.view([batch, heads, length + 1, 2 * length - 1])[
349
+ :, :, :length, length - 1 :
350
+ ]
351
+ return x_final
352
+
353
+ def _absolute_position_to_relative_position(self, x):
354
+ batch, heads, length, _ = x.size()
355
+ x = F.pad(
356
+ x, commons.flatten_pad_shape([[0, 0], [0, 0], [0, 0], [0, length - 1]])
357
+ )
358
+ x_flat = x.view([batch, heads, length**2 + length * (length - 1)])
359
+ x_flat = F.pad(x_flat, commons.flatten_pad_shape([[0, 0], [0, 0], [length, 0]]))
360
+ x_final = x_flat.view([batch, heads, length, 2 * length])[:, :, :, 1:]
361
+ return x_final
362
+
363
+ def _attention_bias_proximal(self, length):
364
+ r = torch.arange(length, dtype=torch.float32)
365
+ diff = torch.unsqueeze(r, 0) - torch.unsqueeze(r, 1)
366
+ return torch.unsqueeze(torch.unsqueeze(-torch.log1p(torch.abs(diff)), 0), 0)
367
+
368
+
369
+ class FeedForward(nn.Module):
370
+ def __init__(
371
+ self,
372
+ in_channels,
373
+ out_channels,
374
+ filter_channels,
375
+ kernel_size,
376
+ p_dropout=0.0,
377
+ activation=None,
378
+ causal=False,
379
+ ):
380
+ super().__init__()
381
+ self.in_channels = in_channels
382
+ self.out_channels = out_channels
383
+ self.filter_channels = filter_channels
384
+ self.kernel_size = kernel_size
385
+ self.p_dropout = p_dropout
386
+ self.activation = activation
387
+ self.causal = causal
388
+
389
+ if causal:
390
+ self.padding = self._causal_padding
391
+ else:
392
+ self.padding = self._same_padding
393
+
394
+ self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size)
395
+ self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size)
396
+ self.drop = nn.Dropout(p_dropout)
397
+
398
+ def forward(self, x, x_mask):
399
+ x = self.conv_1(self.padding(x * x_mask))
400
+ if self.activation == "gelu":
401
+ x = x * torch.sigmoid(1.702 * x)
402
+ else:
403
+ x = torch.relu(x)
404
+ x = self.drop(x)
405
+ x = self.conv_2(self.padding(x * x_mask))
406
+ return x * x_mask
407
+
408
+ def _causal_padding(self, x):
409
+ if self.kernel_size == 1:
410
+ return x
411
+ pad_l = self.kernel_size - 1
412
+ pad_r = 0
413
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
414
+ x = F.pad(x, commons.flatten_pad_shape(padding))
415
+ return x
416
+
417
+ def _same_padding(self, x):
418
+ if self.kernel_size == 1:
419
+ return x
420
+ pad_l = (self.kernel_size - 1) // 2
421
+ pad_r = self.kernel_size // 2
422
+ padding = [[0, 0], [0, 0], [pad_l, pad_r]]
423
+ x = F.pad(x, commons.flatten_pad_shape(padding))
424
+ return x
tiny_tts/nn/commons.py ADDED
@@ -0,0 +1,151 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch.nn import functional as F
4
+
5
+
6
+ def initialize_weights(m, mean=0.0, std=0.01):
7
+ classname = m.__class__.__name__
8
+ if classname.find("Conv") != -1:
9
+ m.weight.data.normal_(mean, std)
10
+
11
+
12
+ def compute_padding(kernel_size, dilation=1):
13
+ return int((kernel_size * dilation - dilation) / 2)
14
+
15
+
16
+ def flatten_pad_shape(pad_shape):
17
+ layer = pad_shape[::-1]
18
+ pad_shape = [item for sublist in layer for item in sublist]
19
+ return pad_shape
20
+
21
+
22
+ def insert_blanks(lst, item):
23
+ result = [item] * (len(lst) * 2 + 1)
24
+ result[1::2] = lst
25
+ return result
26
+
27
+
28
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
29
+ kl = (logs_q - logs_p) - 0.5
30
+ kl += (
31
+ 0.5 * (torch.exp(2.0 * logs_p) + ((m_p - m_q) ** 2)) * torch.exp(-2.0 * logs_q)
32
+ )
33
+ return kl
34
+
35
+
36
+ def rand_gumbel(shape):
37
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
38
+ return -torch.log(-torch.log(uniform_samples))
39
+
40
+
41
+ def rand_gumbel_like(x):
42
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
43
+ return g
44
+
45
+
46
+ def extract_segments(x, ids_str, segment_size=4):
47
+ ret = torch.zeros_like(x[:, :, :segment_size])
48
+ for i in range(x.size(0)):
49
+ idx_str = max(0, ids_str[i].item())
50
+ idx_end = idx_str + segment_size
51
+ available = x.size(2) - idx_str
52
+ if available >= segment_size:
53
+ ret[i] = x[i, :, idx_str:idx_end]
54
+ elif available > 0:
55
+ ret[i, :, :available] = x[i, :, idx_str:idx_str + available]
56
+ return ret
57
+
58
+
59
+ def random_segments(x, x_lengths=None, segment_size=4):
60
+ b, d, t = x.size()
61
+ if x_lengths is None:
62
+ x_lengths = t
63
+ ids_str_max = torch.clamp(x_lengths - segment_size + 1, min=0)
64
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
65
+ ret = extract_segments(x, ids_str, segment_size)
66
+ return ret, ids_str
67
+
68
+
69
+ def get_timing_signal_1d(length, channels, min_timescale=1.0, max_timescale=1.0e4):
70
+ position = torch.arange(length, dtype=torch.float)
71
+ num_timescales = channels // 2
72
+ log_timescale_increment = math.log(float(max_timescale) / float(min_timescale)) / (
73
+ num_timescales - 1
74
+ )
75
+ inv_timescales = min_timescale * torch.exp(
76
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment
77
+ )
78
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
79
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
80
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
81
+ signal = signal.view(1, channels, length)
82
+ return signal
83
+
84
+
85
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
86
+ b, channels, length = x.size()
87
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
88
+ return x + signal.to(dtype=x.dtype, device=x.device)
89
+
90
+
91
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
92
+ b, channels, length = x.size()
93
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
94
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
95
+
96
+
97
+ def subsequent_mask(length):
98
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
99
+ return mask
100
+
101
+
102
+ @torch.jit.script
103
+ def gated_activation(input_a, input_b, n_channels):
104
+ n_channels_int = n_channels[0]
105
+ in_act = input_a + input_b
106
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
107
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
108
+ acts = t_act * s_act
109
+ return acts
110
+
111
+
112
+ def shift_1d(x):
113
+ x = F.pad(x, flatten_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
114
+ return x
115
+
116
+
117
+ def create_length_mask(length, max_length=None):
118
+ if max_length is None:
119
+ max_length = length.max()
120
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
121
+ return x.unsqueeze(0) < length.unsqueeze(1)
122
+
123
+
124
+ def compute_alignment_path(duration, mask):
125
+ b, _, t_y, t_x = mask.shape
126
+ cum_duration = torch.cumsum(duration, -1)
127
+
128
+ cum_duration_flat = cum_duration.view(b * t_x)
129
+ path = create_length_mask(cum_duration_flat, t_y).to(mask.dtype)
130
+ path = path.view(b, t_x, t_y)
131
+ path = path - F.pad(path, flatten_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
132
+ path = path.unsqueeze(1).transpose(2, 3) * mask
133
+ return path
134
+
135
+
136
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
137
+ if isinstance(parameters, torch.Tensor):
138
+ parameters = [parameters]
139
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
140
+ norm_type = float(norm_type)
141
+ if clip_value is not None:
142
+ clip_value = float(clip_value)
143
+
144
+ total_norm = 0
145
+ for p in parameters:
146
+ param_norm = p.grad.data.norm(norm_type)
147
+ total_norm += param_norm.item() ** norm_type
148
+ if clip_value is not None:
149
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
150
+ total_norm = total_norm ** (1.0 / norm_type)
151
+ return total_norm
tiny_tts/nn/modules.py ADDED
@@ -0,0 +1,578 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import torch
3
+ from torch import nn
4
+ from torch.nn import functional as F
5
+
6
+ from torch.nn import Conv1d
7
+ from torch.nn.utils import weight_norm, remove_weight_norm
8
+
9
+ from . import commons
10
+ from .commons import initialize_weights, compute_padding
11
+ from .transforms import spline_transform
12
+ from .attentions import TransformerBlock
13
+
14
+ LRELU_SLOPE = 0.1
15
+
16
+
17
+ class ChannelNorm(nn.Module):
18
+ def __init__(self, channels, eps=1e-5):
19
+ super().__init__()
20
+ self.channels = channels
21
+ self.eps = eps
22
+
23
+ self.gamma = nn.Parameter(torch.ones(channels))
24
+ self.beta = nn.Parameter(torch.zeros(channels))
25
+
26
+ def forward(self, x):
27
+ x = x.transpose(1, -1)
28
+ x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps)
29
+ return x.transpose(1, -1)
30
+
31
+
32
+ class ConvReluNorm(nn.Module):
33
+ def __init__(
34
+ self,
35
+ in_channels,
36
+ hidden_channels,
37
+ out_channels,
38
+ kernel_size,
39
+ n_layers,
40
+ p_dropout,
41
+ ):
42
+ super().__init__()
43
+ self.in_channels = in_channels
44
+ self.hidden_channels = hidden_channels
45
+ self.out_channels = out_channels
46
+ self.kernel_size = kernel_size
47
+ self.n_layers = n_layers
48
+ self.p_dropout = p_dropout
49
+ assert n_layers > 1, "Number of layers should be larger than 0."
50
+
51
+ self.conv_layers = nn.ModuleList()
52
+ self.norm_layers = nn.ModuleList()
53
+ self.conv_layers.append(
54
+ nn.Conv1d(
55
+ in_channels, hidden_channels, kernel_size, padding=kernel_size // 2
56
+ )
57
+ )
58
+ self.norm_layers.append(ChannelNorm(hidden_channels))
59
+ self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout))
60
+ for _ in range(n_layers - 1):
61
+ self.conv_layers.append(
62
+ nn.Conv1d(
63
+ hidden_channels,
64
+ hidden_channels,
65
+ kernel_size,
66
+ padding=kernel_size // 2,
67
+ )
68
+ )
69
+ self.norm_layers.append(ChannelNorm(hidden_channels))
70
+ self.proj = nn.Conv1d(hidden_channels, out_channels, 1)
71
+ self.proj.weight.data.zero_()
72
+ self.proj.bias.data.zero_()
73
+
74
+ def forward(self, x, x_mask):
75
+ x_org = x
76
+ for i in range(self.n_layers):
77
+ x = self.conv_layers[i](x * x_mask)
78
+ x = self.norm_layers[i](x)
79
+ x = self.relu_drop(x)
80
+ x = x_org + self.proj(x)
81
+ return x * x_mask
82
+
83
+
84
+ class DepthwiseSepConv(nn.Module):
85
+ """Dilated and Depth-Separable Convolution"""
86
+
87
+ def __init__(self, channels, kernel_size, n_layers, p_dropout=0.0):
88
+ super().__init__()
89
+ self.channels = channels
90
+ self.kernel_size = kernel_size
91
+ self.n_layers = n_layers
92
+ self.p_dropout = p_dropout
93
+
94
+ self.drop = nn.Dropout(p_dropout)
95
+ self.convs_sep = nn.ModuleList()
96
+ self.convs_1x1 = nn.ModuleList()
97
+ self.norms_1 = nn.ModuleList()
98
+ self.norms_2 = nn.ModuleList()
99
+ for i in range(n_layers):
100
+ dilation = kernel_size**i
101
+ padding = (kernel_size * dilation - dilation) // 2
102
+ self.convs_sep.append(
103
+ nn.Conv1d(
104
+ channels,
105
+ channels,
106
+ kernel_size,
107
+ groups=channels,
108
+ dilation=dilation,
109
+ padding=padding,
110
+ )
111
+ )
112
+ self.convs_1x1.append(nn.Conv1d(channels, channels, 1))
113
+ self.norms_1.append(ChannelNorm(channels))
114
+ self.norms_2.append(ChannelNorm(channels))
115
+
116
+ def forward(self, x, x_mask, g=None):
117
+ if g is not None:
118
+ x = x + g
119
+ for i in range(self.n_layers):
120
+ y = self.convs_sep[i](x * x_mask)
121
+ y = self.norms_1[i](y)
122
+ y = F.gelu(y)
123
+ y = self.convs_1x1[i](y)
124
+ y = self.norms_2[i](y)
125
+ y = F.gelu(y)
126
+ y = self.drop(y)
127
+ x = x + y
128
+ return x * x_mask
129
+
130
+
131
+ class WaveNet(torch.nn.Module):
132
+ def __init__(
133
+ self,
134
+ hidden_channels,
135
+ kernel_size,
136
+ dilation_rate,
137
+ n_layers,
138
+ gin_channels=0,
139
+ p_dropout=0,
140
+ ):
141
+ super(WaveNet, self).__init__()
142
+ assert kernel_size % 2 == 1
143
+ self.hidden_channels = hidden_channels
144
+ self.kernel_size = (kernel_size,)
145
+ self.dilation_rate = dilation_rate
146
+ self.n_layers = n_layers
147
+ self.gin_channels = gin_channels
148
+ self.p_dropout = p_dropout
149
+
150
+ self.in_layers = torch.nn.ModuleList()
151
+ self.res_skip_layers = torch.nn.ModuleList()
152
+ self.drop = nn.Dropout(p_dropout)
153
+
154
+ if gin_channels != 0:
155
+ cond_layer = torch.nn.Conv1d(
156
+ gin_channels, 2 * hidden_channels * n_layers, 1
157
+ )
158
+ self.cond_layer = torch.nn.utils.weight_norm(cond_layer, name="weight")
159
+
160
+ for i in range(n_layers):
161
+ dilation = dilation_rate**i
162
+ padding = int((kernel_size * dilation - dilation) / 2)
163
+ in_layer = torch.nn.Conv1d(
164
+ hidden_channels,
165
+ 2 * hidden_channels,
166
+ kernel_size,
167
+ dilation=dilation,
168
+ padding=padding,
169
+ )
170
+ in_layer = torch.nn.utils.weight_norm(in_layer, name="weight")
171
+ self.in_layers.append(in_layer)
172
+
173
+ if i < n_layers - 1:
174
+ res_skip_channels = 2 * hidden_channels
175
+ else:
176
+ res_skip_channels = hidden_channels
177
+
178
+ res_skip_layer = torch.nn.Conv1d(hidden_channels, res_skip_channels, 1)
179
+ res_skip_layer = torch.nn.utils.weight_norm(res_skip_layer, name="weight")
180
+ self.res_skip_layers.append(res_skip_layer)
181
+
182
+ def forward(self, x, x_mask, g=None, **kwargs):
183
+ output = torch.zeros_like(x)
184
+ n_channels_tensor = torch.IntTensor([self.hidden_channels])
185
+
186
+ if g is not None:
187
+ g = self.cond_layer(g)
188
+
189
+ for i in range(self.n_layers):
190
+ x_in = self.in_layers[i](x)
191
+ if g is not None:
192
+ cond_offset = i * 2 * self.hidden_channels
193
+ g_l = g[:, cond_offset : cond_offset + 2 * self.hidden_channels, :]
194
+ else:
195
+ g_l = torch.zeros_like(x_in)
196
+
197
+ acts = commons.gated_activation(x_in, g_l, n_channels_tensor)
198
+ acts = self.drop(acts)
199
+
200
+ res_skip_acts = self.res_skip_layers[i](acts)
201
+ if i < self.n_layers - 1:
202
+ res_acts = res_skip_acts[:, : self.hidden_channels, :]
203
+ x = (x + res_acts) * x_mask
204
+ output = output + res_skip_acts[:, self.hidden_channels :, :]
205
+ else:
206
+ output = output + res_skip_acts
207
+ return output * x_mask
208
+
209
+ def remove_weight_norm(self):
210
+ if self.gin_channels != 0:
211
+ torch.nn.utils.remove_weight_norm(self.cond_layer)
212
+ for l in self.in_layers:
213
+ torch.nn.utils.remove_weight_norm(l)
214
+ for l in self.res_skip_layers:
215
+ torch.nn.utils.remove_weight_norm(l)
216
+
217
+
218
+ class ConvResBlock(torch.nn.Module):
219
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3, 5)):
220
+ super(ConvResBlock, self).__init__()
221
+ self.convs1 = nn.ModuleList(
222
+ [
223
+ weight_norm(
224
+ Conv1d(
225
+ channels,
226
+ channels,
227
+ kernel_size,
228
+ 1,
229
+ dilation=dilation[0],
230
+ padding=compute_padding(kernel_size, dilation[0]),
231
+ )
232
+ ),
233
+ weight_norm(
234
+ Conv1d(
235
+ channels,
236
+ channels,
237
+ kernel_size,
238
+ 1,
239
+ dilation=dilation[1],
240
+ padding=compute_padding(kernel_size, dilation[1]),
241
+ )
242
+ ),
243
+ weight_norm(
244
+ Conv1d(
245
+ channels,
246
+ channels,
247
+ kernel_size,
248
+ 1,
249
+ dilation=dilation[2],
250
+ padding=compute_padding(kernel_size, dilation[2]),
251
+ )
252
+ ),
253
+ ]
254
+ )
255
+ self.convs1.apply(initialize_weights)
256
+
257
+ self.convs2 = nn.ModuleList(
258
+ [
259
+ weight_norm(
260
+ Conv1d(
261
+ channels,
262
+ channels,
263
+ kernel_size,
264
+ 1,
265
+ dilation=1,
266
+ padding=compute_padding(kernel_size, 1),
267
+ )
268
+ ),
269
+ weight_norm(
270
+ Conv1d(
271
+ channels,
272
+ channels,
273
+ kernel_size,
274
+ 1,
275
+ dilation=1,
276
+ padding=compute_padding(kernel_size, 1),
277
+ )
278
+ ),
279
+ weight_norm(
280
+ Conv1d(
281
+ channels,
282
+ channels,
283
+ kernel_size,
284
+ 1,
285
+ dilation=1,
286
+ padding=compute_padding(kernel_size, 1),
287
+ )
288
+ ),
289
+ ]
290
+ )
291
+ self.convs2.apply(initialize_weights)
292
+
293
+ def forward(self, x, x_mask=None):
294
+ for c1, c2 in zip(self.convs1, self.convs2):
295
+ xt = F.leaky_relu(x, LRELU_SLOPE)
296
+ if x_mask is not None:
297
+ xt = xt * x_mask
298
+ xt = c1(xt)
299
+ xt = F.leaky_relu(xt, LRELU_SLOPE)
300
+ if x_mask is not None:
301
+ xt = xt * x_mask
302
+ xt = c2(xt)
303
+ x = xt + x
304
+ if x_mask is not None:
305
+ x = x * x_mask
306
+ return x
307
+
308
+ def remove_weight_norm(self):
309
+ for l in self.convs1:
310
+ remove_weight_norm(l)
311
+ for l in self.convs2:
312
+ remove_weight_norm(l)
313
+
314
+
315
+ class ConvResBlockLight(torch.nn.Module):
316
+ def __init__(self, channels, kernel_size=3, dilation=(1, 3)):
317
+ super(ConvResBlockLight, self).__init__()
318
+ self.convs = nn.ModuleList(
319
+ [
320
+ weight_norm(
321
+ Conv1d(
322
+ channels,
323
+ channels,
324
+ kernel_size,
325
+ 1,
326
+ dilation=dilation[0],
327
+ padding=compute_padding(kernel_size, dilation[0]),
328
+ )
329
+ ),
330
+ weight_norm(
331
+ Conv1d(
332
+ channels,
333
+ channels,
334
+ kernel_size,
335
+ 1,
336
+ dilation=dilation[1],
337
+ padding=compute_padding(kernel_size, dilation[1]),
338
+ )
339
+ ),
340
+ ]
341
+ )
342
+ self.convs.apply(initialize_weights)
343
+
344
+ def forward(self, x, x_mask=None):
345
+ for c in self.convs:
346
+ xt = F.leaky_relu(x, LRELU_SLOPE)
347
+ if x_mask is not None:
348
+ xt = xt * x_mask
349
+ xt = c(xt)
350
+ x = xt + x
351
+ if x_mask is not None:
352
+ x = x * x_mask
353
+ return x
354
+
355
+ def remove_weight_norm(self):
356
+ for l in self.convs:
357
+ remove_weight_norm(l)
358
+
359
+
360
+ class LogTransform(nn.Module):
361
+ def forward(self, x, x_mask, reverse=False, **kwargs):
362
+ if not reverse:
363
+ y = torch.log(torch.clamp_min(x, 1e-5)) * x_mask
364
+ logdet = torch.sum(-y, [1, 2])
365
+ return y, logdet
366
+ else:
367
+ x = torch.exp(x) * x_mask
368
+ return x
369
+
370
+
371
+ class FlipTransform(nn.Module):
372
+ def forward(self, x, *args, reverse=False, **kwargs):
373
+ x = torch.flip(x, [1])
374
+ if not reverse:
375
+ logdet = torch.zeros(x.size(0)).to(dtype=x.dtype, device=x.device)
376
+ return x, logdet
377
+ else:
378
+ return x
379
+
380
+
381
+ class AffineCoupling(nn.Module):
382
+ def __init__(self, channels):
383
+ super().__init__()
384
+ self.channels = channels
385
+ self.m = nn.Parameter(torch.zeros(channels, 1))
386
+ self.logs = nn.Parameter(torch.zeros(channels, 1))
387
+
388
+ def forward(self, x, x_mask, reverse=False, **kwargs):
389
+ if not reverse:
390
+ y = self.m + torch.exp(self.logs) * x
391
+ y = y * x_mask
392
+ logdet = torch.sum(self.logs * x_mask, [1, 2])
393
+ return y, logdet
394
+ else:
395
+ x = (x - self.m) * torch.exp(-self.logs) * x_mask
396
+ return x
397
+
398
+
399
+ class FlowCouplingLayer(nn.Module):
400
+ def __init__(
401
+ self,
402
+ channels,
403
+ hidden_channels,
404
+ kernel_size,
405
+ dilation_rate,
406
+ n_layers,
407
+ p_dropout=0,
408
+ gin_channels=0,
409
+ mean_only=False,
410
+ ):
411
+ assert channels % 2 == 0, "channels should be divisible by 2"
412
+ super().__init__()
413
+ self.channels = channels
414
+ self.hidden_channels = hidden_channels
415
+ self.kernel_size = kernel_size
416
+ self.dilation_rate = dilation_rate
417
+ self.n_layers = n_layers
418
+ self.half_channels = channels // 2
419
+ self.mean_only = mean_only
420
+
421
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
422
+ self.enc = WaveNet(
423
+ hidden_channels,
424
+ kernel_size,
425
+ dilation_rate,
426
+ n_layers,
427
+ p_dropout=p_dropout,
428
+ gin_channels=gin_channels,
429
+ )
430
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
431
+ self.post.weight.data.zero_()
432
+ self.post.bias.data.zero_()
433
+
434
+ def forward(self, x, x_mask, g=None, reverse=False):
435
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
436
+ h = self.pre(x0) * x_mask
437
+ h = self.enc(h, x_mask, g=g)
438
+ stats = self.post(h) * x_mask
439
+ if not self.mean_only:
440
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
441
+ else:
442
+ m = stats
443
+ logs = torch.zeros_like(m)
444
+
445
+ if not reverse:
446
+ x1 = m + x1 * torch.exp(logs) * x_mask
447
+ x = torch.cat([x0, x1], 1)
448
+ logdet = torch.sum(logs, [1, 2])
449
+ return x, logdet
450
+ else:
451
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
452
+ x = torch.cat([x0, x1], 1)
453
+ return x
454
+
455
+
456
+ class ConvolutionalFlow(nn.Module):
457
+ def __init__(
458
+ self,
459
+ in_channels,
460
+ filter_channels,
461
+ kernel_size,
462
+ n_layers,
463
+ num_bins=10,
464
+ tail_bound=5.0,
465
+ ):
466
+ super().__init__()
467
+ self.in_channels = in_channels
468
+ self.filter_channels = filter_channels
469
+ self.kernel_size = kernel_size
470
+ self.n_layers = n_layers
471
+ self.num_bins = num_bins
472
+ self.tail_bound = tail_bound
473
+ self.half_channels = in_channels // 2
474
+
475
+ self.pre = nn.Conv1d(self.half_channels, filter_channels, 1)
476
+ self.convs = DepthwiseSepConv(filter_channels, kernel_size, n_layers, p_dropout=0.0)
477
+ self.proj = nn.Conv1d(
478
+ filter_channels, self.half_channels * (num_bins * 3 - 1), 1
479
+ )
480
+ self.proj.weight.data.zero_()
481
+ self.proj.bias.data.zero_()
482
+
483
+ def forward(self, x, x_mask, g=None, reverse=False):
484
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
485
+ h = self.pre(x0)
486
+ h = self.convs(h, x_mask, g=g)
487
+ h = self.proj(h) * x_mask
488
+
489
+ b, c, t = x0.shape
490
+ h = h.reshape(b, c, -1, t).permute(0, 1, 3, 2)
491
+
492
+ unnormalized_widths = h[..., : self.num_bins] / math.sqrt(self.filter_channels)
493
+ unnormalized_heights = h[..., self.num_bins : 2 * self.num_bins] / math.sqrt(
494
+ self.filter_channels
495
+ )
496
+ unnormalized_derivatives = h[..., 2 * self.num_bins :]
497
+
498
+ x1, logabsdet = spline_transform(
499
+ x1,
500
+ unnormalized_widths,
501
+ unnormalized_heights,
502
+ unnormalized_derivatives,
503
+ inverse=reverse,
504
+ tails="linear",
505
+ tail_bound=self.tail_bound,
506
+ )
507
+
508
+ x = torch.cat([x0, x1], 1) * x_mask
509
+ logdet = torch.sum(logabsdet * x_mask, [1, 2])
510
+ if not reverse:
511
+ return x, logdet
512
+ else:
513
+ return x
514
+
515
+
516
+ class TransformerCouplingLayer(nn.Module):
517
+ def __init__(
518
+ self,
519
+ channels,
520
+ hidden_channels,
521
+ kernel_size,
522
+ n_layers,
523
+ n_heads,
524
+ p_dropout=0,
525
+ filter_channels=0,
526
+ mean_only=False,
527
+ wn_sharing_parameter=None,
528
+ gin_channels=0,
529
+ ):
530
+ assert n_layers == 3, n_layers
531
+ assert channels % 2 == 0, "channels should be divisible by 2"
532
+ super().__init__()
533
+ self.channels = channels
534
+ self.hidden_channels = hidden_channels
535
+ self.kernel_size = kernel_size
536
+ self.n_layers = n_layers
537
+ self.half_channels = channels // 2
538
+ self.mean_only = mean_only
539
+
540
+ self.pre = nn.Conv1d(self.half_channels, hidden_channels, 1)
541
+ self.enc = (
542
+ TransformerBlock(
543
+ hidden_channels,
544
+ filter_channels,
545
+ n_heads,
546
+ n_layers,
547
+ kernel_size,
548
+ p_dropout,
549
+ isflow=True,
550
+ gin_channels=gin_channels,
551
+ )
552
+ if wn_sharing_parameter is None
553
+ else wn_sharing_parameter
554
+ )
555
+ self.post = nn.Conv1d(hidden_channels, self.half_channels * (2 - mean_only), 1)
556
+ self.post.weight.data.zero_()
557
+ self.post.bias.data.zero_()
558
+
559
+ def forward(self, x, x_mask, g=None, reverse=False):
560
+ x0, x1 = torch.split(x, [self.half_channels] * 2, 1)
561
+ h = self.pre(x0) * x_mask
562
+ h = self.enc(h, x_mask, g=g)
563
+ stats = self.post(h) * x_mask
564
+ if not self.mean_only:
565
+ m, logs = torch.split(stats, [self.half_channels] * 2, 1)
566
+ else:
567
+ m = stats
568
+ logs = torch.zeros_like(m)
569
+
570
+ if not reverse:
571
+ x1 = m + x1 * torch.exp(logs) * x_mask
572
+ x = torch.cat([x0, x1], 1)
573
+ logdet = torch.sum(logs, [1, 2])
574
+ return x, logdet
575
+ else:
576
+ x1 = (x1 - m) * torch.exp(-logs) * x_mask
577
+ x = torch.cat([x0, x1], 1)
578
+ return x
tiny_tts/nn/transforms.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch.nn import functional as F
3
+
4
+ import numpy as np
5
+
6
+
7
+ DEFAULT_MIN_BIN_WIDTH = 1e-3
8
+ DEFAULT_MIN_BIN_HEIGHT = 1e-3
9
+ DEFAULT_MIN_DERIVATIVE = 1e-3
10
+
11
+
12
+ def spline_transform(
13
+ inputs,
14
+ unnormalized_widths,
15
+ unnormalized_heights,
16
+ unnormalized_derivatives,
17
+ inverse=False,
18
+ tails=None,
19
+ tail_bound=1.0,
20
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
21
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
22
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
23
+ ):
24
+ if tails is None:
25
+ spline_fn = quadratic_spline
26
+ spline_kwargs = {}
27
+ else:
28
+ spline_fn = unbounded_spline
29
+ spline_kwargs = {"tails": tails, "tail_bound": tail_bound}
30
+
31
+ outputs, logabsdet = spline_fn(
32
+ inputs=inputs,
33
+ unnormalized_widths=unnormalized_widths,
34
+ unnormalized_heights=unnormalized_heights,
35
+ unnormalized_derivatives=unnormalized_derivatives,
36
+ inverse=inverse,
37
+ min_bin_width=min_bin_width,
38
+ min_bin_height=min_bin_height,
39
+ min_derivative=min_derivative,
40
+ **spline_kwargs
41
+ )
42
+ return outputs, logabsdet
43
+
44
+
45
+ def searchsorted(bin_locations, inputs, eps=1e-6):
46
+ bin_locations[..., -1] += eps
47
+ return torch.sum(inputs[..., None] >= bin_locations, dim=-1) - 1
48
+
49
+
50
+ def unbounded_spline(
51
+ inputs,
52
+ unnormalized_widths,
53
+ unnormalized_heights,
54
+ unnormalized_derivatives,
55
+ inverse=False,
56
+ tails="linear",
57
+ tail_bound=1.0,
58
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
59
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
60
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
61
+ ):
62
+ inside_interval_mask = (inputs >= -tail_bound) & (inputs <= tail_bound)
63
+ outside_interval_mask = ~inside_interval_mask
64
+
65
+ outputs = torch.zeros_like(inputs)
66
+ logabsdet = torch.zeros_like(inputs)
67
+
68
+ if tails == "linear":
69
+ unnormalized_derivatives = F.pad(unnormalized_derivatives, pad=(1, 1))
70
+ constant = np.log(np.exp(1 - min_derivative) - 1)
71
+ unnormalized_derivatives[..., 0] = constant
72
+ unnormalized_derivatives[..., -1] = constant
73
+
74
+ outputs[outside_interval_mask] = inputs[outside_interval_mask]
75
+ logabsdet[outside_interval_mask] = 0
76
+ else:
77
+ raise RuntimeError("{} tails are not implemented.".format(tails))
78
+
79
+ (
80
+ outputs[inside_interval_mask],
81
+ logabsdet[inside_interval_mask],
82
+ ) = quadratic_spline(
83
+ inputs=inputs[inside_interval_mask],
84
+ unnormalized_widths=unnormalized_widths[inside_interval_mask, :],
85
+ unnormalized_heights=unnormalized_heights[inside_interval_mask, :],
86
+ unnormalized_derivatives=unnormalized_derivatives[inside_interval_mask, :],
87
+ inverse=inverse,
88
+ left=-tail_bound,
89
+ right=tail_bound,
90
+ bottom=-tail_bound,
91
+ top=tail_bound,
92
+ min_bin_width=min_bin_width,
93
+ min_bin_height=min_bin_height,
94
+ min_derivative=min_derivative,
95
+ )
96
+
97
+ return outputs, logabsdet
98
+
99
+
100
+ def quadratic_spline(
101
+ inputs,
102
+ unnormalized_widths,
103
+ unnormalized_heights,
104
+ unnormalized_derivatives,
105
+ inverse=False,
106
+ left=0.0,
107
+ right=1.0,
108
+ bottom=0.0,
109
+ top=1.0,
110
+ min_bin_width=DEFAULT_MIN_BIN_WIDTH,
111
+ min_bin_height=DEFAULT_MIN_BIN_HEIGHT,
112
+ min_derivative=DEFAULT_MIN_DERIVATIVE,
113
+ ):
114
+ if torch.min(inputs) < left or torch.max(inputs) > right:
115
+ raise ValueError("Input to a transform is not within its domain")
116
+
117
+ num_bins = unnormalized_widths.shape[-1]
118
+
119
+ if min_bin_width * num_bins > 1.0:
120
+ raise ValueError("Minimal bin width too large for the number of bins")
121
+ if min_bin_height * num_bins > 1.0:
122
+ raise ValueError("Minimal bin height too large for the number of bins")
123
+
124
+ widths = F.softmax(unnormalized_widths, dim=-1)
125
+ widths = min_bin_width + (1 - min_bin_width * num_bins) * widths
126
+ cumwidths = torch.cumsum(widths, dim=-1)
127
+ cumwidths = F.pad(cumwidths, pad=(1, 0), mode="constant", value=0.0)
128
+ cumwidths = (right - left) * cumwidths + left
129
+ cumwidths[..., 0] = left
130
+ cumwidths[..., -1] = right
131
+ widths = cumwidths[..., 1:] - cumwidths[..., :-1]
132
+
133
+ derivatives = min_derivative + F.softplus(unnormalized_derivatives)
134
+
135
+ heights = F.softmax(unnormalized_heights, dim=-1)
136
+ heights = min_bin_height + (1 - min_bin_height * num_bins) * heights
137
+ cumheights = torch.cumsum(heights, dim=-1)
138
+ cumheights = F.pad(cumheights, pad=(1, 0), mode="constant", value=0.0)
139
+ cumheights = (top - bottom) * cumheights + bottom
140
+ cumheights[..., 0] = bottom
141
+ cumheights[..., -1] = top
142
+ heights = cumheights[..., 1:] - cumheights[..., :-1]
143
+
144
+ if inverse:
145
+ bin_idx = searchsorted(cumheights, inputs)[..., None]
146
+ else:
147
+ bin_idx = searchsorted(cumwidths, inputs)[..., None]
148
+
149
+ input_cumwidths = cumwidths.gather(-1, bin_idx)[..., 0]
150
+ input_bin_widths = widths.gather(-1, bin_idx)[..., 0]
151
+
152
+ input_cumheights = cumheights.gather(-1, bin_idx)[..., 0]
153
+ delta = heights / widths
154
+ input_delta = delta.gather(-1, bin_idx)[..., 0]
155
+
156
+ input_derivatives = derivatives.gather(-1, bin_idx)[..., 0]
157
+ input_derivatives_plus_one = derivatives[..., 1:].gather(-1, bin_idx)[..., 0]
158
+
159
+ input_heights = heights.gather(-1, bin_idx)[..., 0]
160
+
161
+ if inverse:
162
+ a = (inputs - input_cumheights) * (
163
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
164
+ ) + input_heights * (input_delta - input_derivatives)
165
+ b = input_heights * input_derivatives - (inputs - input_cumheights) * (
166
+ input_derivatives + input_derivatives_plus_one - 2 * input_delta
167
+ )
168
+ c = -input_delta * (inputs - input_cumheights)
169
+
170
+ discriminant = b.pow(2) - 4 * a * c
171
+ assert (discriminant >= 0).all()
172
+
173
+ root = (2 * c) / (-b - torch.sqrt(discriminant))
174
+ outputs = root * input_bin_widths + input_cumwidths
175
+
176
+ theta_one_minus_theta = root * (1 - root)
177
+ denominator = input_delta + (
178
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
179
+ * theta_one_minus_theta
180
+ )
181
+ derivative_numerator = input_delta.pow(2) * (
182
+ input_derivatives_plus_one * root.pow(2)
183
+ + 2 * input_delta * theta_one_minus_theta
184
+ + input_derivatives * (1 - root).pow(2)
185
+ )
186
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
187
+
188
+ return outputs, -logabsdet
189
+ else:
190
+ theta = (inputs - input_cumwidths) / input_bin_widths
191
+ theta_one_minus_theta = theta * (1 - theta)
192
+
193
+ numerator = input_heights * (
194
+ input_delta * theta.pow(2) + input_derivatives * theta_one_minus_theta
195
+ )
196
+ denominator = input_delta + (
197
+ (input_derivatives + input_derivatives_plus_one - 2 * input_delta)
198
+ * theta_one_minus_theta
199
+ )
200
+ outputs = input_cumheights + numerator / denominator
201
+
202
+ derivative_numerator = input_delta.pow(2) * (
203
+ input_derivatives_plus_one * theta.pow(2)
204
+ + 2 * input_delta * theta_one_minus_theta
205
+ + input_derivatives * (1 - theta).pow(2)
206
+ )
207
+ logabsdet = torch.log(derivative_numerator) - 2 * torch.log(denominator)
208
+
209
+ return outputs, logabsdet
tiny_tts/text/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .symbols import *
2
+
3
+
4
+ _symbol_to_id = {s: i for i, s in enumerate(symbols)}
5
+
6
+
7
+ def phonemes_to_ids(cleaned_text, tones, language, symbol_to_id=None):
8
+ """Converts a list of phoneme symbols to a sequence of integer IDs."""
9
+ symbol_to_id_map = symbol_to_id if symbol_to_id else _symbol_to_id
10
+ unk_id = symbol_to_id_map.get("UNK")
11
+ if unk_id is None:
12
+ phones = [symbol_to_id_map[symbol] for symbol in cleaned_text]
13
+ else:
14
+ phones = [symbol_to_id_map.get(symbol, unk_id) for symbol in cleaned_text]
15
+ tone_start = language_tone_start_map[language]
16
+ tones = [i + tone_start for i in tones]
17
+ lang_id = language_id_map[language]
18
+ lang_ids = [lang_id for _ in phones]
19
+ return phones, tones, lang_ids
tiny_tts/text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (1.29 kB). View file
 
tiny_tts/text/__pycache__/english.cpython-310.pyc ADDED
Binary file (4.67 kB). View file
 
tiny_tts/text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (2.92 kB). View file
 
tiny_tts/text/cmudict.rep ADDED
The diff for this file is too large to render. See raw diff
 
tiny_tts/text/cmudict_cache.pickle ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b9b21b20325471934ba92f2e4a5976989e7d920caa32e7a286eacb027d197949
3
+ size 6212655
tiny_tts/text/english.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pickle
2
+ import os
3
+ import re
4
+ from g2p_en import G2p
5
+
6
+ from . import symbols
7
+
8
+ from .english_utils.abbreviations import expand_abbreviations
9
+ from .english_utils.time_norm import expand_time_english
10
+ from .english_utils.number_norm import normalize_numbers
11
+
12
+
13
+ def distribute_phone(n_phone, n_word):
14
+ phones_per_word = [0] * n_word
15
+ for task in range(n_phone):
16
+ min_tasks = min(phones_per_word)
17
+ min_indices = [
18
+ i for i, x in enumerate(phones_per_word) if x == min_tasks
19
+ ]
20
+ chosen_index = min_indices[len(min_indices) // 2]
21
+ phones_per_word[chosen_index] += 1
22
+ return phones_per_word
23
+
24
+
25
+ from transformers import AutoTokenizer
26
+
27
+ current_file_path = os.path.dirname(__file__)
28
+ CMU_DICT_PATH = os.path.join(current_file_path, "cmudict.rep")
29
+ CACHE_PATH = os.path.join(current_file_path, "cmudict_cache.pickle")
30
+ _g2p = G2p()
31
+
32
+ arpa = {
33
+ "AH0", "S", "AH1", "EY2", "AE2", "EH0", "OW2", "UH0", "NG", "B",
34
+ "G", "AY0", "M", "AA0", "F", "AO0", "ER2", "UH1", "IY1", "AH2",
35
+ "DH", "IY0", "EY1", "IH0", "K", "N", "W", "IY2", "T", "AA1",
36
+ "ER1", "EH2", "OY0", "UH2", "UW1", "Z", "AW2", "AW1", "V", "UW2",
37
+ "AA2", "ER", "AW0", "UW0", "R", "OW1", "EH1", "ZH", "AE0", "IH2",
38
+ "IH", "Y", "JH", "P", "AY1", "EY0", "OY2", "TH", "HH", "D",
39
+ "ER0", "CH", "AO1", "AE1", "AO2", "OY1", "AY2", "IH1", "OW0", "L", "SH",
40
+ }
41
+
42
+
43
+ def map_phoneme(ph):
44
+ rep_map = {
45
+ ":": ",", ";": ",", ",": ",", "。": ".", "!": "!",
46
+ "?": "?", "\n": ".", "·": ",", "、": ",", "...": "…", "v": "V",
47
+ }
48
+ if ph in rep_map.keys():
49
+ ph = rep_map[ph]
50
+ if ph in symbols:
51
+ return ph
52
+ if ph not in symbols:
53
+ ph = "UNK"
54
+ return ph
55
+
56
+
57
+ def read_dict():
58
+ g2p_dict = {}
59
+ start_line = 49
60
+ with open(CMU_DICT_PATH) as f:
61
+ line = f.readline()
62
+ line_index = 1
63
+ while line:
64
+ if line_index >= start_line:
65
+ line = line.strip()
66
+ word_split = line.split(" ")
67
+ word = word_split[0]
68
+
69
+ syllable_split = word_split[1].split(" - ")
70
+ g2p_dict[word] = []
71
+ for syllable in syllable_split:
72
+ phone_split = syllable.split(" ")
73
+ g2p_dict[word].append(phone_split)
74
+
75
+ line_index = line_index + 1
76
+ line = f.readline()
77
+
78
+ return g2p_dict
79
+
80
+
81
+ def cache_dict(g2p_dict, file_path):
82
+ with open(file_path, "wb") as pickle_file:
83
+ pickle.dump(g2p_dict, pickle_file)
84
+
85
+
86
+ def get_dict():
87
+ if os.path.exists(CACHE_PATH):
88
+ with open(CACHE_PATH, "rb") as pickle_file:
89
+ g2p_dict = pickle.load(pickle_file)
90
+ else:
91
+ g2p_dict = read_dict()
92
+ cache_dict(g2p_dict, CACHE_PATH)
93
+
94
+ return g2p_dict
95
+
96
+
97
+ eng_dict = get_dict()
98
+
99
+
100
+ def parse_phoneme(phn):
101
+ tone = 0
102
+ if re.search(r"\d$", phn):
103
+ tone = int(phn[-1]) + 1
104
+ phn = phn[:-1]
105
+ return phn.lower(), tone
106
+
107
+
108
+ def parse_syllables(syllables):
109
+ tones = []
110
+ phonemes = []
111
+ for phn_list in syllables:
112
+ for i in range(len(phn_list)):
113
+ phn = phn_list[i]
114
+ phn, tone = parse_phoneme(phn)
115
+ phonemes.append(phn)
116
+ tones.append(tone)
117
+ return phonemes, tones
118
+
119
+
120
+ def normalize_text(text):
121
+ text = text.lower()
122
+ text = expand_time_english(text)
123
+ text = normalize_numbers(text)
124
+ text = expand_abbreviations(text)
125
+ return text
126
+
127
+
128
+ model_id = 'bert-base-uncased'
129
+ tokenizer = AutoTokenizer.from_pretrained(model_id)
130
+
131
+
132
+ def grapheme_to_phoneme(text, pad_start_end=True, tokenized=None):
133
+ if tokenized is None:
134
+ tokenized = tokenizer.tokenize(text)
135
+ ph_groups = []
136
+ for t in tokenized:
137
+ if not t.startswith("#"):
138
+ ph_groups.append([t])
139
+ else:
140
+ ph_groups[-1].append(t.replace("#", ""))
141
+
142
+ phones = []
143
+ tones = []
144
+ word2ph = []
145
+ for group in ph_groups:
146
+ w = "".join(group)
147
+ phone_len = 0
148
+ word_len = len(group)
149
+ if w.upper() in eng_dict:
150
+ phns, tns = parse_syllables(eng_dict[w.upper()])
151
+ phones += phns
152
+ tones += tns
153
+ phone_len += len(phns)
154
+ else:
155
+ phone_list = list(filter(lambda p: p != " ", _g2p(w)))
156
+ for ph in phone_list:
157
+ if ph in arpa:
158
+ ph, tn = parse_phoneme(ph)
159
+ phones.append(ph)
160
+ tones.append(tn)
161
+ else:
162
+ phones.append(ph)
163
+ tones.append(0)
164
+ phone_len += 1
165
+ aaa = distribute_phone(phone_len, word_len)
166
+ word2ph += aaa
167
+ phones = [map_phoneme(i) for i in phones]
168
+
169
+ if pad_start_end:
170
+ phones = ["_"] + phones + ["_"]
171
+ tones = [0] + tones + [0]
172
+ word2ph = [1] + word2ph + [1]
173
+ return phones, tones, word2ph
tiny_tts/text/english_utils/__init__.py ADDED
File without changes
tiny_tts/text/english_utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (159 Bytes). View file
 
tiny_tts/text/english_utils/__pycache__/abbreviations.cpython-310.pyc ADDED
Binary file (952 Bytes). View file
 
tiny_tts/text/english_utils/__pycache__/number_norm.cpython-310.pyc ADDED
Binary file (2.77 kB). View file
 
tiny_tts/text/english_utils/__pycache__/time_norm.cpython-310.pyc ADDED
Binary file (1.42 kB). View file
 
tiny_tts/text/english_utils/abbreviations.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ # List of (regular expression, replacement) pairs for abbreviations in english:
4
+ abbreviations_en = [
5
+ (re.compile("\\b%s\\." % x[0], re.IGNORECASE), x[1])
6
+ for x in [
7
+ ("mrs", "misess"),
8
+ ("mr", "mister"),
9
+ ("dr", "doctor"),
10
+ ("st", "saint"),
11
+ ("co", "company"),
12
+ ("jr", "junior"),
13
+ ("maj", "major"),
14
+ ("gen", "general"),
15
+ ("drs", "doctors"),
16
+ ("rev", "reverend"),
17
+ ("lt", "lieutenant"),
18
+ ("hon", "honorable"),
19
+ ("sgt", "sergeant"),
20
+ ("capt", "captain"),
21
+ ("esq", "esquire"),
22
+ ("ltd", "limited"),
23
+ ("col", "colonel"),
24
+ ("ft", "fort"),
25
+ ]
26
+ ]
27
+
28
+ def expand_abbreviations(text, lang="en"):
29
+ if lang == "en":
30
+ _abbreviations = abbreviations_en
31
+ else:
32
+ raise NotImplementedError()
33
+ for regex, replacement in _abbreviations:
34
+ text = re.sub(regex, replacement, text)
35
+ return text
tiny_tts/text/english_utils/number_norm.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """ from https://github.com/keithito/tacotron """
2
+
3
+ import re
4
+ from typing import Dict
5
+
6
+ import inflect
7
+
8
+ _inflect = inflect.engine()
9
+ _comma_number_re = re.compile(r"([0-9][0-9\,]+[0-9])")
10
+ _decimal_number_re = re.compile(r"([0-9]+\.[0-9]+)")
11
+ _currency_re = re.compile(r"(£|\$|¥)([0-9\,\.]*[0-9]+)")
12
+ _ordinal_re = re.compile(r"[0-9]+(st|nd|rd|th)")
13
+ _number_re = re.compile(r"-?[0-9]+")
14
+
15
+
16
+ def _remove_commas(m):
17
+ return m.group(1).replace(",", "")
18
+
19
+
20
+ def _expand_decimal_point(m):
21
+ return m.group(1).replace(".", " point ")
22
+
23
+
24
+ def __expand_currency(value: str, inflection: Dict[float, str]) -> str:
25
+ parts = value.replace(",", "").split(".")
26
+ if len(parts) > 2:
27
+ return f"{value} {inflection[2]}" # Unexpected format
28
+ text = []
29
+ integer = int(parts[0]) if parts[0] else 0
30
+ if integer > 0:
31
+ integer_unit = inflection.get(integer, inflection[2])
32
+ text.append(f"{integer} {integer_unit}")
33
+ fraction = int(parts[1]) if len(parts) > 1 and parts[1] else 0
34
+ if fraction > 0:
35
+ fraction_unit = inflection.get(fraction / 100, inflection[0.02])
36
+ text.append(f"{fraction} {fraction_unit}")
37
+ if len(text) == 0:
38
+ return f"zero {inflection[2]}"
39
+ return " ".join(text)
40
+
41
+
42
+ def _expand_currency(m: "re.Match") -> str:
43
+ currencies = {
44
+ "$": {
45
+ 0.01: "cent",
46
+ 0.02: "cents",
47
+ 1: "dollar",
48
+ 2: "dollars",
49
+ },
50
+ "€": {
51
+ 0.01: "cent",
52
+ 0.02: "cents",
53
+ 1: "euro",
54
+ 2: "euros",
55
+ },
56
+ "£": {
57
+ 0.01: "penny",
58
+ 0.02: "pence",
59
+ 1: "pound sterling",
60
+ 2: "pounds sterling",
61
+ },
62
+ "¥": {
63
+ # TODO rin
64
+ 0.02: "sen",
65
+ 2: "yen",
66
+ },
67
+ }
68
+ unit = m.group(1)
69
+ currency = currencies[unit]
70
+ value = m.group(2)
71
+ return __expand_currency(value, currency)
72
+
73
+
74
+ def _expand_ordinal(m):
75
+ return _inflect.number_to_words(m.group(0))
76
+
77
+
78
+ def _expand_number(m):
79
+ num = int(m.group(0))
80
+ if 1000 < num < 3000:
81
+ if num == 2000:
82
+ return "two thousand"
83
+ if 2000 < num < 2010:
84
+ return "two thousand " + _inflect.number_to_words(num % 100)
85
+ if num % 100 == 0:
86
+ return _inflect.number_to_words(num // 100) + " hundred"
87
+ return _inflect.number_to_words(num, andword="", zero="oh", group=2).replace(", ", " ")
88
+ return _inflect.number_to_words(num, andword="")
89
+
90
+
91
+ def normalize_numbers(text):
92
+ text = re.sub(_comma_number_re, _remove_commas, text)
93
+ text = re.sub(_currency_re, _expand_currency, text)
94
+ text = re.sub(_decimal_number_re, _expand_decimal_point, text)
95
+ text = re.sub(_ordinal_re, _expand_ordinal, text)
96
+ text = re.sub(_number_re, _expand_number, text)
97
+ return text
tiny_tts/text/english_utils/time_norm.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import re
2
+
3
+ import inflect
4
+
5
+ _inflect = inflect.engine()
6
+
7
+ _time_re = re.compile(
8
+ r"""\b
9
+ ((0?[0-9])|(1[0-1])|(1[2-9])|(2[0-3])) # hours
10
+ :
11
+ ([0-5][0-9]) # minutes
12
+ \s*(a\\.m\\.|am|pm|p\\.m\\.|a\\.m|p\\.m)? # am/pm
13
+ \b""",
14
+ re.IGNORECASE | re.X,
15
+ )
16
+
17
+
18
+ def _expand_num(n: int) -> str:
19
+ return _inflect.number_to_words(n)
20
+
21
+
22
+ def _expand_time_english(match: "re.Match") -> str:
23
+ hour = int(match.group(1))
24
+ past_noon = hour >= 12
25
+ time = []
26
+ if hour > 12:
27
+ hour -= 12
28
+ elif hour == 0:
29
+ hour = 12
30
+ past_noon = True
31
+ time.append(_expand_num(hour))
32
+
33
+ minute = int(match.group(6))
34
+ if minute > 0:
35
+ if minute < 10:
36
+ time.append("oh")
37
+ time.append(_expand_num(minute))
38
+ am_pm = match.group(7)
39
+ if am_pm is None:
40
+ time.append("p m" if past_noon else "a m")
41
+ else:
42
+ time.extend(list(am_pm.replace(".", "")))
43
+ return " ".join(time)
44
+
45
+
46
+ def expand_time_english(text: str) -> str:
47
+ return re.sub(_time_re, _expand_time_english, text)
tiny_tts/text/symbols.py ADDED
@@ -0,0 +1,293 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # punctuation = ["!", "?", "…", ",", ".", "'", "-"]
2
+ punctuation = ["!", "?", "…", ",", ".", "'", "-", "¿", "¡"]
3
+ pu_symbols = punctuation + ["SP", "UNK"]
4
+ pad = "_"
5
+
6
+ # chinese
7
+ zh_symbols = [
8
+ "E",
9
+ "En",
10
+ "a",
11
+ "ai",
12
+ "an",
13
+ "ang",
14
+ "ao",
15
+ "b",
16
+ "c",
17
+ "ch",
18
+ "d",
19
+ "e",
20
+ "ei",
21
+ "en",
22
+ "eng",
23
+ "er",
24
+ "f",
25
+ "g",
26
+ "h",
27
+ "i",
28
+ "i0",
29
+ "ia",
30
+ "ian",
31
+ "iang",
32
+ "iao",
33
+ "ie",
34
+ "in",
35
+ "ing",
36
+ "iong",
37
+ "ir",
38
+ "iu",
39
+ "j",
40
+ "k",
41
+ "l",
42
+ "m",
43
+ "n",
44
+ "o",
45
+ "ong",
46
+ "ou",
47
+ "p",
48
+ "q",
49
+ "r",
50
+ "s",
51
+ "sh",
52
+ "t",
53
+ "u",
54
+ "ua",
55
+ "uai",
56
+ "uan",
57
+ "uang",
58
+ "ui",
59
+ "un",
60
+ "uo",
61
+ "v",
62
+ "van",
63
+ "ve",
64
+ "vn",
65
+ "w",
66
+ "x",
67
+ "y",
68
+ "z",
69
+ "zh",
70
+ "AA",
71
+ "EE",
72
+ "OO",
73
+ ]
74
+ num_zh_tones = 6
75
+
76
+ # japanese
77
+ ja_symbols = [
78
+ "N",
79
+ "a",
80
+ "a:",
81
+ "b",
82
+ "by",
83
+ "ch",
84
+ "d",
85
+ "dy",
86
+ "e",
87
+ "e:",
88
+ "f",
89
+ "g",
90
+ "gy",
91
+ "h",
92
+ "hy",
93
+ "i",
94
+ "i:",
95
+ "j",
96
+ "k",
97
+ "ky",
98
+ "m",
99
+ "my",
100
+ "n",
101
+ "ny",
102
+ "o",
103
+ "o:",
104
+ "p",
105
+ "py",
106
+ "q",
107
+ "r",
108
+ "ry",
109
+ "s",
110
+ "sh",
111
+ "t",
112
+ "ts",
113
+ "ty",
114
+ "u",
115
+ "u:",
116
+ "w",
117
+ "y",
118
+ "z",
119
+ "zy",
120
+ ]
121
+ num_ja_tones = 1
122
+
123
+ # English
124
+ en_symbols = [
125
+ "aa",
126
+ "ae",
127
+ "ah",
128
+ "ao",
129
+ "aw",
130
+ "ay",
131
+ "b",
132
+ "ch",
133
+ "d",
134
+ "dh",
135
+ "eh",
136
+ "er",
137
+ "ey",
138
+ "f",
139
+ "g",
140
+ "hh",
141
+ "ih",
142
+ "iy",
143
+ "jh",
144
+ "k",
145
+ "l",
146
+ "m",
147
+ "n",
148
+ "ng",
149
+ "ow",
150
+ "oy",
151
+ "p",
152
+ "r",
153
+ "s",
154
+ "sh",
155
+ "t",
156
+ "th",
157
+ "uh",
158
+ "uw",
159
+ "V",
160
+ "w",
161
+ "y",
162
+ "z",
163
+ "zh",
164
+ ]
165
+ num_en_tones = 4
166
+
167
+ # Korean
168
+ kr_symbols = ['ᄌ', 'ᅥ', 'ᆫ', 'ᅦ', 'ᄋ', 'ᅵ', 'ᄅ', 'ᅴ', 'ᄀ', 'ᅡ', 'ᄎ', 'ᅪ', 'ᄑ', 'ᅩ', 'ᄐ', 'ᄃ', 'ᅢ', 'ᅮ', 'ᆼ', 'ᅳ', 'ᄒ', 'ᄆ', 'ᆯ', 'ᆷ', 'ᄂ', 'ᄇ', 'ᄉ', 'ᆮ', 'ᄁ', 'ᅬ', 'ᅣ', 'ᄄ', 'ᆨ', 'ᄍ', 'ᅧ', 'ᄏ', 'ᆸ', 'ᅭ', '(', 'ᄊ', ')', 'ᅲ', 'ᅨ', 'ᄈ', 'ᅱ', 'ᅯ', 'ᅫ', 'ᅰ', 'ᅤ', '~', '\\', '[', ']', '/', '^', ':', 'ㄸ', '*']
169
+ num_kr_tones = 1
170
+
171
+ # Spanish
172
+ es_symbols = [
173
+ "N",
174
+ "Q",
175
+ "a",
176
+ "b",
177
+ "d",
178
+ "e",
179
+ "f",
180
+ "g",
181
+ "h",
182
+ "i",
183
+ "j",
184
+ "k",
185
+ "l",
186
+ "m",
187
+ "n",
188
+ "o",
189
+ "p",
190
+ "s",
191
+ "t",
192
+ "u",
193
+ "v",
194
+ "w",
195
+ "x",
196
+ "y",
197
+ "z",
198
+ "ɑ",
199
+ "æ",
200
+ "ʃ",
201
+ "ʑ",
202
+ "ç",
203
+ "ɯ",
204
+ "ɪ",
205
+ "ɔ",
206
+ "ɛ",
207
+ "ɹ",
208
+ "ð",
209
+ "ə",
210
+ "ɫ",
211
+ "ɥ",
212
+ "ɸ",
213
+ "ʊ",
214
+ "ɾ",
215
+ "ʒ",
216
+ "θ",
217
+ "β",
218
+ "ŋ",
219
+ "ɦ",
220
+ "ɡ",
221
+ "r",
222
+ "ɲ",
223
+ "ʝ",
224
+ "ɣ",
225
+ "ʎ",
226
+ "ˈ",
227
+ "ˌ",
228
+ "ː"
229
+ ]
230
+ num_es_tones = 1
231
+
232
+ # French
233
+ fr_symbols = [
234
+ "\u0303",
235
+ "œ",
236
+ "ø",
237
+ "ʁ",
238
+ "ɒ",
239
+ "ʌ",
240
+ "ɜ",
241
+ "ɐ"
242
+ ]
243
+ num_fr_tones = 1
244
+
245
+ # German
246
+ de_symbols = [
247
+ "ʏ",
248
+ "̩"
249
+ ]
250
+ num_de_tones = 1
251
+
252
+ # Russian
253
+ ru_symbols = [
254
+ "ɭ",
255
+ "ʲ",
256
+ "ɕ",
257
+ "\"",
258
+ "ɵ",
259
+ "^",
260
+ "ɬ"
261
+ ]
262
+ num_ru_tones = 1
263
+
264
+ # combine all symbols
265
+ normal_symbols = sorted(set(zh_symbols + ja_symbols + en_symbols + kr_symbols + es_symbols + fr_symbols + de_symbols + ru_symbols))
266
+ symbols = [pad] + normal_symbols + pu_symbols
267
+ sil_phonemes_ids = [symbols.index(i) for i in pu_symbols]
268
+
269
+ # combine all tones
270
+ num_tones = num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones
271
+
272
+ # language maps
273
+ language_id_map = {"ZH": 0, "JP": 1, "EN": 2, "ZH_MIX_EN": 3, 'KR': 4, 'ES': 5, 'SP': 5, 'FR': 6, 'DE': 7, 'RU': 8, 'VI': 9}
274
+ num_languages = 10
275
+
276
+ language_tone_start_map = {
277
+ "ZH": 0,
278
+ "ZH_MIX_EN": 0,
279
+ "JP": num_zh_tones,
280
+ "EN": num_zh_tones + num_ja_tones,
281
+ 'KR': num_zh_tones + num_ja_tones + num_en_tones,
282
+ "ES": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
283
+ "SP": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones,
284
+ "FR": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones,
285
+ "DE": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones,
286
+ "RU": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones,
287
+ "VI": num_zh_tones + num_ja_tones + num_en_tones + num_kr_tones + num_es_tones + num_fr_tones + num_de_tones + num_ru_tones,
288
+ }
289
+
290
+ if __name__ == "__main__":
291
+ a = set(zh_symbols)
292
+ b = set(en_symbols)
293
+ print(sorted(a & b))
tiny_tts/utils/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from .config import (
2
+ SAMPLING_RATE, FILTER_LENGTH, HOP_LENGTH, SEGMENT_FRAMES,
3
+ ADD_BLANK, SPEC_CHANNELS, N_SPEAKERS, SPK2ID,
4
+ MODEL_PARAMS, NUM_LANGUAGES, NUM_TONES,
5
+ )
tiny_tts/utils/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (418 Bytes). View file
 
tiny_tts/utils/__pycache__/config.cpython-310.pyc ADDED
Binary file (1.21 kB). View file
 
tiny_tts/utils/config.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Audio
2
+ SAMPLING_RATE = 44100
3
+ FILTER_LENGTH = 2048
4
+ HOP_LENGTH = 512
5
+ SEGMENT_FRAMES = 32
6
+ ADD_BLANK = True
7
+ SPEC_CHANNELS = FILTER_LENGTH // 2 + 1 # 1025
8
+
9
+ # Speakers
10
+ N_SPEAKERS = 1
11
+ SPK2ID = {"LJ": 0}
12
+
13
+ # Model
14
+ MODEL_PARAMS = dict(
15
+ use_spk_conditioned_encoder=True,
16
+ use_noise_scaled_mas=True,
17
+ inter_channels=80,
18
+ hidden_channels=80,
19
+ filter_channels=320,
20
+ n_heads=2,
21
+ n_layers=3,
22
+ n_layers_trans_flow=3,
23
+ kernel_size=3,
24
+ p_dropout=0.1,
25
+ resblock="1",
26
+ resblock_kernel_sizes=[3, 7, 11],
27
+ resblock_dilation_sizes=[[1, 3, 5], [1, 3, 5], [1, 3, 5]],
28
+ upsample_rates=[8, 8, 2, 2, 2],
29
+ upsample_initial_channel=256,
30
+ upsample_kernel_sizes=[16, 16, 8, 2, 2],
31
+ n_layers_q=3,
32
+ use_spectral_norm=False,
33
+ gin_channels=80,
34
+ use_sdp=True,
35
+ mas_noise_scale_initial=0.01,
36
+ noise_scale_delta=2e-06,
37
+ )
38
+
39
+ # Language / Tone
40
+ NUM_LANGUAGES = 10
41
+ NUM_TONES = 16