rasenganai commited on
Commit
41bc8a8
·
0 Parent(s):
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +16 -0
  2. .gitmodules +6 -0
  3. README.md +95 -0
  4. S2A/__pycache__/Diffusion.cpython-310.pyc +0 -0
  5. S2A/__pycache__/diff_model.cpython-310.pyc +0 -0
  6. S2A/__pycache__/flow_matching.cpython-310.pyc +0 -0
  7. S2A/__pycache__/inference.cpython-310.pyc +0 -0
  8. S2A/__pycache__/mel_spec.cpython-310.pyc +0 -0
  9. S2A/__pycache__/modules.cpython-310.pyc +0 -0
  10. S2A/__pycache__/stft.cpython-310.pyc +0 -0
  11. S2A/__pycache__/utilities.cpython-310.pyc +0 -0
  12. S2A/diff_model.py +335 -0
  13. S2A/flow_matching.py +123 -0
  14. S2A/inference.py +73 -0
  15. S2A/mel_spec.py +176 -0
  16. S2A/modules.py +683 -0
  17. S2A/train.py +661 -0
  18. S2A/utilities.py +175 -0
  19. Semantic_tokens/convert_factorize.py +273 -0
  20. Semantic_tokens/extract_m4t_tokens_multi.py +228 -0
  21. Semantic_tokens/seamless_communication +1 -0
  22. T2S/__pycache__/autoregressive.cpython-310.pyc +0 -0
  23. T2S/__pycache__/gpt_inference.cpython-310.pyc +0 -0
  24. T2S/__pycache__/mel_spec.cpython-310.pyc +0 -0
  25. T2S/__pycache__/t2s_modules.cpython-310.pyc +0 -0
  26. T2S/__pycache__/utilities.cpython-310.pyc +0 -0
  27. T2S/autoregressive.py +296 -0
  28. T2S/dataset.py +562 -0
  29. T2S/gpt_inference.py +198 -0
  30. T2S/gpt_model_train.py +586 -0
  31. T2S/mel_spec.py +178 -0
  32. T2S/meta_stats.py +282 -0
  33. T2S/plot_embed.py +41 -0
  34. T2S/stft.py +109 -0
  35. T2S/stream_generator.py +1002 -0
  36. T2S/t2s_modules.py +609 -0
  37. T2S/utilities.py +126 -0
  38. Text/__init__.py +1 -0
  39. Text/__pycache__/__init__.cpython-310.pyc +0 -0
  40. Text/__pycache__/symbols.cpython-310.pyc +0 -0
  41. Text/symbols.py +88 -0
  42. Text/symbols_final.pkl +0 -0
  43. Text/text_meta.txt +62 -0
  44. app.py +67 -0
  45. bigvgan_v2_24khz_100band_256x +1 -0
  46. config.py +111 -0
  47. generated_samples/0_hindi.wav +3 -0
  48. inference.py +217 -0
  49. mel_norms.pt +0 -0
  50. pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt +3 -0
.gitattributes ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ pretrained_checkpoint/m2.pt filter=lfs diff=lfs merge=lfs -text
2
+ speakers/female1/train_hindifemale_02794.wav filter=lfs diff=lfs merge=lfs -text
3
+ speakers/male1 filter=lfs diff=lfs merge=lfs -text
4
+ pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt filter=lfs diff=lfs merge=lfs -text
5
+ speakers/female1/train_hindifemale_02794_proc.wav filter=lfs diff=lfs merge=lfs -text
6
+ *.wav filter=lfs diff=lfs merge=lfs -text
7
+ pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt filter=lfs diff=lfs merge=lfs -text
8
+ speakers/female1/train_hindifemale_02795.wav filter=lfs diff=lfs merge=lfs -text
9
+ speakers/female1/train_hindifemale_04167_proc.wav filter=lfs diff=lfs merge=lfs -text
10
+ speakers/female1/train_hindifemale_04167.wav filter=lfs diff=lfs merge=lfs -text
11
+ *.mp3 filter=lfs diff=lfs merge=lfs -text
12
+ *.flac filter=lfs diff=lfs merge=lfs -text
13
+ speakers/female1 filter=lfs diff=lfs merge=lfs -text
14
+ pretrained_checkpoint/700_580k_multilingual_infer_ready/config.json filter=lfs diff=lfs merge=lfs -text
15
+ speakers/female1/train_hindifemale_02795_proc.wav filter=lfs diff=lfs merge=lfs -text
16
+ pretrained_checkpoint/700_580k_multilingual_infer_ready filter=lfs diff=lfs merge=lfs -text
.gitmodules ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ [submodule "Semantic_tokens/seamless_communication"]
2
+ path = Semantic_tokens/seamless_communication
3
+ url = https://github.com/facebookresearch/seamless_communication.git
4
+ [submodule "bigvgan_v2_24khz_100band_256x"]
5
+ path = bigvgan_v2_24khz_100band_256x
6
+ url = https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x
README.md ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div align="center">
2
+
3
+ <a href="https://ibb.co/wN1LS7K"><img width="320" height="173" alt="Screenshot-2024-01-15-at-8-14-08-PM" src="https://github.com/user-attachments/assets/af22f00d-e9d6-49e1-98b1-7efeac900f9a" /></a>
4
+
5
+ <h1>MahaTTS v2: An Open-Source Large Speech Generation Model</h1>
6
+ a <a href = "https://black.dubverse.ai">Dubverse Black</a> initiative <br> <br>
7
+
8
+ <!-- [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1qkZz2km-PX75P0f6mUb2y5e-uzub27NW?usp=sharing) -->
9
+ </div>
10
+
11
+ ------
12
+ ## Description
13
+ We introduce MahaTTS v2, a multi-speaker text-to-speech (TTS) system that has been trained on 50k hours of Indic and global languages.
14
+ We have followed a text-to-semantic-to-acoustic approach, leveraging wav2vec2 tokens, this gives out-the-box generalization to unseen low-resourced languages.
15
+ We have open sourced the first version (MahaTTS), which was trained on English and Indic languages as two separate models on 9k and 400 hours of open source datasets.
16
+ In MahaTTS v2, we have collected over 20k+ hours of training data into a single multilingual cross-lingual model.
17
+ We have used gemma as the backbone for text-to-semantic modeling and a conditional flow model for semantics to mel spectogram generation, using a BigVGAN vocoder to generate the final audio waveform.
18
+ The model has shown great robustness and quality results compared to the previous version.
19
+ We are also open sourcing the ability to finetune on your own voice.
20
+
21
+ ### With this release:
22
+ - generate voices in multiple seen and unseen speaker identities (voice cloning)
23
+ - generate voices in multiple langauges (multilingual and cross-lingual voice cloning)
24
+ - copy the style of speech from one speaker to another (cross-lingual voice cloning with prosody and intonation transfer)
25
+ - Train your own large scale pretraining or finetuning Models.
26
+
27
+ ### MahaTTS Architecture
28
+
29
+ <img width="1023" height="859" alt="Screenshot 2025-07-10 at 4 04 08 PM" src="https://github.com/user-attachments/assets/4d44cc35-4b66-41a1-b4fd-415af35eda87" />
30
+
31
+
32
+
33
+
34
+ <!-- ## Installation -->
35
+ <!--
36
+ ```bash
37
+ pip install git+https://github.com/dubverse-ai/MahaTTSv2.git
38
+ ``` -->
39
+
40
+
41
+ ### Model Params
42
+ | Model | Parameters | Model Type | Output |
43
+ |:-------------------------:|:----------:|------------|:-----------------:|
44
+ | Text to Semantic (M1) | 510 M | Causal LM | 10,001 Tokens |
45
+ | Semantic to MelSpec(M2) | 71 M | FLOW | 100x Melspec |
46
+ | BigVGAN Vocoder | 112 M | GAN | Audio Waveform |
47
+
48
+
49
+ ## 🌐 Supported Languages
50
+
51
+ The following languages are currently supported:
52
+
53
+ | Language | Status |
54
+ |------------------|:------:|
55
+ | English (en) | ✅ |
56
+ | Hindi (in) | ✅ |
57
+ | Assamese (in) | ✅ |
58
+ | Gujarati (in) | ✅ |
59
+ | Telugu (in) | ✅ |
60
+ | Punjabi (in) | ✅ |
61
+ | Marathi (in) | ✅ |
62
+ | Tamil (in) | ✅ |
63
+ | Bengali (in) | ✅ |
64
+ | Odia (in) | ✅ |
65
+ | Manipuri (in) | ✅ |
66
+ | Bhojpuri (in) | ✅ |
67
+ | Sanskrit (in) | ✅ |
68
+ | Bodo (in) | ✅ |
69
+ | Malayalam (in) | ✅ |
70
+ | Kannada (in) | ✅ |
71
+ | Dogri (in) | ✅ |
72
+ | Rajasthani (in) | ✅ |
73
+ | Thai (th) | ✅ |
74
+ | Japanese (ja) | ✅ |
75
+ | French (fr) | ✅ |
76
+ | German (de) | ✅ |
77
+ | Italian (it) | ✅ |
78
+ | Spanish (es) | ✅ |
79
+
80
+
81
+ ## TODO:
82
+ 1. Addind Training Instructions.
83
+ 2. Add a colab for the same.
84
+
85
+
86
+ ## License
87
+ MahaTTS is licensed under the Apache 2.0 License.
88
+
89
+ ## 🙏 Appreciation
90
+
91
+ - [Tortoise-tts](https://github.com/neonbjb/tortoise-tts) for inspiring the architecture
92
+ - [M4t Seamless](https://github.com/facebookresearch/seamless_communication) [AudioLM](https://arxiv.org/abs/2209.03143) and many other ground-breaking papers that enabled the development of MahaTTS
93
+ - [BIGVGAN](https://github.com/NVIDIA/BigVGAN) out of the box vocoder
94
+ - [Flow training](https://github.com/shivammehta25/Matcha-TTS) for training Flow model
95
+ - [Huggingface](https://huggingface.co/docs/transformers/index) for related training and inference code
S2A/__pycache__/Diffusion.cpython-310.pyc ADDED
Binary file (35.4 kB). View file
 
S2A/__pycache__/diff_model.cpython-310.pyc ADDED
Binary file (8.19 kB). View file
 
S2A/__pycache__/flow_matching.cpython-310.pyc ADDED
Binary file (3.84 kB). View file
 
S2A/__pycache__/inference.cpython-310.pyc ADDED
Binary file (1.9 kB). View file
 
S2A/__pycache__/mel_spec.cpython-310.pyc ADDED
Binary file (4.77 kB). View file
 
S2A/__pycache__/modules.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
S2A/__pycache__/stft.cpython-310.pyc ADDED
Binary file (3.23 kB). View file
 
S2A/__pycache__/utilities.cpython-310.pyc ADDED
Binary file (4.05 kB). View file
 
S2A/diff_model.py ADDED
@@ -0,0 +1,335 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ from torch import autocast
7
+
8
+ from config import config
9
+
10
+ from .modules import GST, AttentionBlock, mySequential, normalization
11
+
12
+
13
+ def timestep_embedding(timesteps, dim, max_period=10000):
14
+ """
15
+ Create sinusoidal timestep embeddings.
16
+
17
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
18
+ These may be fractional.
19
+ :param dim: the dimension of the output.
20
+ :param max_period: controls the minimum frequency of the embeddings.
21
+ :return: an [N x dim] Tensor of positional embeddings.
22
+ """
23
+ half = dim // 2
24
+ freqs = torch.exp(
25
+ -math.log(max_period)
26
+ * torch.arange(start=0, end=half, dtype=torch.float32)
27
+ / half
28
+ ).to(device=timesteps.device)
29
+ args = timesteps[:, None].float() * freqs[None]
30
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
31
+ if dim % 2:
32
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
33
+ return embedding
34
+
35
+
36
+ class TimestepBlock(nn.Module):
37
+ def forward(self, x, emb):
38
+ """
39
+ Apply the module to `x` given `emb` timestep embeddings.
40
+ """
41
+
42
+
43
+ class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
44
+ def forward(self, x, emb):
45
+ for layer in self:
46
+ if isinstance(layer, TimestepBlock):
47
+ x = layer(x, emb)
48
+ else:
49
+ x = layer(x)
50
+ return x
51
+
52
+
53
+ class QuartzNetBlock(TimestepBlock):
54
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
55
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ nin,
61
+ nout,
62
+ emb_channels,
63
+ kernel_size=3,
64
+ dropout=0.1,
65
+ R=1,
66
+ se=True,
67
+ ratio=8,
68
+ separable=False,
69
+ bias=True,
70
+ use_scale_shift_norm=True,
71
+ ):
72
+ super(QuartzNetBlock, self).__init__()
73
+ self.use_scale_shift_norm = use_scale_shift_norm
74
+ self.se = se
75
+ self.in_layers = mySequential(
76
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
77
+ nn.SiLU(),
78
+ normalization(nout),
79
+ )
80
+
81
+ if nin == nout:
82
+ self.residual = nn.Identity()
83
+ else:
84
+ self.residual = nn.Conv1d(
85
+ nin, nout, kernel_size=1, padding="same", bias=bias
86
+ )
87
+
88
+ nin = nout
89
+ self.model = nn.Sequential(
90
+ nn.Conv1d(nin, nout, kernel_size, padding="same"),
91
+ nn.SiLU(),
92
+ normalization(nout),
93
+ nn.Dropout(p=dropout),
94
+ )
95
+
96
+ self.emb_layers = nn.Sequential(
97
+ nn.Linear(
98
+ emb_channels,
99
+ 2 * nout if use_scale_shift_norm else nout,
100
+ ),
101
+ nn.SiLU(),
102
+ )
103
+
104
+ def forward(self, x, emb, mask=None):
105
+ x_new = self.in_layers(x)
106
+ emb = self.emb_layers(emb)
107
+ while len(emb.shape) < len(x_new.shape):
108
+ emb = emb[..., None]
109
+ scale, shift = torch.chunk(emb, 2, dim=1)
110
+ x_new = x_new * (1 + scale) + shift
111
+ y = self.model(x_new)
112
+
113
+ return y + self.residual(x)
114
+
115
+
116
+ class QuartzAttn(TimestepBlock):
117
+ def __init__(self, model_channels, dropout, num_heads):
118
+ super().__init__()
119
+ self.resblk = QuartzNetBlock(
120
+ model_channels,
121
+ model_channels,
122
+ model_channels,
123
+ dropout=dropout,
124
+ use_scale_shift_norm=True,
125
+ )
126
+ self.attn = AttentionBlock(
127
+ model_channels, num_heads, relative_pos_embeddings=True
128
+ )
129
+
130
+ def forward(self, x, time_emb):
131
+ y = self.resblk(x, time_emb)
132
+ return self.attn(y)
133
+
134
+
135
+ class QuartzNet9x5(nn.Module):
136
+ def __init__(self, model_channels, num_heads, dropout=0.1, enable_fp16=False):
137
+ super(QuartzNet9x5, self).__init__()
138
+ self.enable_fp16 = enable_fp16
139
+ kernels = [3] * 10
140
+ quartznet = []
141
+ attn = []
142
+ for i in kernels:
143
+ quartznet.append(
144
+ QuartzNetBlock(
145
+ model_channels,
146
+ model_channels,
147
+ model_channels,
148
+ kernel_size=i,
149
+ dropout=dropout,
150
+ R=5,
151
+ se=True,
152
+ )
153
+ )
154
+ attn.append(
155
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
156
+ )
157
+
158
+ self.quartznet = nn.ModuleList(quartznet)
159
+ self.attn = nn.ModuleList(attn)
160
+ self.conv2 = nn.ModuleList(
161
+ [
162
+ QuartzNetBlock(
163
+ model_channels,
164
+ model_channels,
165
+ model_channels,
166
+ kernel_size=3,
167
+ dropout=dropout,
168
+ R=3,
169
+ separable=False,
170
+ )
171
+ for i in range(3)
172
+ ]
173
+ )
174
+ self.conv3 = nn.Sequential(
175
+ nn.Conv1d(model_channels, model_channels, 3, padding="same"),
176
+ nn.SiLU(),
177
+ normalization(model_channels),
178
+ nn.Conv1d(model_channels, 100, 1, padding="same"),
179
+ )
180
+
181
+ def forward(self, x, time_emb):
182
+ for n, (layer, attn) in enumerate(zip(self.quartznet, self.attn)):
183
+ x = layer(x, time_emb) # 256 dim
184
+ x = attn(x)
185
+ for layer in self.conv2:
186
+ x = layer(x, time_emb)
187
+
188
+ x = self.conv3(x)
189
+ return x
190
+
191
+
192
+ class DiffModel(nn.Module):
193
+ def __init__(
194
+ self,
195
+ input_channels=80,
196
+ output_channels=160,
197
+ model_channels=256,
198
+ num_heads=8,
199
+ dropout=0.1,
200
+ num_layers=8,
201
+ multispeaker=True,
202
+ style_tokens=100,
203
+ enable_fp16=False,
204
+ condition_free_per=0.1,
205
+ training=False,
206
+ ar_active=False,
207
+ in_latent_channels=10004,
208
+ ):
209
+ super().__init__()
210
+ self.input_channels = input_channels
211
+ self.model_channels = model_channels
212
+ self.output_channels = output_channels
213
+ self.num_heads = num_heads
214
+ self.dropout = dropout
215
+ self.num_layers = num_layers
216
+ self.enable_fp16 = enable_fp16
217
+ self.condition_free_per = condition_free_per
218
+ self.training = training
219
+ self.multispeaker = multispeaker
220
+ self.ar_active = ar_active
221
+ self.in_latent_channels = in_latent_channels
222
+
223
+ if not self.ar_active:
224
+ self.code_emb = nn.Embedding(
225
+ config.semantic_model_centroids + 1, model_channels
226
+ )
227
+ self.code_converter = mySequential(
228
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
229
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
230
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
231
+ )
232
+ else:
233
+ self.code_converter = mySequential(
234
+ nn.Conv1d(
235
+ self.in_latent_channels, model_channels, 3, padding=1, bias=True
236
+ ),
237
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
238
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
239
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
240
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
241
+ )
242
+ if self.multispeaker:
243
+ self.GST = GST(
244
+ model_channels, style_tokens, num_heads, in_channels=input_channels
245
+ )
246
+
247
+ self.code_norm = normalization(model_channels)
248
+ self.time_norm = normalization(model_channels)
249
+ self.code_time_norm = normalization(model_channels)
250
+
251
+ self.time_embed = mySequential(
252
+ nn.Linear(model_channels, model_channels),
253
+ nn.SiLU(),
254
+ nn.Linear(model_channels, model_channels),
255
+ )
256
+
257
+ self.input_block = nn.Conv1d(input_channels, model_channels, 3, 1, 1, bias=True)
258
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
259
+ self.integrating_conv = nn.Conv1d(
260
+ model_channels * 2, model_channels, kernel_size=1
261
+ )
262
+
263
+ self.code_time = TimestepEmbedSequential(
264
+ QuartzAttn(model_channels, dropout, num_heads),
265
+ QuartzAttn(model_channels, dropout, num_heads),
266
+ QuartzAttn(model_channels, dropout, num_heads),
267
+ )
268
+
269
+ self.layers = QuartzNet9x5(
270
+ model_channels, num_heads, self.enable_fp16, self.dropout
271
+ )
272
+
273
+ def get_speaker_latent(self, ref_mels):
274
+ ref_mels = ref_mels.unsqueeze(1) if len(ref_mels.shape) == 3 else ref_mels
275
+
276
+ conds = []
277
+ for j in range(ref_mels.shape[1]):
278
+ conds.append(self.GST(ref_mels[:, j, :, :]))
279
+
280
+ conds = torch.cat(conds, dim=-1)
281
+ conds = conds.mean(dim=-1)
282
+
283
+ return conds.unsqueeze(2)
284
+
285
+ def forward(
286
+ self,
287
+ x,
288
+ t,
289
+ code_emb,
290
+ ref_clips=None,
291
+ speaker_latents=None,
292
+ conditioning_free=False,
293
+ ):
294
+ time_embed = self.time_norm(
295
+ self.time_embed(
296
+ timestep_embedding(t.unsqueeze(-1), self.model_channels)
297
+ ).permute(0, 2, 1)
298
+ ).squeeze(2)
299
+ if conditioning_free:
300
+ code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
301
+ else:
302
+ if not self.ar_active:
303
+ code_embed = self.code_norm(
304
+ self.code_converter(self.code_emb(code_emb).permute(0, 2, 1))
305
+ )
306
+ else:
307
+ code_embed = self.code_norm(self.code_converter(code_emb))
308
+ if self.multispeaker:
309
+ assert speaker_latents is not None or ref_clips is not None
310
+ if ref_clips is not None:
311
+ speaker_latents = self.get_speaker_latent(ref_clips)
312
+ cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
313
+ code_embed = code_embed * (1 + cond_scale) + cond_shift
314
+
315
+ if self.training and self.condition_free_per > 0:
316
+ unconditioned_batches = (
317
+ torch.rand((code_embed.shape[0], 1, 1), device=code_embed.device)
318
+ < self.condition_free_per
319
+ )
320
+ code_embed = torch.where(
321
+ unconditioned_batches,
322
+ self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
323
+ code_embed,
324
+ )
325
+
326
+ expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode="linear")
327
+
328
+ x_cond = self.code_time_norm(self.code_time(expanded_code_emb, time_embed))
329
+
330
+ x = self.input_block(x)
331
+ x = torch.cat([x, x_cond], dim=1)
332
+ x = self.integrating_conv(x)
333
+ out = self.layers(x, time_embed)
334
+
335
+ return out
S2A/flow_matching.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from abc import ABC
2
+
3
+ import torch
4
+ import torch.nn.functional as F
5
+
6
+ # log = get_pylogger(__name__)
7
+
8
+
9
+ class BASECFM(torch.nn.Module):
10
+ def __init__(
11
+ self,
12
+ ):
13
+ super().__init__()
14
+ # self.n_feats = n_feats
15
+ # self.n_spks = n_spks
16
+ # self.spk_emb_dim = spk_emb_dim
17
+ # self.solver = cfm_params.solver
18
+ # if hasattr(cfm_params, "sigma_min"):
19
+ # self.sigma_min = cfm_params.sigma_min
20
+ # else:
21
+ self.sigma_min = 1e-4 # 0.0#1e-4
22
+
23
+ @torch.inference_mode()
24
+ def forward(
25
+ self, model, code, output_shape, ref_mels, n_timesteps=20, temperature=1.0
26
+ ):
27
+ """Forward diffusion
28
+
29
+ Args:
30
+ mu (torch.Tensor): output of encoder
31
+ shape: (batch_size, n_feats, mel_timesteps)
32
+ mask (torch.Tensor): output_mask
33
+ shape: (batch_size, 1, mel_timesteps)
34
+ n_timesteps (int): number of diffusion steps
35
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
36
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
37
+ shape: (batch_size, spk_emb_dim)
38
+ cond: Not used but kept for future purposes
39
+
40
+ Returns:
41
+ sample: generated mel-spectrogram
42
+ shape: (batch_size, n_feats, mel_timesteps)
43
+ """
44
+ z = torch.randn(output_shape, device=code.device) * temperature
45
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=code.device)
46
+ return self.solve_euler(model, z, t_span=t_span, code=code, ref_mels=ref_mels)
47
+
48
+ def solve_euler(self, model, x, t_span, code, ref_mels):
49
+ """
50
+ Fixed euler solver for ODEs.
51
+ Args:
52
+ x (torch.Tensor): random noise
53
+ t_span (torch.Tensor): n_timesteps interpolated
54
+ shape: (n_timesteps + 1,)
55
+ mu (torch.Tensor): output of encoder
56
+ shape: (batch_size, n_feats, mel_timesteps)
57
+ mask (torch.Tensor): output_mask
58
+ shape: (batch_size, 1, mel_timesteps)
59
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
60
+ shape: (batch_size, spk_emb_dim)
61
+ cond: Not used but kept for future purposes
62
+ """
63
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
64
+
65
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
66
+ # Or in future might add like a return_all_steps flag
67
+ sol = []
68
+
69
+ for step in range(1, len(t_span)):
70
+ dphi_dt = model(x, t.unsqueeze(0), code_emb=code, ref_clips=ref_mels)
71
+
72
+ x = x + dt * dphi_dt
73
+ t = t + dt
74
+ sol.append(x)
75
+ if step < len(t_span) - 1:
76
+ dt = t_span[step + 1] - t
77
+
78
+ return sol[-1]
79
+
80
+ def compute_loss(self, model, x1, mask, code, ref_mels):
81
+ """Computes diffusion loss
82
+
83
+ Args:
84
+ x1 (torch.Tensor): Target
85
+ shape: (batch_size, n_feats, mel_timesteps)
86
+ mask (torch.Tensor): target mask
87
+ shape: (batch_size, 1, mel_timesteps)
88
+ mu (torch.Tensor): output of encoder
89
+ shape: (batch_size, n_feats, mel_timesteps)
90
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
91
+ shape: (batch_size, spk_emb_dim)
92
+
93
+ Returns:
94
+ loss: conditional flow matching loss
95
+ y: conditional flow
96
+ shape: (batch_size, n_feats, mel_timesteps)
97
+ """
98
+ b, _, t = x1.shape
99
+
100
+ # random timestep
101
+ t = torch.rand([b, 1, 1], device=x1.device, dtype=x1.dtype)
102
+ # sample noise p(x_0)
103
+ z = torch.randn_like(x1)
104
+
105
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
106
+ u = x1 - (1 - self.sigma_min) * z
107
+
108
+ # wrong weightage
109
+ # loss = F.mse_loss(model(y,t.squeeze(),code_emb=code,ref_clips=ref_mels), u, reduction='none').mean(dim=-2) # B,80,t -> B,t
110
+ # loss *= mask # B,t
111
+ # loss = loss.sum(-1) / mask.sum(-1) # B,t -> B
112
+ # loss = loss.sum()/loss.shape[0] # B -> 1
113
+
114
+ loss = torch.sum(
115
+ F.mse_loss(
116
+ model(y, t.squeeze(), code_emb=code, ref_clips=ref_mels),
117
+ u,
118
+ reduction="none",
119
+ )
120
+ * mask.unsqueeze(1)
121
+ ) / (mask.sum() * u.shape[1])
122
+
123
+ return loss, y, t
S2A/inference.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(
5
+ os.path.abspath(
6
+ os.path.join(os.path.dirname(__file__), "../bigvgan_v2_24khz_100band_256x/")
7
+ )
8
+ )
9
+
10
+ import bigvgan
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ from pydub import AudioSegment
16
+ from tqdm import tqdm
17
+
18
+ from config import config
19
+
20
+ from .flow_matching import BASECFM
21
+ from .utilities import denormalize_tacotron_mel, normalize_tacotron_mel
22
+
23
+
24
+ def infer(model, timeshapes, code_embs, ref_mels, epoch=0):
25
+ os.makedirs("Samples/" + config.model_name + "/S2A/", exist_ok=True)
26
+ FM = BASECFM()
27
+ device = next(model.parameters()).device
28
+
29
+ hifi = bigvgan.BigVGAN.from_pretrained(
30
+ "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
31
+ )
32
+ hifi.remove_weight_norm()
33
+ hifi = hifi.eval().to(device)
34
+
35
+ audio_paths = []
36
+ mels = []
37
+ for n, (timeshape, code_emb, ref_mel) in enumerate(
38
+ zip(timeshapes, code_embs, ref_mels)
39
+ ):
40
+ with torch.no_grad():
41
+ mel = FM(
42
+ model,
43
+ code_emb.unsqueeze(0).to(device),
44
+ (1, 100, timeshape),
45
+ ref_mel.unsqueeze(0).to(device),
46
+ n_timesteps=20,
47
+ temperature=1.0,
48
+ )
49
+ mel = denormalize_tacotron_mel(mel)
50
+ mels.append(mel)
51
+ audio = hifi(mel)
52
+ audio = audio.squeeze(0).detach().cpu()
53
+ audio = audio * 32767.0
54
+ audio = audio.numpy().reshape(-1).astype(np.int16)
55
+
56
+ audio_path = (
57
+ "../Samples/"
58
+ + config.model_name
59
+ + "/S2A/"
60
+ + str(epoch)
61
+ + "_"
62
+ + str(n)
63
+ + ".wav"
64
+ )
65
+ AudioSegment(
66
+ audio.tobytes(),
67
+ frame_rate=24000,
68
+ sample_width=audio.dtype.itemsize,
69
+ channels=1,
70
+ ).export(audio_path, format="wav")
71
+ audio_paths.append(audio_path)
72
+
73
+ return audio_paths, mels
S2A/mel_spec.py ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ if __name__ == "__main__":
8
+ import os
9
+ import sys
10
+
11
+ sys.path.append("../")
12
+
13
+ import math
14
+ import os
15
+ import pathlib
16
+ import random
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.utils.data
21
+ from librosa.filters import mel as librosa_mel_fn
22
+ from librosa.util import normalize
23
+ from scipy.io.wavfile import read
24
+ from tqdm import tqdm
25
+
26
+ from config import config
27
+
28
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
29
+
30
+
31
+ def load_wav(full_path, sr_target):
32
+ sampling_rate, data = read(full_path)
33
+ if sampling_rate != sr_target:
34
+ raise RuntimeError(
35
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
36
+ )
37
+ return data, sampling_rate
38
+
39
+
40
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
41
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
42
+
43
+
44
+ def dynamic_range_decompression(x, C=1):
45
+ return np.exp(x) / C
46
+
47
+
48
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
49
+ return torch.log(torch.clamp(x, min=clip_val) * C)
50
+
51
+
52
+ def dynamic_range_decompression_torch(x, C=1):
53
+ return torch.exp(x) / C
54
+
55
+
56
+ def spectral_normalize_torch(magnitudes):
57
+ return dynamic_range_compression_torch(magnitudes)
58
+
59
+
60
+ def spectral_de_normalize_torch(magnitudes):
61
+ return dynamic_range_decompression_torch(magnitudes)
62
+
63
+
64
+ mel_basis_cache = {}
65
+ hann_window_cache = {}
66
+
67
+
68
+ def mel_spectrogram(
69
+ y: torch.Tensor,
70
+ n_fft: int,
71
+ num_mels: int,
72
+ sampling_rate: int,
73
+ hop_size: int,
74
+ win_size: int,
75
+ fmin: int,
76
+ fmax: int = None,
77
+ center: bool = False,
78
+ ) -> torch.Tensor:
79
+ """
80
+ Calculate the mel spectrogram of an input signal.
81
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
82
+
83
+ Args:
84
+ y (torch.Tensor): Input signal.
85
+ n_fft (int): FFT size.
86
+ num_mels (int): Number of mel bins.
87
+ sampling_rate (int): Sampling rate of the input signal.
88
+ hop_size (int): Hop size for STFT.
89
+ win_size (int): Window size for STFT.
90
+ fmin (int): Minimum frequency for mel filterbank.
91
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
92
+ center (bool): Whether to pad the input to center the frames. Default is False.
93
+
94
+ Returns:
95
+ torch.Tensor: Mel spectrogram.
96
+ """
97
+ if torch.min(y) < -1.0:
98
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
99
+ if torch.max(y) > 1.0:
100
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
101
+
102
+ device = y.device
103
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
104
+
105
+ if key not in mel_basis_cache:
106
+ mel = librosa_mel_fn(
107
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
108
+ )
109
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
110
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
111
+
112
+ mel_basis = mel_basis_cache[key]
113
+ hann_window = hann_window_cache[key]
114
+
115
+ padding = (n_fft - hop_size) // 2
116
+ y = torch.nn.functional.pad(
117
+ y.unsqueeze(1), (padding, padding), mode="reflect"
118
+ ).squeeze(1)
119
+
120
+ spec = torch.stft(
121
+ y,
122
+ n_fft,
123
+ hop_length=hop_size,
124
+ win_length=win_size,
125
+ window=hann_window,
126
+ center=center,
127
+ pad_mode="reflect",
128
+ normalized=False,
129
+ onesided=True,
130
+ return_complex=True,
131
+ )
132
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
133
+
134
+ mel_spec = torch.matmul(mel_basis, spec)
135
+ mel_spec = spectral_normalize_torch(mel_spec)
136
+
137
+ return mel_spec
138
+
139
+
140
+ def get_mel_spectrogram(wav, sr):
141
+ """
142
+ Generate mel spectrogram from a waveform using given hyperparameters.
143
+
144
+ Args:
145
+ wav (torch.Tensor): Input waveform.
146
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
147
+
148
+ Returns:
149
+ torch.Tensor: Mel spectrogram.
150
+ """
151
+
152
+ assert sr == config.sampling_rate
153
+
154
+ return mel_spectrogram(
155
+ wav,
156
+ config.filter_length,
157
+ config.n_mel_channels,
158
+ config.sampling_rate,
159
+ config.hop_length,
160
+ config.win_length,
161
+ config.mel_fmin,
162
+ config.mel_fmax,
163
+ )
164
+
165
+
166
+ if __name__ == "__main__":
167
+ import torchaudio
168
+
169
+ path = "/delta/NeuralSpeak_cfm_conv/Samples/IITM_cfm_bigv_harsh/S2A/orig/0_test.wav"
170
+ wav, sr = torchaudio.load(path)
171
+
172
+ wav = wav[:, :sr]
173
+ print(wav.shape)
174
+ mel_spec = get_mel_spectrogram(wav, sr)
175
+ duration = wav.shape[-1] / sr
176
+ print(duration, mel_spec.shape)
S2A/modules.py ADDED
@@ -0,0 +1,683 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+ from einops import rearrange, repeat
8
+ from torch.nn.utils import weight_norm
9
+
10
+
11
+ def zero_module(module):
12
+ """
13
+ Zero out the parameters of a module and return it.
14
+ Using it for Zero Convolutions
15
+ """
16
+ for p in module.parameters():
17
+ p.detach().zero_()
18
+ return module
19
+
20
+
21
+ class GroupNorm32(nn.GroupNorm):
22
+ def forward(self, x):
23
+ return super().forward(x.float()).type(x.dtype)
24
+
25
+
26
+ def normalization(channels):
27
+ """
28
+ Make a standard normalization layer. of groups ranging from 2 to 32.
29
+
30
+ :param channels: number of input channels.
31
+ :return: an nn.Module for normalization.
32
+ """
33
+ # return nn.LayerNorm(normalized_shape)
34
+ groups = 32
35
+ if channels <= 16:
36
+ groups = 8
37
+ elif channels <= 64:
38
+ groups = 16
39
+ while channels % groups != 0:
40
+ groups = int(groups / 2)
41
+ assert groups > 2
42
+ return GroupNorm32(groups, channels)
43
+
44
+
45
+ class mySequential(nn.Sequential):
46
+ """Using this to pass mask variable to nn layers"""
47
+
48
+ def forward(self, *inputs):
49
+ for module in self._modules.values():
50
+ if type(inputs) == tuple:
51
+ inputs = module(*inputs)
52
+ else:
53
+ inputs = module(inputs)
54
+ return inputs
55
+
56
+
57
+ class SepConv1D(nn.Module):
58
+ """Depth wise separable Convolution layer with mask"""
59
+
60
+ def __init__(
61
+ self,
62
+ nin,
63
+ nout,
64
+ kernel_size,
65
+ stride=1,
66
+ dilation=1,
67
+ padding_mode="same",
68
+ bias=False,
69
+ ):
70
+ super(SepConv1D, self).__init__()
71
+ self.conv1 = nn.Conv1d(
72
+ nin,
73
+ nin,
74
+ kernel_size=kernel_size,
75
+ stride=stride,
76
+ groups=nin,
77
+ dilation=dilation,
78
+ padding=padding_mode,
79
+ bias=bias,
80
+ )
81
+ self.conv2 = nn.Conv1d(
82
+ nin, nout, kernel_size=1, stride=1, padding=padding_mode, bias=bias
83
+ )
84
+
85
+ def forward(self, x, mask=None):
86
+ if mask is not None:
87
+ x = x * mask.unsqueeze(1).to(device=x.device)
88
+ x = self.conv1(x)
89
+ x = self.conv2(x)
90
+ return x, mask
91
+
92
+
93
+ class Conv1DBN(nn.Module):
94
+ def __init__(
95
+ self,
96
+ nin,
97
+ nout,
98
+ kernel_size,
99
+ stride=1,
100
+ dilation=1,
101
+ dropout=0.1,
102
+ padding_mode="same",
103
+ bias=False,
104
+ ):
105
+ super(Conv1DBN, self).__init__()
106
+ self.conv1 = nn.Conv1d(
107
+ nin,
108
+ nout,
109
+ kernel_size=kernel_size,
110
+ stride=stride,
111
+ padding=padding_mode,
112
+ dilation=dilation,
113
+ bias=bias,
114
+ )
115
+ self.bn = nn.BatchNorm1d(nout)
116
+ self.drop = nn.Dropout(dropout)
117
+
118
+ def forward(self, x, mask=None):
119
+ if mask is not None:
120
+ x = x * mask.unsqueeze(1).to(device=x.device)
121
+ x = self.conv1(x)
122
+ x = self.bn(x)
123
+ x = F.silu(x)
124
+ x = self.drop(x)
125
+ return x, mask
126
+
127
+
128
+ class Conv1d(nn.Module):
129
+ """normal conv1d with mask"""
130
+
131
+ def __init__(self, nin, nout, kernel_size, padding, bias=False):
132
+ super(Conv1d, self).__init__()
133
+ self.l = nn.Conv1d(nin, nout, kernel_size, padding=padding, bias=bias)
134
+
135
+ def forward(self, x, mask):
136
+ if mask is not None:
137
+ x = x * mask.unsqueeze(1).to(device=x.device)
138
+ y = self.l(x)
139
+ return y, mask
140
+
141
+
142
+ class SqueezeExcite(nn.Module):
143
+ """Let the CNN decide how to add across channels"""
144
+
145
+ def __init__(self, nin, ratio=8):
146
+ super(SqueezeExcite, self).__init__()
147
+ self.nin = nin
148
+ self.ratio = ratio
149
+
150
+ self.fc = mySequential(
151
+ nn.Linear(nin, nin // ratio, bias=True),
152
+ nn.SiLU(inplace=True),
153
+ nn.Linear(nin // ratio, nin, bias=True),
154
+ )
155
+
156
+ def forward(self, x, mask=None):
157
+ if mask is None:
158
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=torch.bool).to(x.device)
159
+ mask = ~mask
160
+ x = x.float()
161
+ x.masked_fill_(mask.unsqueeze(1), 0.0)
162
+ mask = ~mask
163
+ y = (
164
+ torch.sum(x, dim=-1, keepdim=True)
165
+ / mask.unsqueeze(1).sum(dim=-1, keepdim=True)
166
+ ).type(x.dtype)
167
+ # y=torch.mean(x,-1,keepdim=True)
168
+ y = y.transpose(1, -1)
169
+ y = self.fc(y)
170
+ y = torch.sigmoid(y)
171
+ y = y.transpose(1, -1)
172
+ y = x * y
173
+ return y, mask
174
+
175
+
176
+ class SCBD(nn.Module):
177
+ """SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet"""
178
+
179
+ def __init__(
180
+ self, nin, nout, kernel_size, p=0.1, rd=True, separable=True, bias=False
181
+ ):
182
+ super(SCBD, self).__init__()
183
+ if separable:
184
+ self.SC = SepConv1D(nin, nout, kernel_size, bias=bias)
185
+ else:
186
+ self.SC = Conv1d(nin, nout, kernel_size, padding="same", bias=bias)
187
+
188
+ if rd: # relu and Dropout
189
+ self.mout = mySequential(
190
+ normalization(nout),
191
+ nn.SiLU(), # nn.BatchNorm1d(nout,eps)
192
+ nn.Dropout(p),
193
+ )
194
+ else:
195
+ self.mout = normalization(nout) # nn.BatchNorm1d(nout,eps)
196
+
197
+ def forward(self, x, mask=None):
198
+ if mask is not None:
199
+ x = x * mask.unsqueeze(1).to(device=x.device)
200
+ x, _ = self.SC(x, mask)
201
+ y = self.mout(x)
202
+ return y, mask
203
+
204
+
205
+ class QuartzNetBlock(nn.Module):
206
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
207
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
208
+ """
209
+
210
+ def __init__(
211
+ self,
212
+ nin,
213
+ nout,
214
+ kernel_size,
215
+ dropout=0.1,
216
+ R=5,
217
+ se=False,
218
+ ratio=8,
219
+ separable=False,
220
+ bias=False,
221
+ ):
222
+ super(QuartzNetBlock, self).__init__()
223
+ self.se = se
224
+ self.residual = mySequential(
225
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
226
+ normalization(nout), # nn.BatchNorm1d(nout,eps)
227
+ )
228
+ model = []
229
+
230
+ for i in range(R - 1):
231
+ model.append(SCBD(nin, nout, kernel_size, dropout, eps=0.001, bias=bias))
232
+ nin = nout
233
+
234
+ if separable:
235
+ model.append(
236
+ SCBD(nin, nout, kernel_size, dropout, eps=0.001, rd=False, bias=bias)
237
+ )
238
+ else:
239
+ model.append(
240
+ SCBD(
241
+ nin,
242
+ nout,
243
+ kernel_size,
244
+ dropout,
245
+ eps=0.001,
246
+ rd=False,
247
+ separable=False,
248
+ bias=bias,
249
+ )
250
+ )
251
+ self.model = mySequential(*model)
252
+
253
+ if self.se:
254
+ # model.append(SqueezeExcite(nin,ratio))
255
+ self.se_layer = SqueezeExcite(nin, ratio)
256
+
257
+ self.mout = mySequential(nn.SiLU(), nn.Dropout(dropout))
258
+
259
+ def forward(self, x, mask=None):
260
+ if mask is not None:
261
+ x = x * mask.unsqueeze(1).to(device=x.device)
262
+ y, _ = self.model(x, mask)
263
+ if self.se:
264
+ y, _ = self.se_layer(y, mask)
265
+ y += self.residual(x)
266
+ y = self.mout(y)
267
+ return y, mask
268
+
269
+
270
+ class QKVAttentionLegacy(nn.Module):
271
+ """
272
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
273
+ """
274
+
275
+ def __init__(self, n_heads):
276
+ super().__init__()
277
+ self.n_heads = n_heads
278
+
279
+ def forward(self, qkv, mask=None, rel_pos=None):
280
+ """
281
+ Apply QKV attention.
282
+
283
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
284
+ :return: an [N x (H * C) x T] tensor after attention.
285
+ """
286
+ bs, width, length = qkv.shape
287
+ assert width % (3 * self.n_heads) == 0
288
+ ch = width // (3 * self.n_heads)
289
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
290
+ scale = 1 / math.sqrt(math.sqrt(ch))
291
+ weight = torch.einsum(
292
+ "bct,bcs->bts", q * scale, k * scale
293
+ ) # More stable with f16 than dividing afterwards
294
+ if rel_pos is not None:
295
+ weight = rel_pos(
296
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
297
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
298
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
299
+ if mask is not None:
300
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
301
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
302
+ weight = weight * mask
303
+ a = torch.einsum("bts,bcs->bct", weight, v)
304
+
305
+ return a.reshape(bs, -1, length)
306
+
307
+
308
+ class AttentionBlock(nn.Module):
309
+ """
310
+ An attention block that allows spatial positions to attend to each other.
311
+
312
+ Originally ported from here, but adapted to the N-d case.
313
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
314
+ """
315
+
316
+ def __init__(
317
+ self,
318
+ channels,
319
+ num_heads=1,
320
+ num_head_channels=-1,
321
+ do_checkpoint=True,
322
+ relative_pos_embeddings=False,
323
+ ):
324
+ super().__init__()
325
+ self.channels = channels
326
+ self.do_checkpoint = do_checkpoint
327
+ if num_head_channels == -1:
328
+ self.num_heads = num_heads
329
+ else:
330
+ assert channels % num_head_channels == 0, (
331
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
332
+ )
333
+ self.num_heads = channels // num_head_channels
334
+ self.norm = normalization(channels)
335
+ self.qkv = nn.Conv1d(channels, channels * 3, 1, bias=False)
336
+ # split heads before split qkv
337
+ self.attention = QKVAttentionLegacy(self.num_heads)
338
+
339
+ self.proj_out = zero_module(
340
+ nn.Conv1d(channels, channels, 1, bias=False)
341
+ ) # no effect of attention in the inital stages.
342
+ # if relative_pos_embeddings:
343
+ self.relative_pos_embeddings = RelativePositionBias(
344
+ scale=(channels // self.num_heads) ** 0.5,
345
+ causal=False,
346
+ heads=num_heads,
347
+ num_buckets=64,
348
+ max_distance=128,
349
+ )
350
+
351
+ def forward(self, x, mask=None):
352
+ b, c, *spatial = x.shape
353
+ x = x.reshape(b, c, -1)
354
+ qkv = self.qkv(self.norm(x))
355
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
356
+ h = self.proj_out(h)
357
+ return (x + h).reshape(b, c, *spatial)
358
+
359
+
360
+ class AbsolutePositionalEmbedding(nn.Module):
361
+ def __init__(self, dim, max_seq_len):
362
+ super().__init__()
363
+ self.scale = dim**-0.5
364
+ self.emb = nn.Embedding(max_seq_len, dim)
365
+
366
+ def forward(self, x):
367
+ n = torch.arange(x.shape[1], device=x.device)
368
+ pos_emb = self.emb(n)
369
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
370
+ return pos_emb * self.scale
371
+
372
+
373
+ class FixedPositionalEmbedding(nn.Module):
374
+ def __init__(self, dim):
375
+ super().__init__()
376
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
377
+ self.register_buffer("inv_freq", inv_freq)
378
+
379
+ def forward(self, x, seq_dim=1, offset=0):
380
+ t = (
381
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
382
+ + offset
383
+ )
384
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
385
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
386
+ return rearrange(emb, "n d -> () n d")
387
+
388
+
389
+ class RelativePositionBias(nn.Module):
390
+ def __init__(self, scale, causal=False, num_buckets=16, max_distance=32, heads=8):
391
+ super().__init__()
392
+ self.scale = scale
393
+ self.causal = causal
394
+ self.num_buckets = num_buckets
395
+ self.max_distance = max_distance
396
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
397
+
398
+ @staticmethod
399
+ def _relative_position_bucket(
400
+ relative_position, causal=True, num_buckets=16, max_distance=32
401
+ ):
402
+ ret = 0
403
+ n = -relative_position
404
+ if not causal:
405
+ num_buckets //= 2
406
+ ret += (n < 0).long() * num_buckets
407
+ n = torch.abs(n)
408
+ else:
409
+ n = torch.max(n, torch.zeros_like(n))
410
+
411
+ max_exact = num_buckets // 2
412
+ is_small = n < max_exact
413
+
414
+ val_if_large = (
415
+ max_exact
416
+ + (
417
+ torch.log(n.float() / max_exact)
418
+ / math.log(max_distance / max_exact)
419
+ * (num_buckets - max_exact)
420
+ ).long()
421
+ )
422
+ val_if_large = torch.min(
423
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
424
+ )
425
+
426
+ ret += torch.where(is_small, n, val_if_large)
427
+ return ret
428
+
429
+ def forward(self, qk_dots):
430
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
431
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
432
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
433
+ rel_pos = k_pos[None, :] - q_pos[:, None]
434
+ rp_bucket = self._relative_position_bucket(
435
+ rel_pos,
436
+ causal=self.causal,
437
+ num_buckets=self.num_buckets,
438
+ max_distance=self.max_distance,
439
+ )
440
+ values = self.relative_attention_bias(rp_bucket)
441
+ bias = rearrange(values, "i j h -> () h i j")
442
+ return qk_dots + (bias * self.scale)
443
+
444
+
445
+ class MultiHeadAttention(nn.Module):
446
+ """
447
+ only for GST
448
+ input:
449
+ query --- [N, T_q, query_dim]
450
+ key --- [N, T_k, key_dim]
451
+ output:
452
+ out --- [N, T_q, num_units]
453
+ """
454
+
455
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
456
+ super().__init__()
457
+ self.num_units = num_units
458
+ self.num_heads = num_heads
459
+ self.key_dim = key_dim
460
+
461
+ self.W_query = nn.Linear(
462
+ in_features=query_dim, out_features=num_units, bias=False
463
+ )
464
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
465
+ self.W_value = nn.Linear(
466
+ in_features=key_dim, out_features=num_units, bias=False
467
+ )
468
+
469
+ def forward(self, query, key):
470
+ querys = self.W_query(query) # [N, T_q, num_units]
471
+ keys = self.W_key(key) # [N, T_k, num_units]
472
+ values = self.W_value(key)
473
+
474
+ split_size = self.num_units // self.num_heads
475
+ querys = torch.stack(
476
+ torch.split(querys, split_size, dim=2), dim=0
477
+ ) # [h, N, T_q, num_units/h]
478
+ keys = torch.stack(
479
+ torch.split(keys, split_size, dim=2), dim=0
480
+ ) # [h, N, T_k, num_units/h]
481
+ values = torch.stack(
482
+ torch.split(values, split_size, dim=2), dim=0
483
+ ) # [h, N, T_k, num_units/h]
484
+
485
+ # score = softmax(QK^T / (d_k ** 0.5))
486
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
487
+ scores = scores / (self.key_dim**0.5)
488
+ scores = F.softmax(scores, dim=3)
489
+
490
+ # out = score * V
491
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
492
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
493
+ 0
494
+ ) # [N, T_q, num_units]
495
+
496
+ return out
497
+
498
+
499
+ class GST(nn.Module):
500
+ def __init__(
501
+ self, model_channels=512, style_tokens=100, num_heads=8, in_channels=100
502
+ ):
503
+ super(GST, self).__init__()
504
+ self.model_channels = model_channels
505
+ self.style_tokens = style_tokens
506
+ self.num_heads = num_heads
507
+
508
+ # self.reference_encoder=nn.Sequential(
509
+ # nn.Conv2d(1,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
510
+ # nn.Conv2d(32,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
511
+ # nn.Conv2d(32,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
512
+ # nn.Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
513
+ # AttentionBlock(64, 8, relative_pos_embeddings=True),
514
+ # nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
515
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
516
+ # nn.Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
517
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
518
+ # nn.Conv2d(128,model_channels,kernel_size=(3,3),stride=(1,1),padding=(1, 1)),normalization(model_channels),nn.ReLU(inplace=True),
519
+ # AttentionBlock(model_channels, 16, relative_pos_embeddings=True)
520
+ # )
521
+
522
+ # self.reference_encoder=nn.Sequential(
523
+ # nn.Conv1d(80,model_channels,3,padding=1,stride=2),
524
+ # nn.Conv1d(model_channels, model_channels,3,padding=1,stride=2),
525
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
526
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
527
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
528
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
529
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
530
+ # )
531
+
532
+ # in_channels=1
533
+ # num_heads = 8
534
+ self.reference_encoder = nn.Sequential(
535
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2, bias=False),
536
+ nn.Conv1d(
537
+ model_channels, model_channels * 2, 3, padding=1, stride=2, bias=False
538
+ ),
539
+ AttentionBlock(
540
+ model_channels * 2,
541
+ num_heads,
542
+ relative_pos_embeddings=True,
543
+ do_checkpoint=False,
544
+ ),
545
+ AttentionBlock(
546
+ model_channels * 2,
547
+ num_heads,
548
+ relative_pos_embeddings=True,
549
+ do_checkpoint=False,
550
+ ),
551
+ AttentionBlock(
552
+ model_channels * 2,
553
+ num_heads,
554
+ relative_pos_embeddings=True,
555
+ do_checkpoint=False,
556
+ ),
557
+ AttentionBlock(
558
+ model_channels * 2,
559
+ num_heads,
560
+ relative_pos_embeddings=True,
561
+ do_checkpoint=False,
562
+ ),
563
+ AttentionBlock(
564
+ model_channels * 2,
565
+ num_heads,
566
+ relative_pos_embeddings=True,
567
+ do_checkpoint=False,
568
+ ),
569
+ # nn.Conv1d(model_channels*2, 64,3,padding=1,stride=2),
570
+ # nn.Conv1d(64, model_channels*2,3,padding=1,stride=2) #added bottleneck
571
+ )
572
+ # bottleneck = 64
573
+ # self.bottleneck = nn.Sequential(nn.Conv1d(model_channels*2,bottleneck,3,padding=1,stride=1),nn.SiLU(),
574
+ # nn.Conv1d(bottleneck,model_channels*2,3,padding=1,stride=1),nn.SiLU())
575
+ # self.gru=nn.GRU(128*2,256,batch_first=True,bidirectional=True)
576
+ # self.attention = MultiHeadAttention(query_dim=model_channels, key_dim=model_channels//num_heads, num_units=model_channels*2, num_heads=num_heads)
577
+ # self.style_tokens = nn.parameter.Parameter(torch.FloatTensor(style_tokens,model_channels//num_heads))
578
+
579
+ # init.normal_(self.style_tokens, mean=0, std=0.5)
580
+
581
+ def forward(self, x):
582
+ # add masking
583
+ # batch=x.size(0)
584
+ # x=x.view(batch,1,-1,80) # (N,1,t,80)
585
+ x = self.reference_encoder(x) # (N,128,t,80//x)
586
+ # print(x.shape)
587
+ # x = self.bottleneck(x)
588
+ # print(x.shape)
589
+ # print(x.shape,'encoder')
590
+ # x = x.mean(dim=-1)#.mean(dim=-1)
591
+ # # x=x.transpose(1,2).contiguous() #(N,t,128,80//x)
592
+ # # time=x.size(1)
593
+ # # x=x.view(batch,time,-1)
594
+ # # _,x=self.gru(x)
595
+ # # print(x.shape,'gru')
596
+ # x=x.view(batch,1,-1)
597
+ # keys = self.style_tokens.unsqueeze(0).expand(batch, -1, -1) # [N, token_num, E // num_heads]
598
+ # # print(keys.shape,'keys')
599
+ # style_embed = self.attention(x, keys)
600
+ # # print(style_embed.shape,'gst tokens')
601
+
602
+ # add normalization?
603
+
604
+ return x
605
+
606
+
607
+ # class GST(nn.Module):
608
+ # """
609
+ # inputs --- [N, Ty/r, n_mels*r] mels
610
+ # outputs --- [N, ref_enc_gru_size]
611
+ # """
612
+
613
+ # def __init__(self, spec_channels=80, gin_channels=512, layernorm=True):
614
+ # super().__init__()
615
+ # self.spec_channels = spec_channels
616
+ # ref_enc_filters = [32, 32, 64, 64, 128, 128]
617
+ # K = len(ref_enc_filters)
618
+ # filters = [1] + ref_enc_filters
619
+ # convs = [
620
+ # weight_norm(
621
+ # nn.Conv2d(
622
+ # in_channels=filters[i],
623
+ # out_channels=filters[i + 1],
624
+ # kernel_size=(3, 3),
625
+ # stride=(2, 2),
626
+ # padding=(1, 1),
627
+ # )
628
+ # )
629
+ # for i in range(K)
630
+ # ]
631
+ # self.convs = nn.ModuleList(convs)
632
+
633
+ # out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
634
+ # self.gru = nn.GRU(
635
+ # input_size=ref_enc_filters[-1] * out_channels,
636
+ # hidden_size=256 // 2,
637
+ # batch_first=True,
638
+ # )
639
+ # self.proj = nn.Linear(128, gin_channels)
640
+ # if layernorm:
641
+ # self.layernorm = nn.LayerNorm(self.spec_channels)
642
+ # else:
643
+ # self.layernorm = None
644
+
645
+ # def forward(self, inputs, mask=None):
646
+ # N = inputs.size(0)
647
+
648
+ # out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
649
+ # if self.layernorm is not None:
650
+ # out = self.layernorm(out)
651
+
652
+ # for conv in self.convs:
653
+ # out = conv(out)
654
+ # # out = wn(out)
655
+ # out = F.silu(out) # [N, 128, Ty//2^K, n_mels//2^K]
656
+
657
+ # out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
658
+ # T = out.size(1)
659
+ # N = out.size(0)
660
+ # out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
661
+
662
+ # self.gru.flatten_parameters()
663
+ # memory, out = self.gru(out) # out --- [1, N, 128]
664
+
665
+ # return self.proj(out.squeeze(0))
666
+
667
+ # def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
668
+ # for i in range(n_convs):
669
+ # L = (L - kernel_size + 2 * pad) // stride + 1
670
+ # return L
671
+
672
+
673
+ if __name__ == "__main__":
674
+ device = torch.device("cpu")
675
+ m = GST(512, 10).to(device)
676
+ mels = torch.rand((16, 80, 1000)).to(device)
677
+
678
+ o = m(mels)
679
+ print(o.shape, "final output")
680
+
681
+ from torchinfo import summary
682
+
683
+ summary(m, input_data={"x": torch.randn(16, 80, 500).to(device)})
S2A/train.py ADDED
@@ -0,0 +1,661 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import time
4
+
5
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
6
+ import linecache
7
+ import mmap
8
+ import pickle as pkl
9
+ import random
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.optim as optim
16
+ import torchaudio
17
+ from accelerate import Accelerator, DistributedDataParallelKwargs
18
+ from mel_spec import get_mel_spectrogram
19
+ from torch.distributed import destroy_process_group, init_process_group
20
+ from torch.distributed.elastic.utils.data import ElasticDistributedSampler
21
+ from torch.nn.parallel import DistributedDataParallel as DDP
22
+ from torch.utils.data import DataLoader, Dataset
23
+ from torch.utils.data.distributed import DistributedSampler
24
+ from tqdm.auto import tqdm
25
+
26
+ import wandb
27
+ from config import config
28
+ from S2A.diff_model import DiffModel
29
+ from S2A.flow_matching import BASECFM
30
+ from S2A.inference import infer
31
+ from S2A.utilities import (dynamic_range_compression, get_mask,
32
+ get_mask_from_lengths, load_wav_to_torch,
33
+ normalize_tacotron_mel)
34
+ from Text import code_labels, labels, text_labels
35
+
36
+ # import torch
37
+ torch.backends.cuda.matmul.allow_tf32 = True
38
+ torch.backends.cudnn.allow_tf32 = True
39
+
40
+ torch.manual_seed(config.seed_value)
41
+ np.random.seed(config.seed_value)
42
+ random.seed(config.seed_value)
43
+
44
+ CLIP_LENGTH = config.CLIP_LENGTH
45
+
46
+ # code encdec
47
+ text_enc = {j: i for i, j in enumerate(text_labels)}
48
+ text_dec = {i: j for i, j in enumerate(text_labels)}
49
+
50
+ # text encdec
51
+ code_enc = {j: i for i, j in enumerate(code_labels)}
52
+ code_dec = {i: j for i, j in enumerate(code_labels)}
53
+
54
+
55
+ def read_specific_line(filename, line_number):
56
+ line = linecache.getline(filename, line_number)
57
+ return line.strip()
58
+
59
+
60
+ class Acoustic_dataset(Dataset):
61
+ def __init__(
62
+ self,
63
+ transcript_path,
64
+ semantic_path=None,
65
+ ref_mels_path=None,
66
+ ref_k=1,
67
+ scale=True,
68
+ ar_active=False,
69
+ clip=True,
70
+ dur_=None,
71
+ ):
72
+ super(Acoustic_dataset).__init__()
73
+ self.scale = scale
74
+ self.ar_active = ar_active
75
+ self.clip = clip
76
+ self.dur_ = dur_
77
+ if self.dur_ is None:
78
+ self.dur_ = 2
79
+ if not scale:
80
+ with open(transcript_path, "r") as file:
81
+ data = file.read().strip("\n").split("\n")[:]
82
+
83
+ with open(semantic_path, "r") as file:
84
+ semb = file.read().strip("\n").split("\n")
85
+
86
+ with open(ref_mels_path, "rb") as file:
87
+ self.ref_mels = pkl.load(file)
88
+
89
+ semb = {
90
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
91
+ }
92
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
93
+
94
+ self.data = [[i, semb[i], data[i]] for i in data.keys()][:]
95
+
96
+ else:
97
+ self.transcript_path = transcript_path
98
+ line_index = {}
99
+ with open(transcript_path, "rb") as file:
100
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
101
+ line_number = 0
102
+ offset = 0
103
+ progress_bar = tqdm(desc="processing:")
104
+ while offset < len(mmapped_file):
105
+ line_index[line_number] = offset
106
+ offset = mmapped_file.find(b"\n", offset) + 1
107
+ line_number += 1
108
+ progress_bar.update(1)
109
+ progress_bar.close()
110
+ self.mmapped_file = mmapped_file
111
+ self.data_len = len(line_index)
112
+ self.line_index = line_index
113
+
114
+ self.ref_k = ref_k
115
+ self.max_wav_value = config.MAX_WAV_VALUE
116
+
117
+ def get_mel(self, filepath, semb_ids=None, align=False, ref_clip=False):
118
+ audio_norm, sampling_rate = torchaudio.load(filepath)
119
+ dur = audio_norm.shape[-1] / sampling_rate
120
+
121
+ if self.clip and dur > self.dur_ and align:
122
+ max_audio_start = int(dur - self.dur_)
123
+ if max_audio_start <= 0:
124
+ audio_start = 0
125
+ else:
126
+ audio_start = np.random.randint(0, max_audio_start)
127
+
128
+ audio_norm = audio_norm[
129
+ :,
130
+ audio_start * sampling_rate : (audio_start + self.dur_) * sampling_rate,
131
+ ]
132
+ semb_ids = semb_ids[audio_start * 50 : ((audio_start + self.dur_) * 50) - 1]
133
+
134
+ # 86 mel -> 1s for 22050 setting
135
+ # 93 mel -> 1s for 24000 setting
136
+
137
+ if ref_clip == True:
138
+ dur_ = 6
139
+ max_audio_start = int(dur - dur_)
140
+ if max_audio_start <= 0:
141
+ audio_start = 0
142
+ else:
143
+ audio_start = np.random.randint(0, max_audio_start)
144
+ audio_norm = audio_norm[
145
+ :, audio_start * sampling_rate : (audio_start + dur_) * sampling_rate
146
+ ]
147
+
148
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
149
+ energy = []
150
+ if align:
151
+ return melspec, list(energy), semb_ids
152
+ return melspec, list(energy)
153
+
154
+ def __len__(self):
155
+ if self.scale:
156
+ return self.data_len
157
+ return len(self.data)
158
+
159
+ def __getitem__(self, index):
160
+ """
161
+ mel_spec,semb
162
+ """
163
+ if not self.scale:
164
+ lang, path, semb, text = self.data[index]
165
+ ref_mels = self.ref_mels[path][: self.ref_k]
166
+ semb_ids = [int(i) + 1 for i in semb] # 0 for pad
167
+
168
+ else:
169
+ self.mmapped_file.seek(self.line_index[index])
170
+ line = self.mmapped_file.readline().decode("utf-8")
171
+
172
+ lang, path, text, semb_ids = line.split("|")
173
+ semb_ids = [int(i) + 1 for i in semb_ids.split()]
174
+ ref_mels = [path][: self.ref_k]
175
+
176
+ try:
177
+ mel_spec, energy, semb_ids = self.get_mel(path, semb_ids, align=True)
178
+ if len(semb_ids) == 0:
179
+ raise Exception("Sorry, no semb ids" + str(line))
180
+ except Exception as e:
181
+ print(index, e)
182
+ if index + 1 < self.data_len:
183
+ return self.__getitem__(index + 1)
184
+ return self.__getitem__(0)
185
+
186
+ if len(ref_mels) == 0:
187
+ print(index, e, "no ref mels present")
188
+ if index + 1 < self.data_len:
189
+ return self.__getitem__(index + 1)
190
+ return self.__getitem__(0)
191
+
192
+ while len(ref_mels) < self.ref_k:
193
+ ref_mels.append(ref_mels[-1])
194
+
195
+ if mel_spec is None:
196
+ print(index, e, "mel_spec error present")
197
+ if index + 1 < self.data_len:
198
+ return self.__getitem__(index + 1)
199
+ return self.__getitem__(0)
200
+
201
+ def get_random_portion(mel, mask_lengths):
202
+ clip = mask_lengths <= CLIP_LENGTH
203
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
204
+ for n, z in enumerate(clip):
205
+ if not z:
206
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
207
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
208
+ return ref_mel
209
+
210
+ try:
211
+ ref_mels = [self.get_mel(path, ref_clip=True)[0] for path in ref_mels]
212
+ except Exception as e:
213
+ print(index, e, "ref_mels mel_spec error")
214
+ if index + 1 < self.data_len:
215
+ return self.__getitem__(index + 1)
216
+ return self.__getitem__(0)
217
+
218
+ ref_c = []
219
+ for i in range(self.ref_k):
220
+ if ref_mels[i] is None:
221
+ continue
222
+ ref_c.append(ref_mels[i])
223
+
224
+ if len(ref_c) == 0:
225
+ print("no refs mel spec found")
226
+ if index + 1 < self.data_len:
227
+ return self.__getitem__(index + 1)
228
+ return self.__getitem__(0)
229
+
230
+ if len(ref_c) != self.ref_k:
231
+ while len(ref_c) < self.ref_k:
232
+ ref_c.append(ref_c[-1])
233
+
234
+ ref_mels = ref_c
235
+ max_target_len = max([x.size(1) for x in ref_mels])
236
+ ref_mels_padded = (
237
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len)) * 1e-9
238
+ )
239
+ mel_length = []
240
+ for i, mel in enumerate(ref_mels):
241
+ ref_mels_padded[i, :, : mel.size(1)] = mel
242
+ mel_length.append(mel.shape[-1])
243
+
244
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
245
+
246
+ text_ids = (
247
+ [text_enc["<S>"]]
248
+ + [text_enc[i] for i in text.strip() if i in text_enc]
249
+ + [text_enc["<E>"]]
250
+ )
251
+ if self.ar_active:
252
+ semb_ids = (
253
+ [code_enc["<SST>"]]
254
+ + [code_enc[str(i - 1)] for i in semb_ids]
255
+ + [code_enc["<EST>"]]
256
+ )
257
+
258
+ return {
259
+ "mel": mel_spec,
260
+ "code": semb_ids,
261
+ "path": path,
262
+ "ref_mels": ref_mels,
263
+ "text": text_ids,
264
+ } # , 'ref_mel_length':mel_length}
265
+
266
+
267
+ def get_padded_seq(sequences, pad_random, before=False, pad__=0):
268
+ max_len = max([len(s) for s in sequences])
269
+ seq_len = []
270
+ for i in range(len(sequences)):
271
+ seq_len.append(len(sequences[i]))
272
+ if pad_random:
273
+ pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
274
+ else:
275
+ pad_ = [pad__] * (max_len - len(sequences[i]))
276
+ if not before:
277
+ sequences[i] = sequences[i] + pad_
278
+ else:
279
+ sequences[i] = pad_ + sequences[i]
280
+
281
+ return sequences, seq_len
282
+
283
+
284
+ def collate(batch):
285
+ mel_specs = []
286
+ code = []
287
+ paths = []
288
+ ref_mels = []
289
+ text_ids = []
290
+
291
+ for b in batch:
292
+ mel_specs.append(b["mel"])
293
+ code.append(b["code"])
294
+ paths.append(b["path"])
295
+ ref_mels.append(b["ref_mels"])
296
+ text_ids.append(b["text"])
297
+
298
+ if code[-1][-1] == code_enc["<EST>"]:
299
+ code, code_len = get_padded_seq(code, pad_random=False, pad__=code_enc["<PAD>"])
300
+ else:
301
+ code, code_len = get_padded_seq(code, pad_random=False)
302
+
303
+ text_ids, text_len = get_padded_seq(
304
+ text_ids, pad_random=False, before=True, pad__=text_enc["<PAD>"]
305
+ )
306
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
307
+ ref_mels_padded = (
308
+ torch.randn(
309
+ (
310
+ len(batch),
311
+ ref_mels[0].shape[0],
312
+ config.n_mel_channels,
313
+ ref_max_target_len,
314
+ )
315
+ )
316
+ * 1e-9
317
+ )
318
+
319
+ for i, mel in enumerate(ref_mels):
320
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
321
+
322
+ max_target_len = max([x.size(-1) for x in mel_specs])
323
+ mel_padded = torch.randn((len(batch), config.n_mel_channels, max_target_len)) * 1e-9
324
+ mel_length = []
325
+ for i, mel in enumerate(mel_specs):
326
+ mel_padded[i, :, : mel.size(-1)] = mel
327
+ mel_length.append(mel.shape[-1])
328
+
329
+ return (
330
+ normalize_tacotron_mel(mel_padded),
331
+ torch.tensor(code),
332
+ torch.tensor(mel_length),
333
+ torch.tensor(code_len),
334
+ ref_mels_padded,
335
+ torch.tensor(text_ids),
336
+ torch.tensor(text_len),
337
+ paths,
338
+ )
339
+
340
+
341
+ def train(
342
+ model,
343
+ diffuser,
344
+ train_dataloader,
345
+ val_dataloader,
346
+ schedule_sampler=None,
347
+ rank=0,
348
+ ar_active=False,
349
+ m1=None,
350
+ checkpoint_initial=None,
351
+ ):
352
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
353
+ accelerator = Accelerator(
354
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
355
+ kwargs_handlers=[ddp_kwargs],
356
+ )
357
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
358
+ conf_ = {}
359
+ for i, j in config.__dict__.items():
360
+ conf_[str(i)] = str(j)
361
+ wandb_log = wandb.init(
362
+ project=config.wandb_project,
363
+ entity=config.user_name,
364
+ name=config.model_name,
365
+ config=conf_,
366
+ )
367
+ wandb_log.watch(model, log_freq=100)
368
+ else:
369
+ wandb_log = None
370
+
371
+ model.train()
372
+ optimizer = optim.AdamW(
373
+ model.parameters(), lr=config.sa_lr, weight_decay=config.sa_weight_decay
374
+ )
375
+ lr = config.sa_lr
376
+ min_val_loss = 1000
377
+ step_num = 0
378
+ start_epoch = 0
379
+ if checkpoint_initial is not None:
380
+ print(checkpoint_initial)
381
+ model.load_state_dict(
382
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["model"],
383
+ strict=True,
384
+ )
385
+ model.train()
386
+ optimizer.load_state_dict(
387
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
388
+ "optimizer"
389
+ ]
390
+ )
391
+ step_num = int(
392
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["step"]
393
+ )
394
+ step_num = 0
395
+ start_epoch = (
396
+ int(
397
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
398
+ "epoch"
399
+ ]
400
+ )
401
+ + 1
402
+ )
403
+ print(f"resuming training from epoch {start_epoch} and step {step_num}")
404
+
405
+ train_dataloader, model, optimizer = accelerator.prepare(
406
+ train_dataloader, model, optimizer
407
+ )
408
+
409
+ FM = BASECFM()
410
+ device = next(model.parameters()).device
411
+ if ar_active:
412
+ m1 = m1.to(device)
413
+
414
+ loading_time = []
415
+ for i in range(start_epoch, config.sa_epochs):
416
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
417
+ if accelerator.is_local_main_process:
418
+ train_loader = tqdm(train_dataloader, desc="Training epoch %d" % (i))
419
+ else:
420
+ train_loader = train_dataloader
421
+
422
+ for inputs in train_loader:
423
+ with accelerator.accumulate(model):
424
+ optimizer.zero_grad()
425
+ x1, code_emb, mask_lengths, _, ref_mels, text_ids, _, _ = inputs
426
+ mask = get_mask_from_lengths(mask_lengths).unsqueeze(1)
427
+ mask = mask.squeeze(1).float()
428
+
429
+ loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels)
430
+
431
+ accelerator.backward(loss)
432
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
433
+ optimizer.step()
434
+ step_num += 1
435
+
436
+ epoch_loss["loss"].append(loss.item())
437
+
438
+ if step_num % config.gradient_accumulation_steps == 0:
439
+ epoch_training_loss = torch.tensor(
440
+ sum(epoch_loss["loss"]) / len(epoch_loss["loss"])
441
+ ).to(device)
442
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
443
+ epoch_training_loss = (
444
+ accelerator.gather_for_metrics(epoch_training_loss)
445
+ .mean()
446
+ .item()
447
+ )
448
+
449
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
450
+ wandb_log.log({"training_loss": epoch_training_loss})
451
+
452
+ if (
453
+ step_num % (config.sa_eval_step * config.gradient_accumulation_steps)
454
+ == 0
455
+ ):
456
+ print(f"evaluation at step_num {step_num}")
457
+ if accelerator.is_local_main_process:
458
+ # save the latest checkpoint
459
+ unwrapped_model = accelerator.unwrap_model(model)
460
+ checkpoint = {
461
+ "epoch": i,
462
+ "step": step_num // config.gradient_accumulation_steps,
463
+ "model": unwrapped_model.state_dict(),
464
+ "optimizer": optimizer.state_dict(),
465
+ "norms": config.norms,
466
+ }
467
+ torch.save(
468
+ checkpoint,
469
+ os.path.join(config.save_root_dir, "latest.pt",),
470
+ )
471
+
472
+ if accelerator.is_local_main_process:
473
+ val_loss, val_mse, val_vlb, time_steps_mean = val(
474
+ model,
475
+ FM,
476
+ val_dataloader,
477
+ infer_=config.sa_infer,
478
+ epoch=i,
479
+ rank=accelerator.is_local_main_process,
480
+ ar_active=ar_active,
481
+ m1=m1,
482
+ )
483
+ model.train()
484
+
485
+ print(
486
+ "validation loss : ",
487
+ val_loss,
488
+ "\nvalidation mse loss : ",
489
+ val_mse,
490
+ "\nvalidation vlb loss : ",
491
+ val_vlb,
492
+ )
493
+ if config.sa_wandb_logs:
494
+ wandb_log.log({"val_loss": val_loss})
495
+ if val_loss < min_val_loss:
496
+ unwrapped_model = accelerator.unwrap_model(model)
497
+ checkpoint = {
498
+ "epoch": i,
499
+ "step": step_num // config.gradient_accumulation_steps,
500
+ "model": unwrapped_model.state_dict(),
501
+ "optimizer": optimizer.state_dict(),
502
+ "norms": config.norms,
503
+ }
504
+ torch.save(
505
+ checkpoint,
506
+ os.path.join(config.save_root_dir, "_best.pt"),
507
+ )
508
+ min_val_loss = val_loss
509
+
510
+ if i == start_epoch + 12:
511
+ exit()
512
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
513
+ wandb_log.finish()
514
+
515
+
516
+ def val(
517
+ model,
518
+ FM,
519
+ val_dataloader,
520
+ infer_=False,
521
+ epoch=0,
522
+ rank=False,
523
+ ar_active=False,
524
+ m1=None,
525
+ ):
526
+ """
527
+ Return the loss value
528
+ """
529
+ model.eval()
530
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
531
+ code_emb = None
532
+ x = None
533
+ mask_lengths = None
534
+ time_steps_mean = []
535
+ device = next(model.parameters()).device
536
+ if rank:
537
+ val_dataloader = tqdm(val_dataloader, desc="validation epoch %d" % (epoch))
538
+ else:
539
+ val_dataloader = val_dataloader
540
+
541
+ with torch.no_grad():
542
+ for inputs in val_dataloader:
543
+ x1, code_emb, mask_lengths, code_len, ref_mels, text_ids, _, _ = inputs
544
+
545
+ mask = get_mask_from_lengths(mask_lengths).unsqueeze(1).to(device)
546
+ mask = mask.squeeze(1).float()
547
+ x1 = x1.to(device)
548
+ code_emb = code_emb.to(device)
549
+ text_ids = text_ids.to(device)
550
+ ref_mels = ref_mels.to(device)
551
+
552
+ loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels)
553
+ time_steps_mean.extend(t.detach().cpu().squeeze(-1).squeeze(-1).tolist())
554
+ mse = loss
555
+ vlb = loss
556
+
557
+ epoch_loss["loss"].append(loss.item())
558
+ epoch_loss["mse"].append(mse.item())
559
+ epoch_loss["vlb"].append(vlb.item())
560
+
561
+ epoch_vlb_loss = sum(epoch_loss["vlb"]) / len(epoch_loss["vlb"])
562
+ epoch_training_loss = sum(epoch_loss["loss"]) / len(epoch_loss["loss"])
563
+ epoch_mse_loss = sum(epoch_loss["mse"]) / len(epoch_loss["mse"])
564
+ if rank and infer_ and epoch % config.sa_infer_epoch == 0:
565
+ k = 4
566
+ if ar_active:
567
+ code_embs = [code_emb[i, :, : code_len[i]] for i in range(k)]
568
+ else:
569
+ code_embs = [code_emb[i, : code_len[i]] for i in range(k)]
570
+ audio_paths, mels = infer(
571
+ model, mask_lengths[:k], code_embs, ref_mels[:k, :], epoch
572
+ )
573
+
574
+ if config.sa_wandb_logs:
575
+ images = [
576
+ wandb.Image(mel[0], caption="epoch: " + str(epoch)) for mel in mels
577
+ ]
578
+ x = [
579
+ wandb.Image(x1[i, :, : mask_lengths[i]], caption="Actual: ")
580
+ for i in range(k)
581
+ ]
582
+ wandb.log(
583
+ {
584
+ "predicted audio": [
585
+ wandb.Audio(audio_path) for audio_path in audio_paths
586
+ ],
587
+ "predicted melspec": images,
588
+ "actual melspec": x,
589
+ "epoch": epoch,
590
+ }
591
+ )
592
+
593
+ return (
594
+ epoch_training_loss,
595
+ epoch_mse_loss,
596
+ epoch_vlb_loss,
597
+ sum(time_steps_mean) / len(time_steps_mean),
598
+ )
599
+
600
+
601
+ if __name__ == "__main__":
602
+ os.makedirs(os.path.join(config.save_root_dir, config.model_name, "S2A"), exist_ok=True)
603
+
604
+ model = DiffModel(
605
+ input_channels=100,
606
+ output_channels=100,
607
+ model_channels=512, # 1024
608
+ num_heads=8, # 16
609
+ dropout=0.10,
610
+ num_layers=8,
611
+ enable_fp16=False,
612
+ condition_free_per=0.0,
613
+ multispeaker=True,
614
+ style_tokens=100,
615
+ training=True,
616
+ ar_active=False,
617
+ in_latent_channels=len(code_labels),
618
+ )
619
+ m1 = None
620
+ checkpoint = None
621
+ print("Model Loaded")
622
+ print("batch_size :", config.sa_batch_size)
623
+ print("Diffusion timesteps:", config.sa_timesteps_max)
624
+
625
+ file_name_train = config.train_file
626
+ file_name_val = config.val_file
627
+
628
+ train_dataset = Acoustic_dataset(file_name_train, scale=config.scale)
629
+ train_dataloader = DataLoader(
630
+ train_dataset,
631
+ pin_memory=True,
632
+ persistent_workers=True,
633
+ num_workers=config.sa_num_workers,
634
+ batch_size=config.sa_batch_size,
635
+ shuffle=True,
636
+ drop_last=False,
637
+ collate_fn=collate,
638
+ )
639
+
640
+ val_dataset = Acoustic_dataset(file_name_val, scale=config.scale, dur_=5)
641
+ val_dataloader = DataLoader(
642
+ val_dataset,
643
+ pin_memory=True,
644
+ persistent_workers=True,
645
+ num_workers=config.sa_num_workers,
646
+ batch_size=config.sa_batch_size,
647
+ shuffle=True,
648
+ drop_last=True,
649
+ collate_fn=collate,
650
+ )
651
+
652
+ train(
653
+ model,
654
+ diffuser=None,
655
+ train_dataloader=train_dataloader,
656
+ val_dataloader=val_dataloader,
657
+ rank=0,
658
+ ar_active=False,
659
+ m1=m1,
660
+ checkpoint_initial=checkpoint,
661
+ )
S2A/utilities.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa.util as librosa_util
2
+ import numpy as np
3
+ import torch
4
+ from scipy.io.wavfile import read
5
+ from scipy.signal import get_window
6
+
7
+ # import librosa
8
+ from config import config
9
+
10
+ # find these values
11
+ # TACOTRON_MEL_MAX = 2.4
12
+ # TACOTRON_MEL_MIN = -11.5130
13
+ # tensor(-11.5129) tensor(2.0743) indic tts
14
+ # tensor(-11.5129) tensor(2.3314) LibriTTS
15
+ # tensor(-11.5129) tensor(2.3996) Eng10k scale
16
+
17
+
18
+ def normalize_tacotron_mel(data, mu=config.mu, std=config.std):
19
+ # return data
20
+ if not isinstance(mu, (float, int)):
21
+ if isinstance(mu, list):
22
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
23
+ elif isinstance(mu, torch.Tensor):
24
+ mu = mu.to(data.device)
25
+ elif isinstance(mu, np.ndarray):
26
+ mu = torch.from_numpy(mu).to(data.device)
27
+ mu = mu.unsqueeze(-1)
28
+
29
+ if not isinstance(std, (float, int)):
30
+ if isinstance(std, list):
31
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
32
+ elif isinstance(std, torch.Tensor):
33
+ std = std.to(data.device)
34
+ elif isinstance(std, np.ndarray):
35
+ std = torch.from_numpy(std).to(data.device)
36
+ std = std.unsqueeze(-1)
37
+
38
+ return (data - mu) / std
39
+
40
+
41
+ def denormalize_tacotron_mel(data, mu=config.mu, std=config.std):
42
+ # return data
43
+ if not isinstance(mu, float):
44
+ if isinstance(mu, list):
45
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
46
+ elif isinstance(mu, torch.Tensor):
47
+ mu = mu.to(data.device)
48
+ elif isinstance(mu, np.ndarray):
49
+ mu = torch.from_numpy(mu).to(data.device)
50
+ mu = mu.unsqueeze(-1)
51
+
52
+ if not isinstance(std, float):
53
+ if isinstance(std, list):
54
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
55
+ elif isinstance(std, torch.Tensor):
56
+ std = std.to(data.device)
57
+ elif isinstance(std, np.ndarray):
58
+ std = torch.from_numpy(std).to(data.device)
59
+ std = std.unsqueeze(-1)
60
+
61
+ return data * std + mu
62
+
63
+
64
+ # def denormalize_tacotron_mel(norm_mel):
65
+ # return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
66
+
67
+
68
+ # def normalize_tacotron_mel(mel):
69
+ # return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
70
+
71
+ # min_max to global scaling
72
+
73
+
74
+ def get_mask_from_lengths(lengths, max_len=None):
75
+ if not max_len:
76
+ max_len = torch.max(lengths).item()
77
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
78
+ mask = (ids < lengths.unsqueeze(1)).bool()
79
+ return mask
80
+
81
+
82
+ def get_mask(lengths, max_len=None):
83
+ if not max_len:
84
+ max_len = torch.max(lengths).item()
85
+ lens = torch.arange(
86
+ max_len,
87
+ )
88
+ mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1)
89
+ return mask
90
+
91
+
92
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
93
+ """
94
+ PARAMS
95
+ ------
96
+ C: compression factor
97
+ """
98
+ return torch.log(torch.clamp(x, min=clip_val) * C)
99
+
100
+
101
+ def dynamic_range_decompression(x, C=1):
102
+ """
103
+ PARAMS
104
+ ------
105
+ C: compression factor used to compress
106
+ """
107
+ return torch.exp(x) / C
108
+
109
+
110
+ def window_sumsquare(
111
+ window,
112
+ n_frames,
113
+ hop_length=200,
114
+ win_length=800,
115
+ n_fft=800,
116
+ dtype=np.float32,
117
+ norm=None,
118
+ ):
119
+ """
120
+ # from librosa 0.6
121
+ Compute the sum-square envelope of a window function at a given hop length.
122
+ This is used to estimate modulation effects induced by windowing
123
+ observations in short-time fourier transforms.
124
+ Parameters
125
+ ----------
126
+ window : string, tuple, number, callable, or list-like
127
+ Window specification, as in `get_window`
128
+ n_frames : int > 0
129
+ The number of analysis frames
130
+ hop_length : int > 0
131
+ The number of samples to advance between frames
132
+ win_length : [optional]
133
+ The length of the window function. By default, this matches `n_fft`.
134
+ n_fft : int > 0
135
+ The length of each analysis frame.
136
+ dtype : np.dtype
137
+ The data type of the output
138
+ Returns
139
+ -------
140
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
141
+ The sum-squared envelope of the window function
142
+ """
143
+ if win_length is None:
144
+ win_length = n_fft
145
+
146
+ n = n_fft + hop_length * (n_frames - 1)
147
+ x = np.zeros(n, dtype=dtype)
148
+
149
+ # Compute the squared window at the desired length
150
+ win_sq = get_window(window, win_length, fftbins=True)
151
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
152
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
153
+
154
+ # Fill the envelope
155
+ for i in range(n_frames):
156
+ sample = i * hop_length
157
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
158
+ return x
159
+
160
+
161
+ def load_wav_to_torch(full_path):
162
+ sampling_rate, data = read(
163
+ full_path,
164
+ )
165
+ # print(data)
166
+ # data,sampling_rate = librosa.load(full_path)
167
+ # print(data)
168
+ return torch.FloatTensor(data), sampling_rate
169
+
170
+
171
+ if __name__ == "__main__":
172
+ lens = torch.tensor([2, 3, 7, 5, 4])
173
+ mask = get_mask(lens)
174
+ print(mask)
175
+ print(mask.shape)
Semantic_tokens/convert_factorize.py ADDED
@@ -0,0 +1,273 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import glob
3
+ import math
4
+ import multiprocessing as mp
5
+ import os
6
+ import statistics
7
+ import sys
8
+ from multiprocessing import Pool
9
+
10
+ import numpy as np
11
+ import pandas as pd
12
+ import soundfile as sf
13
+ import torchaudio
14
+ from pydub import AudioSegment
15
+ from tqdm import tqdm
16
+
17
+
18
+ def check_sample_rate(file_path):
19
+ try:
20
+ info = sf.info(file_path)
21
+ if info.samplerate != 24000:
22
+ return file_path
23
+ except Exception as e:
24
+ return None # In case of error, return None
25
+
26
+
27
+ def process_files(file_list):
28
+ with Pool() as pool:
29
+ result = list(
30
+ tqdm(pool.imap(check_sample_rate, file_list), total=len(file_list))
31
+ )
32
+ return [file for file in result if file is not None]
33
+
34
+
35
+ def read_paths_from_file(filename):
36
+ with open(filename, "r", encoding="utf8") as file:
37
+ path_list = []
38
+ for i in tqdm(file):
39
+ path = i.split("|")[1]
40
+ path_list.append(path.strip("\n"))
41
+
42
+ return path_list[:]
43
+
44
+
45
+ def gather_paths_from_glob():
46
+ return glob.glob("./**/*.wav", recursive=True)
47
+
48
+
49
+ def detect_leading_silence(sound, silence_threshold=-50, chunk_size=64):
50
+ trim_ms = 0
51
+ assert chunk_size > 0
52
+ while sound[
53
+ trim_ms : trim_ms + chunk_size
54
+ ].dBFS < silence_threshold and trim_ms < len(sound):
55
+ trim_ms += chunk_size
56
+ return trim_ms
57
+
58
+
59
+ def preprocess_audio(path, target_dBFS, frame_rate):
60
+ durations = []
61
+ dbfs = []
62
+ audio = AudioSegment.from_file(path)
63
+ dbfs.append(audio.dBFS)
64
+ audio = audio.set_channels(1)
65
+ audio = audio.set_frame_rate(frame_rate).set_sample_width(2)
66
+
67
+ start_trim = detect_leading_silence(audio)
68
+ end_trim = detect_leading_silence(audio.reverse())
69
+
70
+ duration = len(audio)
71
+ audio = audio[start_trim : duration - end_trim]
72
+ audio = (
73
+ AudioSegment.silent(duration=256, frame_rate=22050)
74
+ + audio
75
+ + AudioSegment.silent(duration=256, frame_rate=22050)
76
+ )
77
+
78
+ if path[-4:] == ".wav":
79
+ audio.export(path[:-4] + ".wav", format="wav")
80
+ elif path[-5:] == ".flac":
81
+ audio.export(path[:-5] + ".flac", format="flac")
82
+ else:
83
+ audio.export(path[:-4] + ".wav", format="wav")
84
+
85
+ durations.append(audio.duration_seconds)
86
+
87
+ return dbfs, durations
88
+
89
+
90
+ def preprocess_audio_chunk(args):
91
+ path_list_chunk, target_dBFS, frame_rate, n = args
92
+ dbfs = []
93
+ durations = []
94
+ for i in tqdm(path_list_chunk, desc="preprocess " + str(n)):
95
+ try:
96
+ audio = AudioSegment.from_file(i)
97
+ dbfs_i, durations_i = preprocess_audio(i, target_dBFS, frame_rate)
98
+ dbfs.extend(dbfs_i)
99
+ durations.extend(durations_i)
100
+ except Exception as e:
101
+ print(n, i, e)
102
+
103
+ return dbfs, durations
104
+
105
+
106
+ def preprocess_audio_paths(path_list, target_dBFS, frame_rate, num_workers):
107
+ chunk_size = len(path_list) // num_workers
108
+ path_chunks = [
109
+ path_list[i : i + chunk_size] for i in range(0, len(path_list), chunk_size)
110
+ ]
111
+
112
+ with Pool(num_workers) as pool:
113
+ results = pool.map(
114
+ preprocess_audio_chunk,
115
+ [
116
+ (chunk, target_dBFS, frame_rate, n)
117
+ for n, chunk in enumerate(path_chunks)
118
+ ],
119
+ )
120
+
121
+ dbfs = []
122
+ durations = []
123
+ for dbfs_i, durations_i in results:
124
+ dbfs.extend(dbfs_i)
125
+ durations.extend(durations_i)
126
+ return dbfs, durations
127
+
128
+
129
+ def gather_metadata_chunk(args):
130
+ path_list_chunk = args
131
+ dbfs = []
132
+ durations = []
133
+ files = []
134
+ for i in tqdm(path_list_chunk):
135
+ try:
136
+ path = i.split("|")[0]
137
+ audio = AudioSegment.from_file(path)
138
+ if audio.dBFS == -math.inf:
139
+ print("=====================")
140
+ print(path)
141
+ print("=====================")
142
+ continue
143
+ dbfs.append(audio.dBFS)
144
+ durations.append(audio.duration_seconds)
145
+ files.append((audio.duration_seconds, i))
146
+ if audio.duration_seconds == 0:
147
+ print(i)
148
+ except Exception as e:
149
+ print(e, i)
150
+
151
+ return dbfs, durations, files
152
+
153
+
154
+ def gather_metadata(path_list, num_workers):
155
+ chunk_size = len(path_list) // num_workers
156
+ path_chunks = [
157
+ path_list[i : i + chunk_size] for i in range(0, len(path_list), chunk_size)
158
+ ]
159
+
160
+ with Pool(num_workers) as pool:
161
+ results = pool.map(gather_metadata_chunk, [chunk for chunk in path_chunks])
162
+
163
+ dbfs = []
164
+ durations = []
165
+ files = []
166
+ for dbfs_i, durations_i, files_i in results:
167
+ dbfs.extend(dbfs_i)
168
+ durations.extend(durations_i)
169
+ files.extend(files_i)
170
+
171
+ files = sorted(files, key=lambda x: x[0])
172
+ with open(os.path.join(data_dir, "files_duration.txt"), "w") as file:
173
+ file.write("\n".join([i[1] + "|" + str(i[0]) for i in files]))
174
+
175
+ with open(os.path.join(data_dir, "files.txt"), "w") as file:
176
+ file.write("\n".join([i[1] for i in files if 2.0 < float(i[0]) < 15.0]))
177
+
178
+ return dbfs, durations
179
+
180
+
181
+ def process_audio_data(input_file, mode, num_workers, data_dir):
182
+ if input_file:
183
+ path_list = read_paths_from_file(input_file)
184
+ else:
185
+ path_list = gather_paths_from_glob()
186
+
187
+ speakers = []
188
+ for n, i in enumerate(path_list):
189
+ try:
190
+ speakers.append(i.split("/")[-2])
191
+ except:
192
+ print(n, i)
193
+
194
+ print("total audio files:", len(path_list))
195
+
196
+ if mode == "preprocess":
197
+ print("Preprocessing!")
198
+ target_dBFS = -24.196741 # not using
199
+
200
+ frame_rate = 24000
201
+ dbfs, durations = preprocess_audio_paths(
202
+ path_list[:], target_dBFS, frame_rate, num_workers
203
+ )
204
+
205
+ print("min duration : ", min(durations))
206
+ print("max duration : ", max(durations))
207
+ print("avg duration : ", sum(durations) / len(durations))
208
+ print("Standard Deviation of durations % s" % (statistics.stdev(durations)))
209
+ print("total duration : ", sum(durations))
210
+ print("DONE")
211
+
212
+ if mode == "metadata":
213
+ print("Gathering metadata")
214
+ dbfs, durations = gather_metadata(path_list[:], num_workers)
215
+
216
+ print("min duration : ", min(durations))
217
+ print("max duration : ", max(durations))
218
+ print("avg duration : ", sum(durations) / len(durations))
219
+ print("total duration : ", sum(durations))
220
+ print("Standard Deviation of sample is % s" % (statistics.stdev(durations)))
221
+ print("DONE")
222
+
223
+ # pd.DataFrame({'dBFS': dbfs, 'duration': durations, 'files': [i[1] for i in files]}).to_csv("meta.csv", index=False)
224
+
225
+
226
+ if __name__ == "__main__":
227
+ parser = argparse.ArgumentParser(
228
+ description="Audio processing script for preprocessing and metadata gathering",
229
+ formatter_class=argparse.RawDescriptionHelpFormatter,
230
+ epilog="""
231
+ Examples:
232
+ # Preprocess audio files from a text file
233
+ python convert_factorize.py --input paths.txt --mode preprocess
234
+
235
+ # Gather metadata from audio files in current directory
236
+ python convert_factorize.py --mode metadata
237
+
238
+ # Preprocess audio files found recursively in current directory
239
+ python convert_factorize.py --mode preprocess
240
+ """
241
+ )
242
+
243
+ parser.add_argument(
244
+ "--input", "-i",
245
+ type=str,
246
+ help="Path to text file containing audio file paths in 'language|abspath|text' format"
247
+ )
248
+
249
+ parser.add_argument(
250
+ "--mode", "-m",
251
+ type=str,
252
+ choices=["preprocess", "metadata"],
253
+ required=True,
254
+ help="Processing mode: 'preprocess' to process audio files, 'metadata' to gather statistics"
255
+ )
256
+
257
+ parser.add_argument(
258
+ "--workers", "-w",
259
+ type=int,
260
+ default=4,
261
+ help="Number of worker processes for parallel processing (default: 4)"
262
+ )
263
+
264
+ args = parser.parse_args()
265
+
266
+ data_dir = "/".join(args.input.split("/")[:-1])
267
+ print(f"Data directory: {data_dir}")
268
+
269
+ print(f"Input file: {args.input}")
270
+ print(f"Mode: {args.mode}")
271
+ print(f"Number of workers: {args.workers}")
272
+
273
+ process_audio_data(args.input, args.mode, args.workers, data_dir)
Semantic_tokens/extract_m4t_tokens_multi.py ADDED
@@ -0,0 +1,228 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import argparse
2
+ import multiprocessing
3
+ import os
4
+ import random
5
+ import shutil
6
+ import sys
7
+ import time
8
+ from functools import partial
9
+
10
+ import numpy as np
11
+ import torch
12
+ import torchaudio
13
+ from tqdm import tqdm
14
+
15
+ from seamless_communication.models.unit_extractor import (
16
+ KmeansModel, UnitExtractor, Wav2Vec2LayerOutputModel)
17
+
18
+
19
+ def train_test_split_large_file(input_file, train_file, test_file, test_ratio=0.2, seed=42):
20
+ """
21
+ Memory-efficient train-test split for large files.
22
+ Performs two passes: first to count lines, second to split.
23
+ """
24
+ random.seed(seed)
25
+
26
+ # First pass: count lines
27
+ with open(input_file, "r", encoding="utf-8") as f:
28
+ total_lines = sum(1 for _ in f)
29
+
30
+ n_test = int(total_lines * test_ratio)
31
+
32
+ # Choose random line numbers for test set
33
+ test_indices = set(random.sample(range(total_lines), n_test))
34
+
35
+ # Second pass: write to train/test files
36
+ with (
37
+ open(input_file, "r", encoding="utf-8") as in_f,
38
+ open(train_file, "w", encoding="utf-8") as train_f,
39
+ open(test_file, "w", encoding="utf-8") as test_f,
40
+ ):
41
+ for idx, line in enumerate(in_f):
42
+ if idx in test_indices:
43
+ test_f.write(line)
44
+ else:
45
+ train_f.write(line)
46
+
47
+ print(
48
+ f"Split {total_lines} lines into {total_lines - n_test} train and {n_test} test lines."
49
+ )
50
+
51
+
52
+ lock = multiprocessing.Lock()
53
+
54
+
55
+ def process_data(data, device, out_layer_number, process_no, kmeans_uri, model_name, data_dir, batch_size=10000):
56
+ lock.acquire()
57
+ unit_extractor = UnitExtractor(
58
+ model_name, kmeans_uri, device=torch.device(f"cuda:{device}")
59
+ )
60
+ lock.release()
61
+ results = []
62
+ # if i == 0:
63
+ data = tqdm(data, desc="process no : " + str(process_no))
64
+ for i in data:
65
+ try:
66
+ audio, sr = torchaudio.load(i[1])
67
+ audio = (
68
+ torchaudio.functional.resample(audio, sr, 16000)
69
+ .squeeze(0)
70
+ .unsqueeze(-1)
71
+ )
72
+ with torch.no_grad():
73
+ units = (
74
+ unit_extractor.predict(audio.to(device), out_layer_number - 1)
75
+ .detach()
76
+ .cpu()
77
+ .numpy()
78
+ )
79
+ text = " ".join([str(k) for k in units]).strip()
80
+ results.append("|".join(i) + "|" + text + "\n")
81
+ except Exception as e:
82
+ print(i, e)
83
+
84
+ if len(results) == batch_size:
85
+ semt = "".join(results)
86
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(process_no) + "_semt.txt"), "a") as file:
87
+ file.write(semt)
88
+ results = []
89
+
90
+ if len(results) != 0:
91
+ semt = "".join(results)
92
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(process_no) + "_semt.txt"), "a") as file:
93
+ file.write(semt)
94
+ results = []
95
+
96
+
97
+ def parse_arguments():
98
+ """Parse command line arguments."""
99
+ parser = argparse.ArgumentParser(
100
+ description="Extract M4T semantic tokens from audio files using multi-GPU processing"
101
+ )
102
+
103
+ parser.add_argument(
104
+ "input_file",
105
+ help="Path to the input file containing audio file paths and metadata"
106
+ )
107
+
108
+ parser.add_argument(
109
+ "--kmeans-uri",
110
+ default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
111
+ help="URI for the kmeans model (default: M4T kmeans model)"
112
+ )
113
+
114
+ parser.add_argument(
115
+ "--model-name",
116
+ default="xlsr2_1b_v2",
117
+ help="Model name for unit extraction (default: xlsr2_1b_v2)"
118
+ )
119
+
120
+ parser.add_argument(
121
+ "--out-layer-number",
122
+ type=int,
123
+ default=35,
124
+ help="Output layer number for feature extraction (default: 35)"
125
+ )
126
+
127
+ parser.add_argument(
128
+ "--gpu-multiplier",
129
+ type=int,
130
+ default=1,
131
+ help="Multiplier for number of GPUs to use (default: 1)"
132
+ )
133
+
134
+ parser.add_argument(
135
+ "--test-ratio",
136
+ type=float,
137
+ default=0.1,
138
+ help="Ratio of data to use for validation set (default: 0.1)"
139
+ )
140
+
141
+ parser.add_argument(
142
+ "--seed",
143
+ type=int,
144
+ default=42,
145
+ help="Random seed for reproducibility (default: 42)"
146
+ )
147
+
148
+ parser.add_argument(
149
+ "--batch-size",
150
+ type=int,
151
+ default=10000,
152
+ help="Number of results to accumulate before writing to file (default: 10000)"
153
+ )
154
+
155
+ return parser.parse_args()
156
+
157
+
158
+ if __name__ == "__main__":
159
+ start = time.time()
160
+
161
+ # Parse command line arguments
162
+ args = parse_arguments()
163
+
164
+ # Extract parameters from arguments
165
+ kmeans_uri = args.kmeans_uri
166
+ model_name = args.model_name
167
+ out_layer_number = args.out_layer_number
168
+ k = args.gpu_multiplier
169
+ test_ratio = args.test_ratio
170
+ seed = args.seed
171
+ batch_size = args.batch_size
172
+
173
+ num_gpus_to_use = torch.cuda.device_count() * k
174
+ data_dir = "/".join(args.input_file.split("/")[:-1])
175
+
176
+ if os.path.exists(os.path.join(data_dir, "SEMANTICS_")):
177
+ shutil.rmtree(os.path.join(data_dir, "SEMANTICS_"))
178
+ os.makedirs(os.path.join(data_dir, "SEMANTICS_"), exist_ok=True)
179
+
180
+ with open(args.input_file, "r") as file:
181
+ data = file.read().strip("\n").split("\n")[:]
182
+
183
+ data = [i.split("|") for i in data][:]
184
+
185
+ # Split data into chunks for each GPU
186
+ chunk_size = len(data) // num_gpus_to_use
187
+ data_chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]
188
+ if len(data_chunks) != num_gpus_to_use:
189
+ data_chunks[-2] += data_chunks[-1]
190
+ data_chunks = data_chunks[:-1]
191
+
192
+ processes = []
193
+ for i in range(num_gpus_to_use):
194
+ p = multiprocessing.Process(
195
+ target=process_data,
196
+ args=(data_chunks[i], i // k, out_layer_number, i, kmeans_uri, model_name, data_dir, batch_size)
197
+ )
198
+ processes.append(p)
199
+ p.start()
200
+
201
+ for p in processes:
202
+ p.join()
203
+
204
+ for i in range(num_gpus_to_use):
205
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(i) + "_semt.txt"), "r") as file:
206
+ data = file.read()
207
+
208
+ with open(
209
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt"), "a"
210
+ ) as file:
211
+ file.write(data)
212
+
213
+ input_file = os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt")
214
+ train_file = (
215
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt_train.txt")
216
+ )
217
+ test_file = (
218
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt_val.txt")
219
+ )
220
+
221
+ train_test_split_large_file(
222
+ input_file, train_file, test_file, test_ratio=test_ratio, seed=seed
223
+ )
224
+ print("processing took: with %d instances" % (num_gpus_to_use), time.time() - start)
225
+ print(
226
+ "data ready at:",
227
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt"),
228
+ )
Semantic_tokens/seamless_communication ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit 90e2b57ac4d82fa2bfaa25caeffe39ceb8b2ebec
T2S/__pycache__/autoregressive.cpython-310.pyc ADDED
Binary file (8.35 kB). View file
 
T2S/__pycache__/gpt_inference.cpython-310.pyc ADDED
Binary file (3.96 kB). View file
 
T2S/__pycache__/mel_spec.cpython-310.pyc ADDED
Binary file (4.81 kB). View file
 
T2S/__pycache__/t2s_modules.cpython-310.pyc ADDED
Binary file (15.2 kB). View file
 
T2S/__pycache__/utilities.cpython-310.pyc ADDED
Binary file (3.46 kB). View file
 
T2S/autoregressive.py ADDED
@@ -0,0 +1,296 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import functools
2
+ import os
3
+ import sys
4
+ from typing import Any
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.optim as optim
10
+ from torch.utils.data import DataLoader, Dataset
11
+ from tqdm import tqdm
12
+ from transformers import (GemmaConfig, GemmaModel, GPT2Config, GPT2LMHeadModel,
13
+ GPT2Model, GPT2Tokenizer)
14
+
15
+ from config import config
16
+ from Text import code_labels, labels, text_labels
17
+
18
+ from .gpt_inference import GPT2InferenceModel
19
+ from .t2s_modules import GST
20
+
21
+ # code encdec
22
+ text_enc = {j: i for i, j in enumerate(text_labels)}
23
+ text_dec = {i: j for i, j in enumerate(text_labels)}
24
+
25
+ # text encdec
26
+ code_enc = {j: i for i, j in enumerate(code_labels)}
27
+ code_dec = {i: j for i, j in enumerate(code_labels)}
28
+
29
+
30
+ def null_position_embeddings(range, dim):
31
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
32
+
33
+
34
+ class TS_model(nn.Module):
35
+ def __init__(
36
+ self, n_embed=1024, n_layer=30, n_head=16, n_positions=config.t2s_position
37
+ ):
38
+ super(TS_model, self).__init__()
39
+ assert (n_embed / n_head) % 2 == 0, "n_embed n_head not a division of 2"
40
+ self.vocab_size = len(labels)
41
+ self.n_positions = n_positions
42
+ self.n_embed = n_embed
43
+ self.n_layer = n_layer
44
+ self.n_head = n_head
45
+
46
+ if self.vocab_size % 2 != 0:
47
+ self.vocab_size += 1
48
+ k = 1
49
+
50
+ self.config = GemmaConfig(
51
+ vocab_size=self.vocab_size,
52
+ hidden_size=self.n_embed,
53
+ intermediate_size=self.n_embed * k,
54
+ num_hidden_layers=self.n_layer,
55
+ num_attention_heads=self.n_head,
56
+ num_key_value_heads=self.n_head,
57
+ head_dim=int(self.n_embed / self.n_head),
58
+ hidden_act="gelu_pytorch_tanh",
59
+ hidden_activation=None,
60
+ max_position_embeddings=self.n_positions,
61
+ initializer_range=0.02,
62
+ rms_norm_eps=1e-06,
63
+ use_cache=True,
64
+ pad_token_id=0,
65
+ eos_token_id=1,
66
+ bos_token_id=2,
67
+ tie_word_embeddings=True,
68
+ rope_theta=10000.0,
69
+ attention_bias=False,
70
+ attention_dropout=0.0,
71
+ )
72
+
73
+ self.gpt = GemmaModel(self.config)
74
+ del self.gpt.embed_tokens
75
+
76
+ self.GST = GST(
77
+ model_channels=self.n_embed,
78
+ num_heads=self.n_head,
79
+ in_channels=config.n_mel_channels,
80
+ )
81
+ self.text_head = nn.Linear(self.n_embed, len(text_labels))
82
+ self.code_head = nn.Linear(self.n_embed, len(code_labels))
83
+
84
+ self.text_embed = nn.Embedding(len(text_labels), self.n_embed)
85
+ self.code_embed = nn.Embedding(len(code_labels), self.n_embed)
86
+ self.language_embed = nn.Embedding(len(config.lang_index), self.n_embed)
87
+ self.final_norm = nn.LayerNorm(self.n_embed)
88
+
89
+ def init_random_embeddings(self):
90
+ self.text_embed.weight.data.uniform_(-1, 1)
91
+ self.code_embed.weight.data.uniform_(-1, 1)
92
+
93
+ def get_speaker_latent(self, ref_mels):
94
+ ref_mels = ref_mels.unsqueeze(1) if len(ref_mels.shape) == 3 else ref_mels
95
+
96
+ conds = []
97
+ for j in range(ref_mels.shape[1]):
98
+ conds.append(self.GST(ref_mels[:, j, :, :]))
99
+
100
+ conds = torch.cat(conds, dim=-1)
101
+ conds = conds.mean(dim=-1)
102
+
103
+ return conds.unsqueeze(1)
104
+
105
+ def forward(
106
+ self,
107
+ text_ids,
108
+ codes_ids=None,
109
+ speaker_embed=None,
110
+ ref_clips=None,
111
+ language=torch.tensor(0),
112
+ attn_mask=None,
113
+ return_loss=False,
114
+ ):
115
+ assert speaker_embed is not None or ref_clips is not None
116
+ text_embed = self.text_embed(text_ids)
117
+
118
+ lanugage_embed = self.language_embed(language).unsqueeze(1)
119
+ code_embed = None
120
+ code_probs = None
121
+
122
+ if codes_ids is not None:
123
+ code_embed = self.code_embed(codes_ids)
124
+
125
+ if ref_clips is not None:
126
+ speaker_embed = self.get_speaker_latent(ref_clips)
127
+
128
+ text_embed, code_embed = self.get_logits(
129
+ lanugage_embed=lanugage_embed,
130
+ speaker_embed=speaker_embed,
131
+ text_embed=text_embed,
132
+ code_embed=code_embed,
133
+ attn_mask=attn_mask,
134
+ )
135
+ text_probs = self.text_head(text_embed).permute(0, 2, 1)
136
+
137
+ if codes_ids is not None:
138
+ code_probs = self.code_head(code_embed).permute(0, 2, 1)
139
+
140
+ if return_loss:
141
+ loss_text = F.cross_entropy(
142
+ text_probs[:, :, :-1], text_ids[:, 1:].long(), reduce=False
143
+ )
144
+ loss_mel = F.cross_entropy(
145
+ code_probs[:, :, :-1], codes_ids[:, 1:].long(), reduce=False
146
+ )
147
+ return loss_text, loss_mel, code_probs
148
+
149
+ return text_probs, code_probs
150
+
151
+ def get_logits(
152
+ self, lanugage_embed, speaker_embed, text_embed, code_embed=None, attn_mask=None
153
+ ):
154
+ if code_embed is not None:
155
+ embed = torch.cat(
156
+ [lanugage_embed, speaker_embed, text_embed, code_embed], dim=1
157
+ )
158
+ position_ids = torch.zeros(
159
+ (embed.shape[0], embed.shape[1]), device=embed.device
160
+ )
161
+ indices = torch.tensor(
162
+ [0, 0]
163
+ + list(range(text_embed.shape[1]))
164
+ + list(range(code_embed.shape[1])),
165
+ device=embed.device,
166
+ )
167
+ position_ids[:, : indices.size(0)] = indices
168
+ else:
169
+ embed = torch.cat([lanugage_embed, speaker_embed, text_embed], dim=1)
170
+ position_ids = torch.zeros(
171
+ (embed.shape[0], embed.shape[1]), device=embed.device
172
+ )
173
+ indices = torch.tensor(
174
+ [0, 0] + list(range(text_embed.shape[1])), device=embed.device
175
+ )
176
+ position_ids[:, : indices.size(0)] = indices
177
+
178
+ if attn_mask is None:
179
+ attn_mask = torch.ones_like(embed).to(embed.device)
180
+ else:
181
+ attn_mask = torch.cat(
182
+ [torch.ones((embed.shape[0], 2)).to(embed.device), attn_mask], dim=1
183
+ )
184
+ gpt_output = self.gpt(
185
+ inputs_embeds=embed,
186
+ attention_mask=attn_mask,
187
+ position_ids=position_ids,
188
+ return_dict=True,
189
+ )
190
+ enc = gpt_output.last_hidden_state[:, 2:]
191
+ enc = self.final_norm(enc)
192
+ if code_embed is not None:
193
+ return enc[:, : text_embed.shape[1]], enc[:, -code_embed.shape[1] :]
194
+
195
+ return enc[:, : text_embed.shape[1]], None
196
+
197
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
198
+ self.gpt_inference = GPT2InferenceModel(
199
+ self.config,
200
+ self.gpt,
201
+ None,
202
+ self.code_embed,
203
+ self.final_norm,
204
+ self.code_head,
205
+ kv_cache=kv_cache,
206
+ )
207
+ self.gpt.embed_tokens = self.code_embed
208
+
209
+ if use_deepspeed:
210
+ import deepspeed
211
+
212
+ self.ds_engine = deepspeed.init_inference(
213
+ model=self.gpt_inference.half(), # Transformers models
214
+ mp_size=1, # Number of GPU
215
+ dtype=torch.float32, # desired data type of output
216
+ replace_method="auto", # Lets DS autmatically identify the layer to replace
217
+ replace_with_kernel_inject=True, # replace the model with the kernel injector
218
+ )
219
+ self.gpt_inference = self.ds_engine.module.eval()
220
+
221
+ def compute_embeddings(self, language, cond_latents, text_inputs, code_inputs):
222
+ text_embed = self.text_embed(text_inputs)
223
+ lanugage_embed = self.language_embed(language).unsqueeze(1)
224
+
225
+ emb = torch.cat([lanugage_embed, cond_latents, text_embed], dim=1)
226
+
227
+ position_ids = torch.zeros(
228
+ (emb.shape[0], emb.shape[1] + len(code_inputs)), device=emb.device
229
+ )
230
+ indices = torch.tensor(
231
+ [0, 0] + list(range(text_embed.shape[1])) + list(range(len(code_inputs))),
232
+ device=emb.device,
233
+ )
234
+ position_ids[:, : indices.size(0)] = indices
235
+
236
+ self.gpt_inference.store_prefix_emb(emb)
237
+ gpt_inputs = torch.full(
238
+ (
239
+ emb.shape[0],
240
+ emb.shape[1] + len(code_inputs), # +1 for the start_audio_token
241
+ ),
242
+ fill_value=1,
243
+ dtype=torch.long,
244
+ device=text_inputs.device,
245
+ )
246
+ gpt_inputs[:, -len(code_inputs) :] = torch.tensor(code_inputs)
247
+ return (gpt_inputs, position_ids)
248
+
249
+ def generate(
250
+ self,
251
+ language,
252
+ cond_latents,
253
+ text_inputs,
254
+ code_inputs=[code_enc["<SST>"]],
255
+ **hf_generate_kwargs,
256
+ ):
257
+ gpt_inputs, position_ids = self.compute_embeddings(
258
+ language, cond_latents, text_inputs, code_inputs
259
+ )
260
+ gen = self.gpt_inference.generate(
261
+ gpt_inputs,
262
+ bos_token_id=code_enc["<SST>"],
263
+ pad_token_id=code_enc["<PAD>"],
264
+ eos_token_id=code_enc["<EST>"],
265
+ max_length=self.n_positions,
266
+ position_ids=position_ids,
267
+ **hf_generate_kwargs,
268
+ )
269
+ if "return_dict_in_generate" in hf_generate_kwargs:
270
+ return gen.sequences[:, gpt_inputs.shape[1] :], gen
271
+ return gen[:, gpt_inputs.shape[1] - len(code_inputs) + 1 :]
272
+
273
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
274
+ return self.gpt_inference.generate_stream(
275
+ fake_inputs,
276
+ bos_token_id=code_enc["<SST>"],
277
+ pad_token_id=code_enc["<PAD>"],
278
+ eos_token_id=code_enc["<EST>"],
279
+ max_length=self.n_positions,
280
+ do_stream=True,
281
+ **hf_generate_kwargs,
282
+ )
283
+
284
+
285
+ class LearnedPositionEmbeddings(nn.Module):
286
+ def __init__(self, seq_len, model_dim, init=0.02):
287
+ super().__init__()
288
+ self.emb = nn.Embedding(seq_len, model_dim)
289
+ self.emb.weight.data.normal_(mean=0.0, std=init)
290
+
291
+ def forward(self, x):
292
+ sl = x.shape[1]
293
+ return self.emb(torch.arange(0, sl, device=x.device))
294
+
295
+ def get_fixed_embedding(self, ind, dev):
296
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
T2S/dataset.py ADDED
@@ -0,0 +1,562 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ from typing import Any
4
+
5
+ sys.path.append("../")
6
+ import linecache
7
+ import mmap
8
+ import pickle as pkl
9
+ import random
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.optim as optim
16
+ import torchaudio
17
+ import transformers
18
+ from accelerate import Accelerator, DistributedDataParallelKwargs
19
+ from autoregressive import TS_model
20
+ from cleaners import english_cleaners
21
+ from librosa.filters import mel as librosa_mel_fn
22
+ from mel_spec import get_mel_spectrogram
23
+ from meta_stats import process_file, process_file_for_heads
24
+ from stft import STFT
25
+ from torch.utils.data import (DataLoader, Dataset, WeightedRandomSampler,
26
+ get_worker_info)
27
+ from tqdm import tqdm
28
+ from utilities import get_mask_from_lengths
29
+
30
+ import wandb
31
+ from config import config
32
+ from Text import code_labels, labels, text_labels
33
+
34
+ torch.manual_seed(config.seed_value)
35
+ np.random.seed(config.seed_value)
36
+ random.seed(config.seed_value)
37
+ print(text_labels)
38
+ # add semantic tokens:
39
+ # tok_enc = {j:i for i,j in enumerate(labels)}
40
+ # tok_dec = {j:i for i,j in enumerate(labels)}
41
+
42
+ # code encdec
43
+ text_enc = {j: i for i, j in enumerate(text_labels)}
44
+ text_dec = {i: j for i, j in enumerate(text_labels)}
45
+
46
+ # text encdec
47
+ code_enc = {j: i for i, j in enumerate(code_labels)}
48
+ code_dec = {i: j for i, j in enumerate(code_labels)}
49
+
50
+
51
+ def read_specific_line(filename, line_number):
52
+ line = linecache.getline(filename, line_number)
53
+ return line.strip() # Remove any leading or trailing whitespace
54
+
55
+
56
+ CLIP_LENGTH = config.CLIP_LENGTH
57
+
58
+
59
+ class semantic_dataset_batch(Dataset):
60
+ def __init__(
61
+ self,
62
+ transcript_path,
63
+ semantic_path=None,
64
+ ref_mels_path=None,
65
+ ref_k=3,
66
+ scale=False,
67
+ process_id=None,
68
+ total_processes=None,
69
+ ):
70
+ super().__init__()
71
+ self.scale = scale
72
+ if not scale:
73
+ with open(transcript_path, "r") as file:
74
+ data = file.read().strip("\n").split("\n")[:]
75
+
76
+ with open(semantic_path, "r") as file:
77
+ semb = file.read().strip("\n").split("\n")
78
+
79
+ with open(ref_mels_path, "rb") as file:
80
+ self.ref_mels = pkl.load(file)
81
+
82
+ semb = {
83
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
84
+ }
85
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
86
+
87
+ self.data = [[i, semb[i], data[i]] for i in data.keys()]
88
+
89
+ else:
90
+ # with open(transcript_path,'r') as file:
91
+ # get meta for dataset
92
+ # for count, line in enumerate(file):
93
+ # pass
94
+ # count = 80
95
+ print(transcript_path)
96
+ # self.weights,self.count = process_file(transcript_path)
97
+ self.heads, self.weights, self.count = process_file_for_heads(
98
+ transcript_path, total_processes, process_id
99
+ )
100
+ print("length :", self.count)
101
+ self.data_len = self.count
102
+ self.transcript_path = transcript_path
103
+ line_index = {}
104
+ with open(transcript_path, "rb") as file:
105
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
106
+ line_number = 0
107
+ offset = 0
108
+ while offset < len(mmapped_file):
109
+ line_index[line_number] = offset
110
+ offset = mmapped_file.find(b"\n", offset) + 1
111
+ # print(line_number,offset)
112
+ line_number += 1
113
+ self.mmapped_file = mmapped_file
114
+ self.line_index = line_index
115
+
116
+ self.process_id = process_id
117
+ self.total_processes = total_processes
118
+ self.iterator = None
119
+
120
+ self.ref_k = ref_k
121
+ self.max_wav_value = config.MAX_WAV_VALUE
122
+ self.stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
123
+
124
+ mel_basis = librosa_mel_fn(
125
+ sr=config.sampling_rate,
126
+ n_fft=config.filter_length,
127
+ n_mels=config.n_mel_channels,
128
+ fmin=config.mel_fmin,
129
+ fmax=config.mel_fmax,
130
+ )
131
+
132
+ self.mel_basis = torch.from_numpy(mel_basis).float()
133
+
134
+ def get_mel(self, filepath):
135
+ # audio, sampling_rate = load_wav_to_torch(filepath)
136
+ # audio_norm = audio / self.max_wav_value
137
+ audio_norm, sampling_rate = torchaudio.load(filepath)
138
+
139
+ # dur = audio_norm.shape[-1]/sampling_rate
140
+
141
+ # if dur<0.5:
142
+ # return None,None,None
143
+
144
+ # if self.clip and dur>10 and align:
145
+ # # print('big file',dur)
146
+ # max_audio_start = int(dur - 10)
147
+ # audio_start = random.randint(0, max_audio_start)
148
+
149
+ # audio_norm = audio_norm[:,audio_start*sampling_rate:(audio_start+10)*sampling_rate]
150
+ # semb_ids = semb_ids[audio_start*50:(audio_start+10)*50 -1]
151
+
152
+ # 86 mel -> 1s for 22050 setting
153
+ # ` 93 mel ->`1s for 24000 setting
154
+
155
+ # add 64ms of values to start and end
156
+ # audio_norm += torch.randn(audio_norm.shape[0])*1e-8
157
+ # audio_norm = torch.concat([torch.randn(1412)*1e-8,audio_norm,torch.randn(1412)*1e-8])
158
+ # audio_norm = audio_norm.unsqueeze(0)
159
+ # y = torch.autograd.Variable(audio_norm, requires_grad=False)
160
+
161
+ # assert(torch.min(y.data) >= -1)
162
+ # assert(torch.max(y.data) <= 1)
163
+ # magnitudes, phases = self.stft_fn.transform(y)
164
+ # magnitudes = magnitudes.data
165
+ # mel_output = torch.matmul(self.mel_basis, magnitudes)
166
+ # mel_output = dynamic_range_compression(mel_output)
167
+ # melspec = torch.squeeze(mel_output, 0)
168
+ # energy = torch.norm(magnitudes, dim=1).squeeze(0)
169
+ # melspec,energy = mel_spectrogram(audio_norm)
170
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
171
+ energy = []
172
+ # if align:
173
+ # return melspec,list(energy),semb_ids
174
+ return melspec, list(energy)
175
+
176
+ def __len__(self):
177
+ if self.scale:
178
+ return self.data_len
179
+ return len(self.data)
180
+
181
+ # def get_process_heads(self,):
182
+ # '''
183
+ # divide data and heads based on the batch_size and weights
184
+ # '''
185
+
186
+ # new_heads ={}
187
+ # new_weights =[]
188
+ # process_batch_size = config.ts_batch_size*config.ts_gradient_accumulation_steps
189
+ # sm=0
190
+ # for i,j in zip(self.heads,self.weights):
191
+
192
+ # if sm + j > process_batch_size:
193
+ # if sm+j == process_batch_size:
194
+ # new_heads[i] = self.heads[i]
195
+ # new_weights.append(j)
196
+ # else:
197
+ # new_heads[i] = self.heads[i][:len(self.heads[i])*(process_batch_size-sm)//process_batch_size]
198
+ # new_weights.append(process_batch_size-sm)
199
+ # else:
200
+ # new_heads[i] = self.heads[i]
201
+ # new_weights.append(j)
202
+
203
+ # self.get_worker_heads()
204
+
205
+ # old heads and weights
206
+ # new_heads = {}
207
+ # for i in self.heads:
208
+ # segment_size = (len(self.heads[i]) + self.total_processes - 1) // self.total_processes
209
+ # start_idx = self.process_id * segment_size
210
+ # end_idx = start_idx + segment_size
211
+
212
+ # if end_idx > len(self.heads[i]):
213
+ # # Create a list that wraps around to the beginning
214
+ # segment = self.heads[i][start_idx:] + self.heads[i][:end_idx - len(self.heads[i])]
215
+ # else:
216
+ # segment = self.heads[i][start_idx:end_idx]
217
+ # new_heads[i]=segment
218
+ # self.heads = new_heads
219
+ # print(self.process_id,[len(self.heads[i]) for i in self.heads])
220
+ # self.get_worker_heads()
221
+
222
+ def get_worker_heads(
223
+ self,
224
+ ):
225
+ self.worker_id = get_worker_info().id
226
+ self.num_worker = get_worker_info().num_workers
227
+ new_heads = {}
228
+ for i in self.heads:
229
+ segment_size = (len(self.heads[i]) + self.num_worker - 1) // self.num_worker
230
+ start_idx = self.worker_id * segment_size
231
+ end_idx = start_idx + segment_size
232
+
233
+ if end_idx > len(self.heads[i]):
234
+ # Create a list that wraps around to the beginning
235
+ segment = (
236
+ self.heads[i][start_idx:]
237
+ + self.heads[i][: end_idx - len(self.heads[i])]
238
+ )
239
+ else:
240
+ segment = self.heads[i][start_idx:end_idx]
241
+ new_heads[i] = segment
242
+ self.heads = new_heads
243
+ # print("worker:",self.worker_id,self.process_id,[len(self.heads[i]) for i in self.heads],self.weights)
244
+
245
+ def get_head(self):
246
+ # self.get_process_heads()
247
+ self.get_worker_heads()
248
+ # print("weights:",self.weights,[h for h in self.heads])
249
+ self.indices = [0] * len(self.heads)
250
+ # self.process_heads = [{i:self.heads[i][self.process_id:]}for i in self.heads]
251
+ while True:
252
+ for (
253
+ n,
254
+ (head, weight),
255
+ ) in enumerate(zip(self.heads, self.weights)):
256
+ # if process_id == 0:
257
+ # print(weight,head)
258
+ for i in range(weight):
259
+ if self.indices[n] < len(self.heads[head]):
260
+ # print(self.heads[head][self.indices[n]],worker_id,self.indices)
261
+ yield self.heads[head][self.indices[n]]
262
+ self.indices[n] += 1
263
+ else:
264
+ self.indices[n] = 0
265
+ random.shuffle(self.heads[head])
266
+ # shuffle the indices
267
+
268
+ def __getitem__(self, index) -> Any:
269
+ if self.iterator is None:
270
+ self.iterator = self.get_head()
271
+ if not self.scale:
272
+ lang, path, semb, text = self.data[index]
273
+ ref_mels = self.ref_mels[path][: self.ref_k]
274
+
275
+ else:
276
+ # line = read_specific_line(self.transcript_path,index+1)
277
+
278
+ index = next(self.iterator)
279
+ # print(self.worker_id,self.process_id,index)
280
+ self.mmapped_file.seek(self.line_index[index])
281
+ line = self.mmapped_file.readline().decode("utf-8")
282
+
283
+ lang, path, text, semb_ids, ref_mels = line.split("|")
284
+ # a=5/0
285
+ # semb_ids = [int(i)+1 for i in semb_ids.split()]
286
+ semb = semb_ids.split()
287
+ ref_mels = [i.split(",") for i in ref_mels.split("\t")][: self.ref_k]
288
+
289
+ if len(semb) < 25:
290
+ if index + 1 < self.data_len:
291
+ return self.__getitem__(index + 1)
292
+ return self.__getitem__(0)
293
+
294
+ if len(ref_mels) == 0:
295
+ ref_mels.append((path, 1))
296
+ ref_mels.append((path, 1))
297
+ ref_mels.append((path, 1))
298
+
299
+ while len(ref_mels) < self.ref_k:
300
+ ref_mels.append(ref_mels[-1])
301
+
302
+ text = text.lower().strip()
303
+ # try:
304
+ text_ids = [text_enc["<S>"]] + [text_enc[i] for i in text] + [text_enc["<E>"]]
305
+ semb_ids = (
306
+ [code_enc["<SST>"]] + [code_enc[i] for i in semb] + [code_enc["<EST>"]]
307
+ )
308
+
309
+ # except Exception as e:
310
+ # print(e)
311
+ # print(lang,path,text,index)
312
+ # exit
313
+ # input_ids = text_ids+semb_ids
314
+ # pad_length = config.t2s_position-(len(text_ids)+len(semb_ids))
315
+
316
+ # token_type_ids = [0]*len(text_ids)+[1]*len(semb_ids)+[0]*pad_length
317
+ # positional_ids = [i for i in range(len(text_ids))]+[i for i in range(len(semb_ids))]+[0]*pad_length
318
+ # labels = [-100]*len(text_ids)+semb_ids+[-100]*pad_length
319
+ # attention_mask = [1]*len(input_ids)+[0]*pad_length
320
+ # input_ids += [tok_enc['<PAD>']]*pad_length
321
+
322
+ def get_random_portion(mel, mask_lengths):
323
+ clip = mask_lengths <= CLIP_LENGTH
324
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
325
+ for n, z in enumerate(clip):
326
+ if not z:
327
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
328
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
329
+ return ref_mel
330
+
331
+ try:
332
+ ref_mels = [self.get_mel(path)[0] for path, score in ref_mels]
333
+ except Exception as e:
334
+ print(index, e)
335
+ if index + 1 < self.data_len:
336
+ return self.__getitem__(index + 1)
337
+ return self.__getitem__(0)
338
+
339
+ ref_c = []
340
+ for i in range(self.ref_k):
341
+ if ref_mels[i] is None:
342
+ continue
343
+ ref_c.append(ref_mels[i])
344
+
345
+ if len(ref_c) == 0:
346
+ # print('no refs worthy')
347
+ if index + 1 < self.data_len:
348
+ return self.__getitem__(index + 1)
349
+ return self.__getitem__(0)
350
+
351
+ if len(ref_c) != self.ref_k:
352
+ # print('less refs found',len(ref_c))
353
+ while len(ref_c) < self.ref_k:
354
+ ref_c.append(ref_c[-1])
355
+
356
+ ref_mels = ref_c
357
+
358
+ max_target_len = max([x.size(1) for x in ref_mels])
359
+ ref_mels_padded = (
360
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len))
361
+ ) * 1e-9
362
+ mel_length = []
363
+ for i, mel in enumerate(ref_mels):
364
+ ref_mels_padded[i, :, : mel.size(1)] = mel
365
+ mel_length.append(mel.shape[-1])
366
+
367
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
368
+
369
+ return {
370
+ "text_ids": text_ids,
371
+ "semb_ids": semb_ids,
372
+ "ref_mels": ref_mels,
373
+ "lang": torch.tensor(config.lang_index[lang]),
374
+ }
375
+
376
+
377
+ # def get_padded_seq(sequences):
378
+
379
+ # max_len=max([len(s) for s in sequences])
380
+ # for i in range(len(sequences)):
381
+ # sequences[i]=sequences[i]+tok_enc['<PAD>']*(max_len-len(sequences[i]))
382
+
383
+ # return sequences
384
+
385
+
386
+ def get_padded_seq(sequences, pad_random, before=False, pad__=0):
387
+ max_len = max([len(s) for s in sequences])
388
+ seq_len = []
389
+ for i in range(len(sequences)):
390
+ seq_len.append(len(sequences[i]))
391
+ if pad_random:
392
+ pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
393
+ else:
394
+ pad_ = [pad__] * (max_len - len(sequences[i]))
395
+ if not before:
396
+ sequences[i] = sequences[i] + pad_
397
+ else:
398
+ sequences[i] = pad_ + sequences[i]
399
+
400
+ return sequences, seq_len
401
+
402
+
403
+ def collate(batch):
404
+ text_ids = []
405
+ semb_ids = []
406
+ # paths=[]
407
+ ref_mels = []
408
+ langs = []
409
+ # ref_mels_length=[]
410
+
411
+ for b in batch:
412
+ text_ids.append(b["text_ids"])
413
+ semb_ids.append(b["semb_ids"])
414
+ # paths.append(b['path'])
415
+ ref_mels.append(b["ref_mels"])
416
+ langs.append(b["lang"])
417
+ # ref_mels_length.append(b['ref_mel_length'])
418
+
419
+ text_ids, text_len = get_padded_seq(
420
+ text_ids, pad_random=False, before=False, pad__=text_enc["<E>"]
421
+ )
422
+ code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc["<EST>"])
423
+
424
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
425
+ ref_mels_padded = (
426
+ torch.randn(
427
+ (
428
+ len(batch),
429
+ ref_mels[0].shape[0],
430
+ config.n_mel_channels,
431
+ ref_max_target_len,
432
+ )
433
+ )
434
+ ) * 1e-9
435
+
436
+ for i, mel in enumerate(ref_mels):
437
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
438
+
439
+ # print(mel_padded.shape,torch.tensor(code).shape,torch.tensor(mel_length),get_mask_from_lengths(torch.tensor(mel_length)))
440
+
441
+ return (
442
+ torch.tensor(text_ids),
443
+ torch.tensor(code),
444
+ torch.tensor(text_len),
445
+ torch.tensor(code_len),
446
+ ref_mels_padded,
447
+ torch.tensor(langs),
448
+ )
449
+
450
+
451
+ def get_dataset(transcript_path, get_process_id, total_processes):
452
+ return semantic_dataset_batch(
453
+ transcript_path,
454
+ scale=True,
455
+ process_id=get_process_id,
456
+ total_processes=total_processes,
457
+ )
458
+
459
+
460
+ if __name__ == "__main__":
461
+ accelerator = Accelerator(
462
+ gradient_accumulation_steps=config.ts_gradient_accumulation_steps
463
+ ) # ,kwargs_handlers=[ddp_kwargs]) mixed_precision="fp16",
464
+
465
+ get_process_id = accelerator.process_index
466
+ total_processes = accelerator.num_processes
467
+
468
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
469
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
470
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
471
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
472
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
473
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
474
+ train_dataset_ = semantic_dataset_batch(
475
+ config.data_path + "/transcript_train_20s_final_normalized_filtered.txt",
476
+ "../" + config.data_path + "/semt.txt",
477
+ "../" + config.data_path + "/ref_clips.pkl",
478
+ scale=True,
479
+ process_id=get_process_id,
480
+ total_processes=total_processes,
481
+ )
482
+ # sampler = WeightedRandomSampler(
483
+ # train_dataset_.weights,
484
+ # train_dataset_.count,
485
+ # replacement=False)
486
+ train_dataset = DataLoader(
487
+ train_dataset_,
488
+ pin_memory=True,
489
+ persistent_workers=True,
490
+ num_workers=config.ts_num_workers,
491
+ batch_size=config.ts_batch_size,
492
+ shuffle=False,
493
+ drop_last=False,
494
+ collate_fn=collate,
495
+ sampler=None,
496
+ )
497
+ print("batch", config.ts_batch_size)
498
+ # val_dataset = DataLoader(semantic_dataset_batch(config.data_path+'/transcript_test_20_final_normalized.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',scale=True,process_id=get_process_id,total_processes = total_processes),pin_memory=True,
499
+ # persistent_workers=True,num_workers=2,batch_size=config.ts_batch_size,shuffle=True,drop_last=False,collate_fn=collate)
500
+
501
+ train_dataloader = accelerator.prepare(train_dataset)
502
+ # if accelerator.is_local_main_process:
503
+ # from IPython import embed
504
+ # embed()
505
+
506
+ # checkiong the sampler working
507
+ import math
508
+ from collections import defaultdict
509
+
510
+ def calculate_duration(code_len):
511
+ return math.ceil(((code_len + 1) / 50) * 2) / 2
512
+
513
+ sampling = defaultdict(int)
514
+ dataset = []
515
+ batch_data = {}
516
+ batch = 0
517
+ batch_data[batch] = defaultdict(int)
518
+ for n, data in enumerate(tqdm(train_dataloader)):
519
+ # break
520
+ text_ids, code, text_len, code_len, ref_clips, langs = data
521
+ # print(text_ids)
522
+ # print('=====')
523
+ # # break
524
+ for i, j in zip(code_len, text_ids):
525
+ dur = calculate_duration(i - 2)
526
+ # print(dur,i,code.shape)
527
+ # sampling[calculate_duration(i)]+=1
528
+ dataset.append(list(j.detach().cpu().numpy()))
529
+
530
+ if dur > 19.5:
531
+ batch_data[batch]["20_sentence"] += 1
532
+ continue
533
+ if dur <= 5:
534
+ batch_data[batch]["5s"] += 1
535
+ continue
536
+ elif dur <= 10:
537
+ batch_data[batch]["10s"] += 1
538
+ continue
539
+ elif dur <= 15:
540
+ batch_data[batch]["15s"] += 1
541
+ continue
542
+ elif dur <= 20:
543
+ batch_data[batch]["20s"] += 1
544
+ continue
545
+ # print(batch)
546
+ if (n + 1) % config.ts_gradient_accumulation_steps == 0:
547
+ batch += 1
548
+ batch_data[batch] = defaultdict(int)
549
+ # break
550
+ # if n==20:
551
+ # break
552
+ # # print(sampling)
553
+ with open(
554
+ f"Sampling_data_meta/sampling_{accelerator.process_index}.pkl", "wb"
555
+ ) as file:
556
+ pkl.dump(batch_data, file)
557
+ with open(
558
+ f"Sampling_data_meta/sampling_dataset_{accelerator.process_index}.pkl", "wb"
559
+ ) as file:
560
+ pkl.dump(dataset, file)
561
+ print(batch_data[0])
562
+ # # # return 0
T2S/gpt_inference.py ADDED
@@ -0,0 +1,198 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # copied from coqui
2
+ import math
3
+
4
+ import torch
5
+ from torch import nn
6
+ from transformers import (GemmaForCausalLM, GemmaPreTrainedModel,
7
+ GPT2PreTrainedModel)
8
+ from transformers.modeling_outputs import (CausalLMOutputWithCrossAttentions,
9
+ CausalLMOutputWithPast)
10
+
11
+
12
+ class GPT2InferenceModel(GemmaForCausalLM):
13
+ """Override GPT2LMHeadModel to allow for prefix conditioning."""
14
+
15
+ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
16
+ super().__init__(config)
17
+ self.transformer = gpt
18
+ self.pos_embedding = pos_emb
19
+ self.embeddings = embeddings
20
+ self.final_norm = norm
21
+ self.lm_head = nn.Sequential(norm, linear)
22
+ self.kv_cache = kv_cache
23
+ self.cur_pos = None
24
+
25
+ def store_prefix_emb(self, prefix_emb):
26
+ self.cached_prefix_emb = prefix_emb
27
+
28
+ def prepare_inputs_for_generation(
29
+ self,
30
+ input_ids,
31
+ past_key_values=None,
32
+ attention_mask=None,
33
+ inputs_embeds=None,
34
+ cache_position=None,
35
+ use_cache=True,
36
+ **kwargs,
37
+ ):
38
+ past_length = 0
39
+ if past_key_values.get_seq_length() != 0:
40
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
41
+ past_length = (
42
+ cache_position[0]
43
+ if cache_position is not None
44
+ else past_key_values.get_seq_length()
45
+ )
46
+ # from IPython import embed; embed()
47
+ max_cache_length = (
48
+ torch.tensor(past_key_values.get_seq_length(), device=input_ids.device)
49
+ if past_key_values.get_seq_length() is not None
50
+ else None
51
+ )
52
+ cache_length = (
53
+ past_length
54
+ if max_cache_length is None
55
+ else torch.min(max_cache_length, past_length)
56
+ )
57
+
58
+ # Keep only the unprocessed tokens:
59
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
60
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
61
+ if (
62
+ attention_mask is not None
63
+ and attention_mask.shape[1] > input_ids.shape[1]
64
+ ):
65
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
66
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
67
+ # input_ids based on the past_length.
68
+ elif past_length < input_ids.shape[1]:
69
+ input_ids = input_ids[:, past_length:]
70
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
71
+
72
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
73
+ if (
74
+ max_cache_length is not None
75
+ and attention_mask is not None
76
+ and cache_length + input_ids.shape[1] > max_cache_length
77
+ ):
78
+ attention_mask = attention_mask[:, -max_cache_length:]
79
+
80
+ position_ids = kwargs.get("position_ids", None)
81
+ if attention_mask is not None and position_ids is None:
82
+ # create position_ids on the fly for batch generation
83
+ position_ids = attention_mask.long().cumsum(-1) - 1
84
+ position_ids.masked_fill_(attention_mask == 0, 1)
85
+ if past_key_values:
86
+ position_ids = position_ids[:, -input_ids.shape[1] :]
87
+
88
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
89
+ if inputs_embeds is not None and past_length == 0:
90
+ model_inputs = {"inputs_embeds": inputs_embeds}
91
+ else:
92
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
93
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
94
+ # TODO: use `next_tokens` directly instead.
95
+ model_inputs = {"input_ids": input_ids.contiguous()}
96
+
97
+ input_length = (
98
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
99
+ )
100
+ if cache_position is None:
101
+ cache_position = torch.arange(
102
+ past_length, past_length + input_length, device=input_ids.device
103
+ )
104
+ elif use_cache:
105
+ cache_position = cache_position[-input_length:]
106
+
107
+ model_inputs.update(
108
+ {
109
+ "position_ids": position_ids,
110
+ "cache_position": cache_position,
111
+ "past_key_values": past_key_values,
112
+ "use_cache": use_cache,
113
+ "attention_mask": attention_mask,
114
+ }
115
+ )
116
+ return model_inputs
117
+
118
+ def forward(
119
+ self,
120
+ input_ids=None,
121
+ past_key_values=None,
122
+ attention_mask=None,
123
+ token_type_ids=None,
124
+ position_ids=None,
125
+ head_mask=None,
126
+ inputs_embeds=None,
127
+ encoder_hidden_states=None,
128
+ encoder_attention_mask=None,
129
+ labels=None,
130
+ use_cache=None,
131
+ output_attentions=None,
132
+ output_hidden_states=None,
133
+ return_dict=None,
134
+ cache_position=None,
135
+ ):
136
+ assert self.cached_prefix_emb is not None
137
+ assert inputs_embeds is None # Not supported by this inference model.
138
+ assert labels is None # Training not supported by this inference model.
139
+ return_dict = (
140
+ return_dict if return_dict is not None else self.config.use_return_dict
141
+ )
142
+
143
+ # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
144
+
145
+ # Create embedding
146
+ prefix_len = self.cached_prefix_emb.shape[1]
147
+ if input_ids.shape[1] != 1:
148
+ gen_inputs = input_ids[:, prefix_len:]
149
+ gen_emb = self.embeddings(gen_inputs)
150
+ gen_emb = gen_emb # + self.pos_embedding(gen_emb)
151
+ if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
152
+ prefix_emb = self.cached_prefix_emb.repeat_interleave(
153
+ gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
154
+ )
155
+ else:
156
+ prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
157
+ emb = torch.cat([prefix_emb, gen_emb], dim=1)
158
+ self.cur_pos = position_ids[:, -1:]
159
+ else:
160
+ emb = self.embeddings(input_ids)
161
+ self.cur_pos += 1.0
162
+ position_ids = self.cur_pos
163
+
164
+ transformer_outputs = self.transformer(
165
+ inputs_embeds=emb,
166
+ past_key_values=past_key_values,
167
+ attention_mask=attention_mask,
168
+ position_ids=position_ids,
169
+ use_cache=use_cache,
170
+ output_attentions=output_attentions,
171
+ output_hidden_states=output_hidden_states,
172
+ return_dict=return_dict,
173
+ )
174
+ hidden_states = transformer_outputs[0]
175
+ lm_logits = self.lm_head(hidden_states)
176
+
177
+ if not return_dict:
178
+ return (lm_logits,) + transformer_outputs[1:]
179
+
180
+ return CausalLMOutputWithPast(
181
+ loss=None,
182
+ logits=lm_logits,
183
+ past_key_values=transformer_outputs.past_key_values,
184
+ hidden_states=transformer_outputs.hidden_states,
185
+ attentions=transformer_outputs.attentions,
186
+ )
187
+
188
+ @staticmethod
189
+ def _reorder_cache(past_key_values, beam_idx):
190
+ reordered_past = ()
191
+ for layer_past in past_key_values:
192
+ reordered_past += (
193
+ tuple(
194
+ past_state.index_select(0, beam_idx.to(past_state.device))
195
+ for past_state in layer_past
196
+ ),
197
+ )
198
+ return reordered_past
T2S/gpt_model_train.py ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+
4
+ sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
5
+ import linecache
6
+ import mmap
7
+ import pickle as pkl
8
+ import random
9
+ from typing import Any
10
+
11
+ import numpy as np
12
+ import torch
13
+ import torch.nn as nn
14
+ import torch.nn.functional as F
15
+ import torch.optim as optim
16
+ import torchaudio
17
+ import transformers
18
+ from accelerate import Accelerator, DistributedDataParallelKwargs
19
+ from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
20
+ from tqdm import tqdm
21
+
22
+ import wandb
23
+ from config import config
24
+ from T2S.autoregressive import TS_model
25
+ from T2S.mel_spec import get_mel_spectrogram
26
+ from T2S.utilities import get_mask_from_lengths
27
+ from Text import code_labels, labels, text_labels
28
+
29
+ torch.manual_seed(config.seed_value)
30
+ np.random.seed(config.seed_value)
31
+ random.seed(config.seed_value)
32
+
33
+ # code encdec
34
+ text_enc = {j: i for i, j in enumerate(text_labels)}
35
+ text_dec = {i: j for i, j in enumerate(text_labels)}
36
+
37
+ # text encdec
38
+ code_enc = {j: i for i, j in enumerate(code_labels)}
39
+ code_dec = {i: j for i, j in enumerate(code_labels)}
40
+
41
+
42
+ def read_specific_line(filename, line_number):
43
+ line = linecache.getline(filename, line_number)
44
+ return line.strip() # Remove any leading or trailing whitespace
45
+
46
+
47
+ CLIP_LENGTH = config.CLIP_LENGTH
48
+
49
+
50
+ class semantic_dataset(Dataset):
51
+ def __init__(
52
+ self,
53
+ transcript_path,
54
+ semantic_path=None,
55
+ ref_mels_path=None,
56
+ ref_k=1,
57
+ scale=True,
58
+ ):
59
+ super().__init__()
60
+ self.scale = scale
61
+ if not scale:
62
+ with open(transcript_path, "r") as file:
63
+ data = file.read().strip("\n").split("\n")[:]
64
+
65
+ with open(semantic_path, "r") as file:
66
+ semb = file.read().strip("\n").split("\n")
67
+
68
+ with open(ref_mels_path, "rb") as file:
69
+ self.ref_mels = pkl.load(file)
70
+
71
+ semb = {
72
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
73
+ }
74
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
75
+
76
+ self.data = [[i, semb[i], data[i]] for i in data.keys()]
77
+
78
+ else:
79
+ line_index = {}
80
+ with open(transcript_path, "rb") as file:
81
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
82
+ line_number = 0
83
+ offset = 0
84
+ pbar = tqdm()
85
+ while offset < len(mmapped_file):
86
+ line_index[line_number] = offset
87
+ offset = mmapped_file.find(b"\n", offset) + 1
88
+ line_number += 1
89
+ pbar.update(1)
90
+ pbar.close()
91
+ self.mmapped_file = mmapped_file
92
+ self.line_index = line_index
93
+ self.data_len = len(line_index)
94
+ print("data length:", self.data_len)
95
+ self.transcript_path = transcript_path
96
+
97
+ self.ref_k = ref_k
98
+ self.max_wav_value = config.MAX_WAV_VALUE
99
+
100
+ def get_mel(self, filepath):
101
+ audio_norm, sampling_rate = torchaudio.load(filepath)
102
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
103
+ energy = []
104
+ return melspec, list(energy)
105
+
106
+ def __len__(self):
107
+ if self.scale:
108
+ return self.data_len
109
+ return len(self.data)
110
+
111
+ def __getitem__(self, index) -> Any:
112
+ if not self.scale:
113
+ lang, path, semb, text = self.data[index]
114
+ ref_mels = self.ref_mels[path][: self.ref_k]
115
+
116
+ else:
117
+ self.mmapped_file.seek(self.line_index[index])
118
+ line = self.mmapped_file.readline().decode("utf-8")
119
+
120
+ try:
121
+ lang, path, text, semb_ids = line.split("|")
122
+ except Exception as e:
123
+ print(index, line)
124
+ if index + 1 < self.data_len:
125
+ return self.__getitem__(index + 1)
126
+ return self.__getitem__(0)
127
+ semb = semb_ids.split()
128
+ ref_mels = [path]
129
+ # ref_mels = [i.split(',') for i in ref_mels.split('\t')][:self.ref_k]
130
+
131
+ if len(semb) < 5:
132
+ print(index, "No Semb tokens found")
133
+ if index + 1 < self.data_len:
134
+ return self.__getitem__(index + 1)
135
+ return self.__getitem__(0)
136
+
137
+ if len(ref_mels) == 0:
138
+ ref_mels.append((path, 1))
139
+ ref_mels.append((path, 1))
140
+ ref_mels.append((path, 1))
141
+
142
+ while len(ref_mels) < self.ref_k:
143
+ ref_mels.append(ref_mels[-1])
144
+
145
+ text = text.lower().strip()
146
+ try:
147
+ text_ids = (
148
+ [text_enc["<S>"]] + [text_enc[i] for i in text] + [text_enc["<E>"]]
149
+ )
150
+ semb_ids = (
151
+ [code_enc["<SST>"]] + [code_enc[i] for i in semb] + [code_enc["<EST>"]]
152
+ )
153
+ except Exception as e:
154
+ print(index, e)
155
+ if index + 1 < self.data_len:
156
+ return self.__getitem__(index + 1)
157
+ return self.__getitem__(0)
158
+
159
+ def get_random_portion(mel, mask_lengths):
160
+ clip = mask_lengths <= CLIP_LENGTH
161
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
162
+ for n, z in enumerate(clip):
163
+ if not z:
164
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
165
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
166
+ return ref_mel
167
+
168
+ try:
169
+ ref_mels = [self.get_mel(path)[0] for path in ref_mels]
170
+ except Exception as e:
171
+ print(index, e, path)
172
+ if index + 1 < self.data_len:
173
+ return self.__getitem__(index + 1)
174
+ return self.__getitem__(0)
175
+
176
+ ref_c = []
177
+ for i in range(self.ref_k):
178
+ if ref_mels[i] is None:
179
+ continue
180
+ ref_c.append(ref_mels[i])
181
+
182
+ if len(ref_c) == 0:
183
+ if index + 1 < self.data_len:
184
+ return self.__getitem__(index + 1)
185
+ return self.__getitem__(0)
186
+
187
+ if len(ref_c) != self.ref_k:
188
+ while len(ref_c) < self.ref_k:
189
+ ref_c.append(ref_c[-1])
190
+
191
+ ref_mels = ref_c
192
+
193
+ max_target_len = max([x.size(1) for x in ref_mels])
194
+ ref_mels_padded = (
195
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len))
196
+ ) * 1e-9
197
+ mel_length = []
198
+ for i, mel in enumerate(ref_mels):
199
+ ref_mels_padded[i, :, : mel.size(1)] = mel
200
+ mel_length.append(mel.shape[-1])
201
+
202
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
203
+
204
+ return {
205
+ "text_ids": text_ids,
206
+ "semb_ids": semb_ids,
207
+ "ref_mels": ref_mels,
208
+ "lang": torch.tensor(config.lang_index[lang]),
209
+ }
210
+
211
+
212
+ def get_padded_seq(sequences, pad_random, before=False, pad__=0):
213
+ max_len = max([len(s) for s in sequences])
214
+ seq_len = []
215
+ for i in range(len(sequences)):
216
+ seq_len.append(len(sequences[i]))
217
+ if pad_random:
218
+ pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
219
+ else:
220
+ pad_ = [pad__] * (max_len - len(sequences[i]))
221
+ if not before:
222
+ sequences[i] = sequences[i] + pad_
223
+ else:
224
+ sequences[i] = pad_ + sequences[i]
225
+
226
+ return sequences, seq_len
227
+
228
+
229
+ def collate(batch):
230
+ text_ids = []
231
+ semb_ids = []
232
+ ref_mels = []
233
+ langs = []
234
+
235
+ for b in batch:
236
+ text_ids.append(b["text_ids"])
237
+ semb_ids.append(b["semb_ids"])
238
+ ref_mels.append(b["ref_mels"])
239
+ langs.append(b["lang"])
240
+
241
+ text_ids, text_len = get_padded_seq(
242
+ text_ids, pad_random=False, before=False, pad__=text_enc["<E>"]
243
+ )
244
+ code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc["<EST>"])
245
+
246
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
247
+ ref_mels_padded = (
248
+ torch.randn(
249
+ (
250
+ len(batch),
251
+ ref_mels[0].shape[0],
252
+ config.n_mel_channels,
253
+ ref_max_target_len,
254
+ )
255
+ )
256
+ ) * 1e-9
257
+
258
+ for i, mel in enumerate(ref_mels):
259
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
260
+
261
+ return (
262
+ torch.tensor(text_ids),
263
+ torch.tensor(code),
264
+ torch.tensor(text_len),
265
+ torch.tensor(code_len),
266
+ ref_mels_padded,
267
+ torch.tensor(langs),
268
+ )
269
+
270
+
271
+ def train(model, train_dataset, val_dataset, save_dir, checkpoint_initial=None):
272
+ accelerator = Accelerator(
273
+ gradient_accumulation_steps=config.ts_gradient_accumulation_steps
274
+ ) # ,kwargs_handlers=[ddp_kwargs]) mixed_precision="fp16",
275
+
276
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
277
+ conf_ = {}
278
+ for i, j in config.__dict__.items():
279
+ conf_[str(i)] = str(j)
280
+ wandb_log = wandb.init(
281
+ project=config.wandb_project,
282
+ entity=config.user_name,
283
+ name=config.model_name,
284
+ config=conf_,
285
+ )
286
+ wandb_log.watch(model, log_freq=100)
287
+ else:
288
+ wandb_log = None
289
+
290
+ optimizer = optim.Adam(
291
+ model.parameters(), lr=config.ts_lr, weight_decay=config.ts_weight_decay
292
+ )
293
+ # optimizer = transformers.Adafactor(model.parameters(), lr=config.ts_lr,weight_decay=config.ts_weight_decay, relative_step =False, scale_parameter =False)
294
+ lr = config.ts_lr
295
+ step_num = 0
296
+ start_epoch = 0
297
+ if checkpoint_initial is not None:
298
+ model.load_state_dict(
299
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["model"],
300
+ strict=True,
301
+ )
302
+ if (
303
+ config.ts_finetuning
304
+ ): # freezing heads results in less hallucinations after Ft.
305
+ for param in model.text_head.parameters():
306
+ param.requires_grad = False
307
+
308
+ for param in model.code_head.parameters():
309
+ param.requires_grad = False
310
+
311
+ model.train()
312
+
313
+ print("loading optimizer")
314
+ optimizer.load_state_dict(
315
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
316
+ "optimizer"
317
+ ]
318
+ )
319
+ step_num = (
320
+ int(
321
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["step"]
322
+ )
323
+ + 1
324
+ )
325
+ step_num = 0
326
+ start_epoch = (
327
+ int(
328
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
329
+ "epoch"
330
+ ]
331
+ )
332
+ + 1
333
+ )
334
+ print(f"Resuming training from epoch {start_epoch} and step {step_num}")
335
+
336
+ train_dataloader, val_dataloader, model, optimizer = accelerator.prepare(
337
+ train_dataset, val_dataset, model, optimizer
338
+ )
339
+ val_dataloader = val_dataset
340
+ min_val_loss = 1000
341
+ model.train()
342
+
343
+ for i in range(start_epoch, config.ts_epochs):
344
+ epoch_loss = []
345
+ if accelerator.is_main_process:
346
+ train_loader = tqdm(
347
+ train_dataloader,
348
+ desc="Rank %d: Training epoch %d"
349
+ % (accelerator.local_process_index, i),
350
+ )
351
+ else:
352
+ train_loader = train_dataloader
353
+
354
+ for n, inputs in enumerate(train_loader):
355
+ with accelerator.accumulate(model):
356
+ # with accelerator.autocast():
357
+ text_ids, code, text_len, code_len, ref_clips, langs = inputs
358
+ mask_text = get_mask_from_lengths(text_len)
359
+ code_mask = get_mask_from_lengths(code_len)
360
+ attn_mask = torch.cat([mask_text, code_mask], dim=1)
361
+ loss_text, loss_code, _ = model(
362
+ text_ids=text_ids,
363
+ ref_clips=ref_clips,
364
+ codes_ids=code,
365
+ language=langs,
366
+ return_loss=True,
367
+ attn_mask=attn_mask,
368
+ )
369
+
370
+ loss_text *= mask_text[:, 1:].float()
371
+ loss_text = loss_text.sum() / mask_text[:, 1:].sum()
372
+
373
+ loss_code *= code_mask[:, 1:].float()
374
+ loss_code = loss_code.sum() / code_mask[:, 1:].sum()
375
+
376
+ loss = loss_text * config.text_loss_weight + loss_code
377
+
378
+ accelerator.backward(loss)
379
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
380
+ optimizer.step()
381
+ optimizer.zero_grad()
382
+ step_num += 1
383
+
384
+ if (
385
+ step_num % config.ts_gradient_accumulation_steps == 0
386
+ and config.ts_wandb_logs
387
+ and accelerator.is_main_process
388
+ ):
389
+ wandb_log.log(
390
+ {
391
+ "training_loss": loss.item(),
392
+ "step": step_num // config.ts_gradient_accumulation_steps,
393
+ }
394
+ )
395
+
396
+ epoch_loss.append(loss.item())
397
+
398
+ if (
399
+ not config.ts_finetuning
400
+ and step_num
401
+ % (config.ts_gradient_accumulation_steps * config.ts_eval_step)
402
+ == 0
403
+ ):
404
+ val_loss = val(model, val_dataloader, accelerator.is_main_process)
405
+ val_loss = accelerator.gather_for_metrics(val_loss).mean().item()
406
+ model.train()
407
+ if config.ts_wandb_logs and accelerator.is_main_process:
408
+ wandb_log.log(
409
+ {
410
+ "val_loss": val_loss,
411
+ "epoch": i,
412
+ "scheduled_learning_rate": lr,
413
+ "step": step_num // config.ts_gradient_accumulation_steps,
414
+ }
415
+ )
416
+
417
+ if val_loss < min_val_loss:
418
+ # save the model
419
+ accelerator.wait_for_everyone()
420
+ unwrapped_model = accelerator.unwrap_model(model)
421
+ checkpoint = {
422
+ "epoch": i,
423
+ "step": str(step_num // config.gradient_accumulation_steps),
424
+ "model": unwrapped_model.state_dict(),
425
+ "optimizer": optimizer.state_dict(),
426
+ }
427
+ torch.save(
428
+ checkpoint,
429
+ os.path.join(config.save_root_dir, "_best.pt"),
430
+ )
431
+ min_val_loss = val_loss
432
+
433
+ # save the latest checkpoint
434
+ accelerator.wait_for_everyone()
435
+ unwrapped_model = accelerator.unwrap_model(model)
436
+ checkpoint = {
437
+ "epoch": i,
438
+ "step": str(step_num // config.gradient_accumulation_steps),
439
+ "model": unwrapped_model.state_dict(),
440
+ "optimizer": optimizer.state_dict(),
441
+ }
442
+ torch.save(
443
+ checkpoint,
444
+ os.path.join(config.save_root_dir, str(step_num // config.gradient_accumulation_steps) + "_latest.pt"),
445
+ )
446
+ print(f"Saved latest checkpoint at {os.path.join(config.save_root_dir, str(step_num // config.gradient_accumulation_steps) + '_latest.pt')}")
447
+
448
+ val_loss = val(model, val_dataloader, accelerator.is_main_process)
449
+ val_loss = accelerator.gather_for_metrics(val_loss).mean().item()
450
+ model.train()
451
+ if config.ts_wandb_logs and accelerator.is_main_process:
452
+ wandb_log.log(
453
+ {
454
+ "val_loss": val_loss,
455
+ "epoch": i,
456
+ "scheduled_learning_rate": lr,
457
+ "step": step_num // config.ts_gradient_accumulation_steps,
458
+ }
459
+ )
460
+
461
+ if val_loss < min_val_loss:
462
+ # save the model
463
+ accelerator.wait_for_everyone()
464
+ unwrapped_model = accelerator.unwrap_model(model)
465
+ checkpoint = {
466
+ "epoch": i,
467
+ "step": str(step_num // config.gradient_accumulation_steps),
468
+ "model": unwrapped_model.state_dict(),
469
+ "optimizer": optimizer.state_dict(),
470
+ }
471
+ torch.save(
472
+ checkpoint, os.path.join(config.save_root_dir, "best.pt")
473
+ )
474
+ min_val_loss = val_loss
475
+ print(f"Saved best checkpoint at {os.path.join(config.save_root_dir, 'best.pt')}")
476
+ accelerator.wait_for_everyone()
477
+ unwrapped_model = accelerator.unwrap_model(model)
478
+ checkpoint = {
479
+ "epoch": i,
480
+ "step": str(step_num // config.gradient_accumulation_steps),
481
+ "model": unwrapped_model.state_dict(),
482
+ "optimizer": optimizer.state_dict(),
483
+ }
484
+ torch.save(
485
+ checkpoint,
486
+ os.path.join(config.save_root_dir, str(i) + "_latest.pt"),
487
+ )
488
+
489
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
490
+ wandb_log.log(
491
+ {
492
+ "scheduled_learning_rate": lr,
493
+ "epoch": i,
494
+ "step": step_num // config.ts_gradient_accumulation_steps,
495
+ }
496
+ )
497
+ print(
498
+ "epoch_number : ", i, " training loss : ", sum(epoch_loss) / len(epoch_loss)
499
+ )
500
+
501
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
502
+ wandb_log.finish()
503
+
504
+
505
+ def val(model, val_dataloader, _main=False):
506
+ """
507
+ Return the loss value
508
+ """
509
+ print("VALIDATION STARTING:")
510
+ model.eval()
511
+ val_loss = []
512
+ device = next(model.parameters()).device
513
+ if _main:
514
+ val_dataloader = tqdm(val_dataloader)
515
+ with torch.no_grad():
516
+ for inputs in val_dataloader:
517
+ text_ids, code, text_len, code_len, ref_clips, langs = inputs
518
+ mask_text = get_mask_from_lengths(text_len).to(device)
519
+ code_mask = get_mask_from_lengths(code_len).to(device)
520
+ attn_mask = torch.cat([mask_text, code_mask], dim=1)
521
+ loss_text, loss_code, _ = model(
522
+ text_ids=text_ids.to(device),
523
+ ref_clips=ref_clips.to(device),
524
+ codes_ids=code.to(device),
525
+ language=langs.to(device),
526
+ return_loss=True,
527
+ attn_mask=attn_mask,
528
+ )
529
+
530
+ loss_text *= mask_text[:, 1:].float()
531
+ loss_text = loss_text.sum() / mask_text[:, 1:].sum()
532
+ loss_code *= code_mask[:, 1:].float()
533
+ loss_code = loss_code.sum() / code_mask[:, 1:].sum()
534
+ loss = loss_text * config.text_loss_weight + loss_code
535
+
536
+ val_loss.append(loss.item())
537
+
538
+ val_loss = sum(val_loss) / len(val_loss)
539
+ print(" Validation loss : ", val_loss)
540
+ return torch.tensor(val_loss).to(device)
541
+
542
+
543
+ def main():
544
+
545
+ os.makedirs(os.path.join(config.save_root_dir, config.model_name, "T2S"), exist_ok=True)
546
+
547
+ file_name_train = config.train_file
548
+ file_name_val = config.val_file
549
+
550
+ checkpoint = config.t2s_checkpoint
551
+ model = TS_model(n_embed=1024, n_layer=30, n_head=16)
552
+
553
+ val_dataset = DataLoader(
554
+ semantic_dataset(file_name_val, scale=True),
555
+ pin_memory=True,
556
+ persistent_workers=True,
557
+ num_workers=2,
558
+ batch_size=config.ts_batch_size,
559
+ shuffle=True,
560
+ drop_last=False,
561
+ collate_fn=collate,
562
+ )
563
+
564
+ train_dataset_ = semantic_dataset(file_name_train, scale=True)
565
+ train_dataset = DataLoader(
566
+ train_dataset_,
567
+ pin_memory=True,
568
+ persistent_workers=True,
569
+ num_workers=config.ts_num_workers,
570
+ batch_size=config.ts_batch_size,
571
+ shuffle=True,
572
+ drop_last=False,
573
+ collate_fn=collate,
574
+ )
575
+
576
+ train(
577
+ model,
578
+ train_dataset,
579
+ val_dataset,
580
+ save_dir=os.path.join(config.save_root_dir, config.model_name, "T2S"),
581
+ checkpoint_initial=checkpoint
582
+ )
583
+
584
+
585
+ if __name__ == "__main__":
586
+ main()
T2S/mel_spec.py ADDED
@@ -0,0 +1,178 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) 2024 NVIDIA CORPORATION.
2
+ # Licensed under the MIT license.
3
+
4
+ # Adapted from https://github.com/jik876/hifi-gan under the MIT license.
5
+ # LICENSE is in incl_licenses directory.
6
+
7
+ if __name__ == "__main__":
8
+ import os
9
+ import sys
10
+
11
+ sys.path.append("../")
12
+
13
+ import math
14
+ import os
15
+ import pathlib
16
+ import random
17
+
18
+ import numpy as np
19
+ import torch
20
+ import torch.utils.data
21
+ from librosa.filters import mel as librosa_mel_fn
22
+ from librosa.util import normalize
23
+ from scipy.io.wavfile import read
24
+ from tqdm import tqdm
25
+
26
+ from config import config
27
+
28
+ MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
29
+
30
+
31
+ def load_wav(full_path, sr_target):
32
+ sampling_rate, data = read(full_path)
33
+ if sampling_rate != sr_target:
34
+ raise RuntimeError(
35
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
36
+ )
37
+ return data, sampling_rate
38
+
39
+
40
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
41
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
42
+
43
+
44
+ def dynamic_range_decompression(x, C=1):
45
+ return np.exp(x) / C
46
+
47
+
48
+ def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
49
+ return torch.log(torch.clamp(x, min=clip_val) * C)
50
+
51
+
52
+ def dynamic_range_decompression_torch(x, C=1):
53
+ return torch.exp(x) / C
54
+
55
+
56
+ def spectral_normalize_torch(magnitudes):
57
+ return dynamic_range_compression_torch(magnitudes)
58
+
59
+
60
+ def spectral_de_normalize_torch(magnitudes):
61
+ return dynamic_range_decompression_torch(magnitudes)
62
+
63
+
64
+ mel_basis_cache = {}
65
+ hann_window_cache = {}
66
+
67
+
68
+ def mel_spectrogram(
69
+ y: torch.Tensor,
70
+ n_fft: int,
71
+ num_mels: int,
72
+ sampling_rate: int,
73
+ hop_size: int,
74
+ win_size: int,
75
+ fmin: int,
76
+ fmax: int = None,
77
+ center: bool = False,
78
+ ) -> torch.Tensor:
79
+ """
80
+ Calculate the mel spectrogram of an input signal.
81
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
82
+
83
+ Args:
84
+ y (torch.Tensor): Input signal.
85
+ n_fft (int): FFT size.
86
+ num_mels (int): Number of mel bins.
87
+ sampling_rate (int): Sampling rate of the input signal.
88
+ hop_size (int): Hop size for STFT.
89
+ win_size (int): Window size for STFT.
90
+ fmin (int): Minimum frequency for mel filterbank.
91
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
92
+ center (bool): Whether to pad the input to center the frames. Default is False.
93
+
94
+ Returns:
95
+ torch.Tensor: Mel spectrogram.
96
+ """
97
+ if torch.min(y) < -1.0:
98
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
99
+ if torch.max(y) > 1.0:
100
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
101
+
102
+ device = y.device
103
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
104
+
105
+ if key not in mel_basis_cache:
106
+ mel = librosa_mel_fn(
107
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
108
+ )
109
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
110
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
111
+
112
+ mel_basis = mel_basis_cache[key]
113
+ hann_window = hann_window_cache[key]
114
+
115
+ padding = (n_fft - hop_size) // 2
116
+ y = torch.nn.functional.pad(
117
+ y.unsqueeze(1), (padding, padding), mode="reflect"
118
+ ).squeeze(1)
119
+
120
+ spec = torch.stft(
121
+ y,
122
+ n_fft,
123
+ hop_length=hop_size,
124
+ win_length=win_size,
125
+ window=hann_window,
126
+ center=center,
127
+ pad_mode="reflect",
128
+ normalized=False,
129
+ onesided=True,
130
+ return_complex=True,
131
+ )
132
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
133
+
134
+ mel_spec = torch.matmul(mel_basis, spec)
135
+ mel_spec = spectral_normalize_torch(mel_spec)
136
+
137
+ return mel_spec
138
+
139
+
140
+ def get_mel_spectrogram(wav, sr):
141
+ """
142
+ Generate mel spectrogram from a waveform using given hyperparameters.
143
+
144
+ Args:
145
+ wav (torch.Tensor): Input waveform.
146
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
147
+
148
+ Returns:
149
+ torch.Tensor: Mel spectrogram.
150
+ """
151
+
152
+ assert sr == config.sampling_rate, (
153
+ f"Given SR : {sr}, Required SR: {config.sampling_rate}"
154
+ )
155
+
156
+ return mel_spectrogram(
157
+ wav,
158
+ config.filter_length,
159
+ config.n_mel_channels,
160
+ config.sampling_rate,
161
+ config.hop_length,
162
+ config.win_length,
163
+ config.mel_fmin,
164
+ config.mel_fmax,
165
+ )
166
+
167
+
168
+ if __name__ == "__main__":
169
+ import torchaudio
170
+
171
+ path = "/delta/NeuralSpeak_cfm_conv/Samples/IITM_cfm_bigv_harsh/S2A/orig/0_test.wav"
172
+ wav, sr = torchaudio.load(path)
173
+
174
+ wav = wav[:, :sr]
175
+ print(wav.shape)
176
+ mel_spec = get_mel_spectrogram(wav, sr)
177
+ duration = wav.shape[-1] / sr
178
+ print(duration, mel_spec.shape)
T2S/meta_stats.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import math
2
+ # from collections import defaultdict
3
+ # from config import config
4
+
5
+
6
+ # def calculate_duration(semt):
7
+ # return math.ceil(((len(semt.split()) + 1) / 50) * 2) / 2
8
+
9
+ # def get_weights(weights,expected_data,languages):
10
+ # new_weights = []
11
+
12
+ # expected_weights = config.weights_percentage_duration
13
+
14
+ # total_files = sum([expected_data['total'][i] for i in expected_data['total']])
15
+
16
+ # duration_multiplier = {i:config.weights_percentage_duration[i]/(expected_data['total'][i]/total_files) for i in config.weights_percentage_duration}
17
+ # print(expected_data['total'],duration_multiplier)
18
+ # for i in weights:
19
+ # new_weights.append(duration_multiplier[i[1]])
20
+ # return new_weights
21
+
22
+
23
+ # def process_file(file_path):
24
+ # weights = []
25
+ # expected_data = defaultdict(lambda: {i:0 for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]})
26
+ # languages = defaultdict(int)
27
+ # count = 0
28
+ # with open(file_path, 'r') as file:
29
+ # for line in file:
30
+ # count+=1
31
+ # lang, path, text, semt, ref_files = line.split('|')
32
+ # languages[lang]+=1
33
+ # dur = calculate_duration(semt)
34
+ # # weights.append([lang,,1.0])
35
+ # # duration_files['total'][dur] += 1
36
+
37
+ # if len(text.strip().split(' '))==1:
38
+ # expected_data[lang]["single_word"]+=1
39
+ # expected_data['total']["single_word"]+=1
40
+ # weights.append([lang,"single_word",1.0])
41
+ # continue
42
+
43
+ # if dur >19.5 and dur<=20:
44
+ # expected_data[lang]["20_sentence"]+=1
45
+ # expected_data['total']["20_sentence"]+=1
46
+ # weights.append([lang,"20_sentence",1.0])
47
+ # continue
48
+
49
+ # if dur<=5:
50
+ # expected_data[lang]["5s"]+=1
51
+ # expected_data['total']["5s"]+=1
52
+ # weights.append([lang,"5s",1.0])
53
+ # continue
54
+ # elif dur<=10:
55
+ # expected_data[lang]["10s"]+=1
56
+ # expected_data['total']["10s"]+=1
57
+ # weights.append([lang,"10s",1.0])
58
+ # continue
59
+ # elif dur<=15:
60
+ # expected_data[lang]["15s"]+=1
61
+ # expected_data['total']["15s"]+=1
62
+ # weights.append([lang,"15s",1.0])
63
+ # continue
64
+ # elif dur<=20:
65
+ # expected_data[lang]["20s"]+=1
66
+ # expected_data['total']["20s"]+=1
67
+ # weights.append([lang,"20s",1.0])
68
+ # continue
69
+ # else:
70
+ # # expected_data[lang][">20"]+=1
71
+ # # expected_data['total'][">20"]+=1
72
+ # # weights.append([lang,">20",1.0])
73
+ # continue
74
+
75
+ # final_weights = get_weights(weights,expected_data,languages)
76
+
77
+
78
+ # return final_weights,count
79
+
80
+ # def process_file_for_heads(file_path,total_processes,process_id):
81
+ # weights = []
82
+ # # heads = defaultdict(lambda: {i:[] for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]}) # to include langauges
83
+ # heads = {i:[] for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]}
84
+
85
+ # expected_data = defaultdict(lambda: {i:0 for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]})
86
+ # languages = defaultdict(int)
87
+ # count = 0
88
+ # line_number = -1
89
+ # with open(file_path, 'r') as file:
90
+ # for line in file:
91
+ # count+=1
92
+ # line_number+=1
93
+ # lang, path, text, semt, ref_files = line.split('|')
94
+ # languages[lang]+=1
95
+ # dur = calculate_duration(semt)
96
+ # # weights.append([lang,,1.0])
97
+ # # duration_files['total'][dur] += 1
98
+
99
+ # if len(text.strip().split(' '))==1:
100
+ # expected_data[lang]["single_word"]+=1
101
+ # expected_data['total']["single_word"]+=1
102
+ # weights.append([lang,"single_word",1.0])
103
+ # heads["single_word"].append(line_number)
104
+ # continue
105
+
106
+ # if dur >19.5 and dur<=20:
107
+ # expected_data[lang]["20_sentence"]+=1
108
+ # expected_data['total']["20_sentence"]+=1
109
+ # weights.append([lang,"20_sentence",1.0])
110
+ # heads["20_sentence"].append(line_number)
111
+ # continue
112
+
113
+ # if dur<=5:
114
+ # expected_data[lang]["5s"]+=1
115
+ # expected_data['total']["5s"]+=1
116
+ # weights.append([lang,"5s",1.0])
117
+ # heads["5s"].append(line_number)
118
+ # continue
119
+ # elif dur<=10:
120
+ # expected_data[lang]["10s"]+=1
121
+ # expected_data['total']["10s"]+=1
122
+ # weights.append([lang,"10s",1.0])
123
+ # heads["10s"].append(line_number)
124
+ # continue
125
+ # elif dur<=15:
126
+ # expected_data[lang]["15s"]+=1
127
+ # expected_data['total']["15s"]+=1
128
+ # weights.append([lang,"15s",1.0])
129
+ # heads["15s"].append(line_number)
130
+ # continue
131
+ # elif dur<=20:
132
+ # expected_data[lang]["20s"]+=1
133
+ # expected_data['total']["20s"]+=1
134
+ # weights.append([lang,"20s",1.0])
135
+ # heads["20s"].append(line_number)
136
+ # continue
137
+ # else:
138
+ # # expected_data[lang][">20"]+=1
139
+ # # expected_data['total'][">20"]+=1
140
+ # # weights.append([lang,">20",1.0])
141
+ # continue
142
+
143
+ # line_number+=1
144
+ # # final_weights = get_weights(weights,expected_data,languages)
145
+ # # final_weights = [1]*len(heads) # same weightage
146
+ # if config.ts_gradient_accumulation_steps>1:
147
+ # batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps//config.ts_num_workers
148
+ # else:
149
+ # batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps
150
+ # # batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps
151
+ # # heads = heads[:-1]
152
+ # heads = {i:heads[i] for i in heads if len(heads[i])!=0}
153
+ # total_size = sum([len(heads[i]) for i in heads if len(heads[i])!=0])
154
+ # norm_nums = [len(heads[i])/total_size for i in heads if len(heads[i])!=0]
155
+ # final_weights = []
156
+
157
+ # for i in norm_nums:
158
+ # final_weights.append(max(1,math.ceil(i*batch)))
159
+
160
+ # rem_elem = sum(final_weights)-batch
161
+ # final_weights[final_weights.index(max(final_weights))]-=rem_elem
162
+
163
+ # # heads,final_weights = sorted(zip(heads,final_weights),key=lambda x:x[1])
164
+
165
+ # # process_head = []
166
+ # # proc = 0
167
+ # # sm=0
168
+ # # for i in final_weights:
169
+ # # # sm+=i
170
+ # # if sm+i >
171
+
172
+ # # process_batch_size = config.ts_batch_size*config.ts_gradient_accumulation_steps
173
+ # # proc = 0
174
+ # # lens = {i:len(heads[i]) for i in heads}
175
+ # # while proc <= process_id:
176
+ # # new_heads ={}
177
+ # # new_weights =[]
178
+ # # sm=0
179
+ # # for i,j in zip(heads,range(len(final_weights))):
180
+ # # if sm + final_weights[j] > process_batch_size:
181
+ # # if sm+final_weights[j] == process_batch_size:
182
+ # # new_heads[i] = heads[i]
183
+ # # new_weights.append(final_weights[j])
184
+ # # heads.pop(i)
185
+ # # else:
186
+ # # new_heads[i] = heads[i][:1+(lens[i]*(process_batch_size-sm)//process_batch_size)]
187
+ # # heads[i] = heads[i][1+(lens[i]*(process_batch_size-sm)//process_batch_size):]
188
+ # # if len(heads[i]) == 0:
189
+ # # heads.pop(i)
190
+ # # new_weights.append(process_batch_size-sm)
191
+ # # final_weights[j]-= process_batch_size-sm
192
+ # # sm = 0
193
+ # # proc+=1
194
+ # # final_weights=final_weights[j:]
195
+ # # break
196
+ # # else:
197
+ # # new_heads[i] = heads[i]
198
+ # # new_weights.append(final_weights[j])
199
+ # # heads.pop(i)
200
+
201
+
202
+ # # if len(heads) == 0:
203
+ # # break
204
+
205
+ # # print("weights",new_weights,[(i,len(heads[i])) for i in new_heads])
206
+ # # return new_heads,new_weights,count
207
+ # # # make it more effective as to real_batch_size instead of worker_batch_size
208
+
209
+ # # # #[867, 31444, 35458, 6764, 1561, 96, 0] per gpu for iitm
210
+ # # # [10,400,400,60,20,1]
211
+ # # #
212
+ # print("weights",final_weights,[(i,len(heads[i])) for i in heads])
213
+ # print(process_id)
214
+ # new_heads, new_weights = process_batches(heads,final_weights,process_id+1)
215
+
216
+ # assert len(new_heads) != 0 and len(new_weights) == len(new_heads), print(new_heads)
217
+
218
+ # print("process id",process_id,new_weights,[(i,len(new_heads[i])) for i in new_heads])
219
+ # return new_heads, new_weights, count
220
+
221
+ # def process_batches(heads, final_weights, process_id=0):
222
+ # if config.ts_gradient_accumulation_steps>1:
223
+ # process_batch_size = config.ts_batch_size * config.ts_gradient_accumulation_steps//config.ts_num_workers
224
+ # else:
225
+ # process_batch_size = config.ts_batch_size * config.ts_gradient_accumulation_steps
226
+ # proc = 0
227
+ # # Create a copy of the original dictionaries to avoid modifying them during iteration
228
+ # remaining_heads = heads.copy()
229
+ # remaining_weights = final_weights.copy()
230
+ # lens = {i: len(heads[i]) for i in heads}
231
+
232
+ # while proc <= process_id and remaining_heads:
233
+ # new_heads = {}
234
+ # new_weights = []
235
+ # current_sum = 0
236
+
237
+ # # Convert items to list to avoid dictionary size change during iteration
238
+ # items = list(remaining_heads.items())
239
+
240
+ # for key, head_list in items:
241
+ # weight = remaining_weights[0] # Get the corresponding weight
242
+
243
+ # if current_sum + weight > process_batch_size:
244
+ # # Calculate how much of this head we can include
245
+ # remaining_space = process_batch_size - current_sum
246
+ # if current_sum + weight == process_batch_size:
247
+ # # If it fits exactly
248
+ # new_heads[key] = head_list
249
+ # new_weights.append(weight)
250
+ # del remaining_heads[key]
251
+ # remaining_weights.pop(0)
252
+ # # print("inside first")
253
+ # else:
254
+ # # If we need to split the head
255
+ # split_point = 1 + (lens[key] * remaining_space) // process_batch_size
256
+ # new_heads[key] = head_list[:split_point]
257
+ # remaining_heads[key] = head_list[split_point:]
258
+ # # print(process_id,"inside >",remaining_heads)
259
+ # if not remaining_heads[key]: # If the remaining list is empty
260
+ # del remaining_heads[key]
261
+
262
+ # new_weights.append(remaining_space)
263
+ # remaining_weights[0] -= remaining_space
264
+ # # print(process_id,"inside >",remaining_heads)
265
+ # # print("inside seciond")
266
+ # proc += 1
267
+ # break
268
+ # else:
269
+ # # If the current head fits completely
270
+ # new_heads[key] = head_list
271
+ # new_weights.append(weight)
272
+ # del remaining_heads[key]
273
+ # remaining_weights.pop(0)
274
+ # current_sum += weight
275
+ # # print("inside third")
276
+ # if len(remaining_heads)==0:
277
+ # proc+=1
278
+ # if proc == process_id:
279
+ # # print("process id",process_id,proc,new_weights,[(i,len(new_heads[i])) for i in new_heads])
280
+ # return new_heads, new_weights
281
+
282
+ # return {}, [] # Return empty structures if no valid batch is found
T2S/plot_embed.py ADDED
@@ -0,0 +1,41 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import matplotlib.pyplot as plt
2
+ # import numpy as np
3
+ # import subprocess
4
+ # import json
5
+ # from umap import UMAP
6
+ # from tqdm import tqdm
7
+
8
+ # def count_lines_shell(file_path):
9
+ # result = subprocess.run(["wc", "-l", file_path], capture_output=True, text=True)
10
+ # return int(result.stdout.split()[0])
11
+
12
+ # def load_chunk(file_path,chunk_size):
13
+ # lines = count_lines_shell(file_path)
14
+ # with open(file_path,'r') as file:
15
+ # dataset = []
16
+ # # embed = []
17
+ # for i in tqdm(file,total=lines):
18
+ # data = json.loads(i)
19
+ # key = list(data.keys())[0]
20
+ # dataset.append([key,data[key][0]])
21
+ # # embed.append(data[key][1])
22
+ # if len(dataset)==chunk_size:
23
+ # return dataset
24
+ # dataset=[]
25
+ # # embed=[]
26
+ # if len(dataset)!=0:
27
+ # return dataset
28
+
29
+
30
+ # if __name__ == '__main__':
31
+
32
+ # file_name = "pocketfm_pure_textlossless_data_stats.json"
33
+ # bs = -1
34
+ # # data = load_chunk(file_name,-1)
35
+ # embed = np.load("/nlsasfs/home/dubverse/varshulg/work/NeuralSpeak/T2S/pocketfm_embeddings.npy")
36
+ # print(embed.shape)
37
+ # plt.scatter(embed[:,0],embed[:,1])
38
+ # # plt.imsave("gst_embed.png")
39
+ # plt.savefig('gst_embed_pocketfm.png')#, dpi=300, bbox_inches='tight')
40
+
41
+
T2S/stft.py ADDED
@@ -0,0 +1,109 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # import torch
2
+ # import numpy as np
3
+ # import torch.nn.functional as F
4
+ # from torch.autograd import Variable
5
+ # from scipy.signal import get_window
6
+ # from librosa.util import pad_center, tiny
7
+ # from utilities import window_sumsquare
8
+
9
+
10
+ # class STFT(torch.nn.Module):
11
+ # """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
12
+ # def __init__(self, filter_length=800, hop_length=200, win_length=800,
13
+ # window='hann'):
14
+ # super(STFT, self).__init__()
15
+ # self.filter_length = filter_length
16
+ # self.hop_length = hop_length
17
+ # self.win_length = win_length
18
+ # self.window = window
19
+ # self.forward_transform = None
20
+ # scale = self.filter_length / self.hop_length
21
+ # fourier_basis = np.fft.fft(np.eye(self.filter_length))
22
+
23
+ # cutoff = int((self.filter_length / 2 + 1))
24
+ # fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
25
+ # np.imag(fourier_basis[:cutoff, :])])
26
+
27
+ # forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
28
+ # inverse_basis = torch.FloatTensor(
29
+ # np.linalg.pinv(scale * fourier_basis).T[:, None, :])
30
+
31
+ # if window is not None:
32
+ # assert(filter_length >= win_length)
33
+ # # get window and zero center pad it to filter_length
34
+ # fft_window = get_window(window, win_length, fftbins=True)
35
+ # fft_window = pad_center(fft_window, size = filter_length)
36
+ # fft_window = torch.from_numpy(fft_window).float()
37
+
38
+ # # window the bases
39
+ # forward_basis *= fft_window
40
+ # inverse_basis *= fft_window
41
+
42
+ # self.register_buffer('forward_basis', forward_basis.float())
43
+ # self.register_buffer('inverse_basis', inverse_basis.float())
44
+
45
+ # def transform(self, input_data):
46
+ # num_batches = input_data.size(0)
47
+ # num_samples = input_data.size(1)
48
+
49
+ # self.num_samples = num_samples
50
+
51
+ # # similar to librosa, reflect-pad the input
52
+ # input_data = input_data.view(num_batches, 1, num_samples)
53
+ # input_data = F.pad(
54
+ # input_data.unsqueeze(1),
55
+ # (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
56
+ # mode='reflect')
57
+ # input_data = input_data.squeeze(1)
58
+
59
+ # forward_transform = F.conv1d(
60
+ # input_data,
61
+ # Variable(self.forward_basis, requires_grad=False),
62
+ # stride=self.hop_length,
63
+ # padding=0)
64
+
65
+ # cutoff = int((self.filter_length / 2) + 1)
66
+ # real_part = forward_transform[:, :cutoff, :]
67
+ # imag_part = forward_transform[:, cutoff:, :]
68
+
69
+ # magnitude = torch.sqrt(real_part**2 + imag_part**2)
70
+ # phase = torch.autograd.Variable(
71
+ # torch.atan2(imag_part.data, real_part.data))
72
+
73
+ # return magnitude, phase
74
+
75
+ # def inverse(self, magnitude, phase):
76
+ # recombine_magnitude_phase = torch.cat(
77
+ # [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
78
+
79
+ # inverse_transform = F.conv_transpose1d(
80
+ # recombine_magnitude_phase,
81
+ # Variable(self.inverse_basis, requires_grad=False),
82
+ # stride=self.hop_length,
83
+ # padding=0)
84
+
85
+ # if self.window is not None:
86
+ # window_sum = window_sumsquare(
87
+ # self.window, magnitude.size(-1), hop_length=self.hop_length,
88
+ # win_length=self.win_length, n_fft=self.filter_length,
89
+ # dtype=np.float32)
90
+ # # remove modulation effects
91
+ # approx_nonzero_indices = torch.from_numpy(
92
+ # np.where(window_sum > tiny(window_sum))[0])
93
+ # window_sum = torch.autograd.Variable(
94
+ # torch.from_numpy(window_sum), requires_grad=False)
95
+ # window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
96
+ # inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
97
+
98
+ # # scale by hop ratio
99
+ # inverse_transform *= float(self.filter_length) / self.hop_length
100
+
101
+ # inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
102
+ # inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
103
+
104
+ # return inverse_transform
105
+
106
+ # def forward(self, input_data):
107
+ # self.magnitude, self.phase = self.transform(input_data)
108
+ # reconstruction = self.inverse(self.magnitude, self.phase)
109
+ # return reconstruction
T2S/stream_generator.py ADDED
@@ -0,0 +1,1002 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Adapted from: https://github.com/LowinLi/transformers-stream-generator
2
+
3
+ import copy
4
+ import inspect
5
+ import random
6
+ import warnings
7
+ from typing import Callable, List, Optional, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+ import torch.distributed as dist
12
+ from torch import nn
13
+ from transformers import (BeamSearchScorer, ConstrainedBeamSearchScorer,
14
+ DisjunctiveConstraint, GenerationConfig,
15
+ GenerationMixin, LogitsProcessorList,
16
+ PhrasalConstraint, PreTrainedModel,
17
+ StoppingCriteriaList)
18
+ from transformers.generation.utils import GenerateOutput, SampleOutput, logger
19
+
20
+
21
+ def setup_seed(seed):
22
+ if seed == -1:
23
+ return
24
+ torch.manual_seed(seed)
25
+ if torch.cuda.is_available():
26
+ torch.cuda.manual_seed_all(seed)
27
+ np.random.seed(seed)
28
+ random.seed(seed)
29
+ torch.backends.cudnn.deterministic = True
30
+
31
+
32
+ class StreamGenerationConfig(GenerationConfig):
33
+ def __init__(self, **kwargs):
34
+ super().__init__(**kwargs)
35
+ self.do_stream = kwargs.pop("do_stream", False)
36
+
37
+
38
+ class NewGenerationMixin(GenerationMixin):
39
+ @torch.no_grad()
40
+ def generate(
41
+ self,
42
+ inputs: Optional[torch.Tensor] = None,
43
+ generation_config: Optional[StreamGenerationConfig] = None,
44
+ logits_processor: Optional[LogitsProcessorList] = None,
45
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
46
+ prefix_allowed_tokens_fn: Optional[
47
+ Callable[[int, torch.Tensor], List[int]]
48
+ ] = None,
49
+ synced_gpus: Optional[bool] = False,
50
+ seed=0,
51
+ **kwargs,
52
+ ) -> Union[GenerateOutput, torch.LongTensor]:
53
+ r"""
54
+
55
+ Generates sequences of token ids for models with a language modeling head.
56
+
57
+ <Tip warning={true}>
58
+
59
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
60
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
61
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
62
+
63
+ For an overview of generation strategies and code examples, check out the [following
64
+ guide](./generation_strategies).
65
+
66
+ </Tip>
67
+
68
+ Parameters:
69
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
70
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
71
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
72
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
73
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
74
+ generation_config (`~generation.GenerationConfig`, *optional*):
75
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
76
+ passed to generate matching the attributes of `generation_config` will override them. If
77
+ `generation_config` is not provided, the default will be used, which had the following loading
78
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
79
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
80
+ default values, whose documentation should be checked to parameterize generation.
81
+ logits_processor (`LogitsProcessorList`, *optional*):
82
+ Custom logits processors that complement the default logits processors built from arguments and
83
+ generation config. If a logit processor is passed that is already created with the arguments or a
84
+ generation config an error is thrown. This feature is intended for advanced users.
85
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
86
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
87
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
88
+ generation config an error is thrown. This feature is intended for advanced users.
89
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
90
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
91
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
92
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
93
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
94
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
95
+ Retrieval](https://arxiv.org/abs/2010.00904).
96
+ synced_gpus (`bool`, *optional*, defaults to `False`):
97
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
98
+ kwargs:
99
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
100
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
101
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
102
+
103
+ Return:
104
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
105
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
106
+
107
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
108
+ [`~utils.ModelOutput`] types are:
109
+
110
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
111
+ - [`~generation.SampleDecoderOnlyOutput`],
112
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
113
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
114
+
115
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
116
+ [`~utils.ModelOutput`] types are:
117
+
118
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
119
+ - [`~generation.SampleEncoderDecoderOutput`],
120
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
121
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
122
+ """
123
+ # setup_seed(seed)
124
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
125
+ self._validate_model_class()
126
+
127
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
128
+ if generation_config is None:
129
+ # legacy: users may modify the model configuration to control generation -- update the generation config
130
+ # model attribute accordingly, if it was created from the model config
131
+ if self.generation_config._from_model_config:
132
+ new_generation_config = StreamGenerationConfig.from_model_config(
133
+ self.config
134
+ )
135
+ if new_generation_config != self.generation_config:
136
+ warnings.warn(
137
+ "You have modified the pretrained model configuration to control generation. This is a"
138
+ " deprecated strategy to control generation and will be removed soon, in a future version."
139
+ " Please use a generation configuration file (see"
140
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
141
+ )
142
+ self.generation_config = new_generation_config
143
+ generation_config = self.generation_config
144
+
145
+ generation_config = copy.deepcopy(generation_config)
146
+ model_kwargs = generation_config.update(
147
+ **kwargs
148
+ ) # All unused kwargs must be model kwargs
149
+ # self._validate_model_kwargs(model_kwargs.copy())
150
+
151
+ # 2. Set generation parameters if not already defined
152
+ logits_processor = (
153
+ logits_processor if logits_processor is not None else LogitsProcessorList()
154
+ )
155
+ stopping_criteria = (
156
+ stopping_criteria
157
+ if stopping_criteria is not None
158
+ else StoppingCriteriaList()
159
+ )
160
+
161
+ if (
162
+ generation_config.pad_token_id is None
163
+ and generation_config.eos_token_id is not None
164
+ ):
165
+ if model_kwargs.get("attention_mask", None) is None:
166
+ logger.warning(
167
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
168
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
169
+ )
170
+ eos_token_id = generation_config.eos_token_id
171
+ if isinstance(eos_token_id, list):
172
+ eos_token_id = eos_token_id[0]
173
+ logger.warning(
174
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
175
+ )
176
+ generation_config.pad_token_id = eos_token_id
177
+
178
+ # 3. Define model inputs
179
+ # inputs_tensor has to be defined
180
+ # model_input_name is defined if model-specific keyword input is passed
181
+ # otherwise model_input_name is None
182
+ # all model-specific keyword inputs are removed from `model_kwargs`
183
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
184
+ inputs, generation_config.bos_token_id, model_kwargs
185
+ )
186
+ batch_size = inputs_tensor.shape[0]
187
+
188
+ # 4. Define other model kwargs
189
+ model_kwargs["output_attentions"] = generation_config.output_attentions
190
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
191
+ model_kwargs["use_cache"] = generation_config.use_cache
192
+
193
+ accepts_attention_mask = "attention_mask" in set(
194
+ inspect.signature(self.forward).parameters.keys()
195
+ )
196
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
197
+
198
+ if (
199
+ model_kwargs.get("attention_mask", None) is None
200
+ and requires_attention_mask
201
+ and accepts_attention_mask
202
+ ):
203
+ # print(generation_config.pad_token_id,generation_config.eos_token_id,inputs_tensor)
204
+ model_kwargs["attention_mask"] = (
205
+ self._prepare_attention_mask_for_generation(
206
+ inputs_tensor,
207
+ generation_config.pad_token_id,
208
+ generation_config.eos_token_id,
209
+ )
210
+ )
211
+
212
+ # decoder-only models should use left-padding for generation
213
+ if not self.config.is_encoder_decoder:
214
+ if (
215
+ generation_config.pad_token_id is not None
216
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
217
+ > 0
218
+ ):
219
+ logger.warning(
220
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
221
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
222
+ )
223
+
224
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
225
+ # if model is encoder decoder encoder_outputs are created
226
+ # and added to `model_kwargs`
227
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
228
+ inputs_tensor, model_kwargs, model_input_name
229
+ )
230
+
231
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
232
+ if self.config.is_encoder_decoder:
233
+ input_ids = self._prepare_decoder_input_ids_for_generation(
234
+ batch_size,
235
+ decoder_start_token_id=generation_config.decoder_start_token_id,
236
+ bos_token_id=generation_config.bos_token_id,
237
+ model_kwargs=model_kwargs,
238
+ device=inputs_tensor.device,
239
+ )
240
+ else:
241
+ # if decoder-only then inputs_tensor has to be `input_ids`
242
+ input_ids = inputs_tensor
243
+
244
+ # 6. Prepare `max_length` depending on other stopping criteria.
245
+ input_ids_seq_length = input_ids.shape[-1]
246
+ has_default_max_length = (
247
+ kwargs.get("max_length") is None
248
+ and generation_config.max_length is not None
249
+ )
250
+ if has_default_max_length and generation_config.max_new_tokens is None:
251
+ warnings.warn(
252
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
253
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
254
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
255
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
256
+ UserWarning,
257
+ )
258
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
259
+ generation_config.max_length = (
260
+ generation_config.max_new_tokens + input_ids_seq_length
261
+ )
262
+ elif (
263
+ not has_default_max_length and generation_config.max_new_tokens is not None
264
+ ):
265
+ raise ValueError(
266
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
267
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
268
+ " documentation for more information. "
269
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
270
+ )
271
+
272
+ if (
273
+ generation_config.min_length is not None
274
+ and generation_config.min_length > generation_config.max_length
275
+ ):
276
+ raise ValueError(
277
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
278
+ f" the maximum length ({generation_config.max_length})"
279
+ )
280
+ if input_ids_seq_length >= generation_config.max_length:
281
+ input_ids_string = (
282
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
283
+ )
284
+ logger.warning(
285
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
286
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
287
+ " increasing `max_new_tokens`."
288
+ )
289
+
290
+ # 7. determine generation mode
291
+ is_constraint_gen_mode = (
292
+ generation_config.constraints is not None
293
+ or generation_config.force_words_ids is not None
294
+ )
295
+
296
+ is_contrastive_search_gen_mode = (
297
+ generation_config.top_k is not None
298
+ and generation_config.top_k > 1
299
+ and generation_config.do_sample is False
300
+ and generation_config.penalty_alpha is not None
301
+ and generation_config.penalty_alpha > 0
302
+ )
303
+
304
+ is_greedy_gen_mode = (
305
+ (generation_config.num_beams == 1)
306
+ and (generation_config.num_beam_groups == 1)
307
+ and generation_config.do_sample is False
308
+ and not is_constraint_gen_mode
309
+ and not is_contrastive_search_gen_mode
310
+ )
311
+ is_sample_gen_mode = (
312
+ (generation_config.num_beams == 1)
313
+ and (generation_config.num_beam_groups == 1)
314
+ and generation_config.do_sample is True
315
+ and generation_config.do_stream is False
316
+ and not is_constraint_gen_mode
317
+ and not is_contrastive_search_gen_mode
318
+ )
319
+ is_sample_gen_stream_mode = (
320
+ (generation_config.num_beams == 1)
321
+ and (generation_config.num_beam_groups == 1)
322
+ and generation_config.do_stream is True
323
+ and not is_constraint_gen_mode
324
+ and not is_contrastive_search_gen_mode
325
+ )
326
+ is_beam_gen_mode = (
327
+ (generation_config.num_beams > 1)
328
+ and (generation_config.num_beam_groups == 1)
329
+ and generation_config.do_sample is False
330
+ and not is_constraint_gen_mode
331
+ and not is_contrastive_search_gen_mode
332
+ )
333
+ is_beam_sample_gen_mode = (
334
+ (generation_config.num_beams > 1)
335
+ and (generation_config.num_beam_groups == 1)
336
+ and generation_config.do_sample is True
337
+ and not is_constraint_gen_mode
338
+ and not is_contrastive_search_gen_mode
339
+ )
340
+ is_group_beam_gen_mode = (
341
+ (generation_config.num_beams > 1)
342
+ and (generation_config.num_beam_groups > 1)
343
+ and not is_constraint_gen_mode
344
+ and not is_contrastive_search_gen_mode
345
+ )
346
+
347
+ if generation_config.num_beam_groups > generation_config.num_beams:
348
+ raise ValueError(
349
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
350
+ )
351
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
352
+ raise ValueError(
353
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
354
+ )
355
+
356
+ if self.device.type != input_ids.device.type:
357
+ warnings.warn(
358
+ "You are calling .generate() with the `input_ids` being on a device type different"
359
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
360
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
361
+ " Please make sure that you have put `input_ids` to the"
362
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
363
+ " running `.generate()`.",
364
+ UserWarning,
365
+ )
366
+ # 8. prepare distribution pre_processing samplers
367
+ logits_processor = self._get_logits_processor(
368
+ generation_config=generation_config,
369
+ input_ids_seq_length=input_ids_seq_length,
370
+ encoder_input_ids=inputs_tensor,
371
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
372
+ logits_processor=logits_processor,
373
+ )
374
+
375
+ # 9. prepare stopping criteria
376
+ stopping_criteria = self._get_stopping_criteria(
377
+ generation_config=generation_config, stopping_criteria=stopping_criteria
378
+ )
379
+ # 10. go into different generation modes
380
+ if is_greedy_gen_mode:
381
+ if generation_config.num_return_sequences > 1:
382
+ raise ValueError(
383
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
384
+ " greedy search."
385
+ )
386
+
387
+ # 11. run greedy search
388
+ return self.greedy_search(
389
+ input_ids,
390
+ logits_processor=logits_processor,
391
+ stopping_criteria=stopping_criteria,
392
+ pad_token_id=generation_config.pad_token_id,
393
+ eos_token_id=generation_config.eos_token_id,
394
+ output_scores=generation_config.output_scores,
395
+ return_dict_in_generate=generation_config.return_dict_in_generate,
396
+ synced_gpus=synced_gpus,
397
+ **model_kwargs,
398
+ )
399
+
400
+ elif is_contrastive_search_gen_mode:
401
+ if generation_config.num_return_sequences > 1:
402
+ raise ValueError(
403
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
404
+ " contrastive search."
405
+ )
406
+
407
+ return self.contrastive_search(
408
+ input_ids,
409
+ top_k=generation_config.top_k,
410
+ penalty_alpha=generation_config.penalty_alpha,
411
+ logits_processor=logits_processor,
412
+ stopping_criteria=stopping_criteria,
413
+ pad_token_id=generation_config.pad_token_id,
414
+ eos_token_id=generation_config.eos_token_id,
415
+ output_scores=generation_config.output_scores,
416
+ return_dict_in_generate=generation_config.return_dict_in_generate,
417
+ synced_gpus=synced_gpus,
418
+ **model_kwargs,
419
+ )
420
+
421
+ elif is_sample_gen_mode:
422
+ # 11. prepare logits warper
423
+ logits_warper = self._get_logits_warper(generation_config)
424
+
425
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
426
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
427
+ input_ids=input_ids,
428
+ expand_size=generation_config.num_return_sequences,
429
+ is_encoder_decoder=self.config.is_encoder_decoder,
430
+ **model_kwargs,
431
+ )
432
+
433
+ # 13. run sample
434
+ return self.sample(
435
+ input_ids,
436
+ logits_processor=logits_processor,
437
+ logits_warper=logits_warper,
438
+ stopping_criteria=stopping_criteria,
439
+ pad_token_id=generation_config.pad_token_id,
440
+ eos_token_id=generation_config.eos_token_id,
441
+ output_scores=generation_config.output_scores,
442
+ return_dict_in_generate=generation_config.return_dict_in_generate,
443
+ synced_gpus=synced_gpus,
444
+ **model_kwargs,
445
+ )
446
+ elif is_sample_gen_stream_mode:
447
+ # 11. prepare logits warper
448
+ logits_warper = self._get_logits_warper(generation_config)
449
+
450
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
451
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
452
+ input_ids=input_ids,
453
+ expand_size=generation_config.num_return_sequences,
454
+ is_encoder_decoder=self.config.is_encoder_decoder,
455
+ **model_kwargs,
456
+ )
457
+
458
+ # 13. run sample
459
+ return self.sample_stream(
460
+ input_ids,
461
+ logits_processor=logits_processor,
462
+ logits_warper=logits_warper,
463
+ stopping_criteria=stopping_criteria,
464
+ pad_token_id=generation_config.pad_token_id,
465
+ eos_token_id=generation_config.eos_token_id,
466
+ output_scores=generation_config.output_scores,
467
+ return_dict_in_generate=generation_config.return_dict_in_generate,
468
+ synced_gpus=synced_gpus,
469
+ **model_kwargs,
470
+ )
471
+ elif is_beam_gen_mode:
472
+ if generation_config.num_return_sequences > generation_config.num_beams:
473
+ raise ValueError(
474
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
475
+ )
476
+
477
+ if stopping_criteria.max_length is None:
478
+ raise ValueError(
479
+ "`max_length` needs to be a stopping_criteria for now."
480
+ )
481
+
482
+ # 11. prepare beam search scorer
483
+ beam_scorer = BeamSearchScorer(
484
+ batch_size=batch_size,
485
+ num_beams=generation_config.num_beams,
486
+ device=inputs_tensor.device,
487
+ length_penalty=generation_config.length_penalty,
488
+ do_early_stopping=generation_config.early_stopping,
489
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
490
+ )
491
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
492
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
493
+ input_ids=input_ids,
494
+ expand_size=generation_config.num_beams,
495
+ is_encoder_decoder=self.config.is_encoder_decoder,
496
+ **model_kwargs,
497
+ )
498
+ # 13. run beam search
499
+ return self.beam_search(
500
+ input_ids,
501
+ beam_scorer,
502
+ logits_processor=logits_processor,
503
+ stopping_criteria=stopping_criteria,
504
+ pad_token_id=generation_config.pad_token_id,
505
+ eos_token_id=generation_config.eos_token_id,
506
+ output_scores=generation_config.output_scores,
507
+ return_dict_in_generate=generation_config.return_dict_in_generate,
508
+ synced_gpus=synced_gpus,
509
+ **model_kwargs,
510
+ )
511
+
512
+ elif is_beam_sample_gen_mode:
513
+ # 11. prepare logits warper
514
+ logits_warper = self._get_logits_warper(generation_config)
515
+
516
+ if stopping_criteria.max_length is None:
517
+ raise ValueError(
518
+ "`max_length` needs to be a stopping_criteria for now."
519
+ )
520
+ # 12. prepare beam search scorer
521
+ beam_scorer = BeamSearchScorer(
522
+ batch_size=batch_size * generation_config.num_return_sequences,
523
+ num_beams=generation_config.num_beams,
524
+ device=inputs_tensor.device,
525
+ length_penalty=generation_config.length_penalty,
526
+ do_early_stopping=generation_config.early_stopping,
527
+ )
528
+
529
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
530
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
531
+ input_ids=input_ids,
532
+ expand_size=generation_config.num_beams
533
+ * generation_config.num_return_sequences,
534
+ is_encoder_decoder=self.config.is_encoder_decoder,
535
+ **model_kwargs,
536
+ )
537
+
538
+ # 14. run beam sample
539
+ return self.beam_sample(
540
+ input_ids,
541
+ beam_scorer,
542
+ logits_processor=logits_processor,
543
+ logits_warper=logits_warper,
544
+ stopping_criteria=stopping_criteria,
545
+ pad_token_id=generation_config.pad_token_id,
546
+ eos_token_id=generation_config.eos_token_id,
547
+ output_scores=generation_config.output_scores,
548
+ return_dict_in_generate=generation_config.return_dict_in_generate,
549
+ synced_gpus=synced_gpus,
550
+ **model_kwargs,
551
+ )
552
+
553
+ elif is_group_beam_gen_mode:
554
+ if generation_config.num_return_sequences > generation_config.num_beams:
555
+ raise ValueError(
556
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
557
+ )
558
+
559
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
560
+ raise ValueError(
561
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
562
+ )
563
+
564
+ if stopping_criteria.max_length is None:
565
+ raise ValueError(
566
+ "`max_length` needs to be a stopping_criteria for now."
567
+ )
568
+
569
+ has_default_typical_p = (
570
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
571
+ )
572
+ if not has_default_typical_p:
573
+ raise ValueError(
574
+ "Decoder argument `typical_p` is not supported with beam groups."
575
+ )
576
+
577
+ # 11. prepare beam search scorer
578
+ beam_scorer = BeamSearchScorer(
579
+ batch_size=batch_size,
580
+ num_beams=generation_config.num_beams,
581
+ max_length=stopping_criteria.max_length,
582
+ device=inputs_tensor.device,
583
+ length_penalty=generation_config.length_penalty,
584
+ do_early_stopping=generation_config.early_stopping,
585
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
586
+ num_beam_groups=generation_config.num_beam_groups,
587
+ )
588
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
589
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
590
+ input_ids=input_ids,
591
+ expand_size=generation_config.num_beams,
592
+ is_encoder_decoder=self.config.is_encoder_decoder,
593
+ **model_kwargs,
594
+ )
595
+ # 13. run beam search
596
+ return self.group_beam_search(
597
+ input_ids,
598
+ beam_scorer,
599
+ logits_processor=logits_processor,
600
+ stopping_criteria=stopping_criteria,
601
+ pad_token_id=generation_config.pad_token_id,
602
+ eos_token_id=generation_config.eos_token_id,
603
+ output_scores=generation_config.output_scores,
604
+ return_dict_in_generate=generation_config.return_dict_in_generate,
605
+ synced_gpus=synced_gpus,
606
+ **model_kwargs,
607
+ )
608
+
609
+ elif is_constraint_gen_mode:
610
+ if generation_config.num_return_sequences > generation_config.num_beams:
611
+ raise ValueError(
612
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
613
+ )
614
+
615
+ if stopping_criteria.max_length is None:
616
+ raise ValueError(
617
+ "`max_length` needs to be a stopping_criteria for now."
618
+ )
619
+
620
+ if generation_config.num_beams <= 1:
621
+ raise ValueError(
622
+ "`num_beams` needs to be greater than 1 for constrained generation."
623
+ )
624
+
625
+ if generation_config.do_sample:
626
+ raise ValueError(
627
+ "`do_sample` needs to be false for constrained generation."
628
+ )
629
+
630
+ if (
631
+ generation_config.num_beam_groups is not None
632
+ and generation_config.num_beam_groups > 1
633
+ ):
634
+ raise ValueError(
635
+ "`num_beam_groups` not supported yet for constrained generation."
636
+ )
637
+
638
+ final_constraints = []
639
+ if generation_config.constraints is not None:
640
+ final_constraints = generation_config.constraints
641
+
642
+ if generation_config.force_words_ids is not None:
643
+
644
+ def typeerror():
645
+ raise ValueError(
646
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
647
+ f"of positive integers, but is {generation_config.force_words_ids}."
648
+ )
649
+
650
+ if (
651
+ not isinstance(generation_config.force_words_ids, list)
652
+ or len(generation_config.force_words_ids) == 0
653
+ ):
654
+ typeerror()
655
+
656
+ for word_ids in generation_config.force_words_ids:
657
+ if isinstance(word_ids[0], list):
658
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
659
+ typeerror()
660
+ if any(
661
+ not isinstance(token_ids, list) for token_ids in word_ids
662
+ ):
663
+ typeerror()
664
+ if any(
665
+ any(
666
+ (not isinstance(token_id, int) or token_id < 0)
667
+ for token_id in token_ids
668
+ )
669
+ for token_ids in word_ids
670
+ ):
671
+ typeerror()
672
+
673
+ constraint = DisjunctiveConstraint(word_ids)
674
+ else:
675
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
676
+ typeerror()
677
+ if any(
678
+ (not isinstance(token_id, int) or token_id < 0)
679
+ for token_id in word_ids
680
+ ):
681
+ typeerror()
682
+
683
+ constraint = PhrasalConstraint(word_ids)
684
+ final_constraints.append(constraint)
685
+
686
+ # 11. prepare beam search scorer
687
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
688
+ constraints=final_constraints,
689
+ batch_size=batch_size,
690
+ num_beams=generation_config.num_beams,
691
+ device=inputs_tensor.device,
692
+ length_penalty=generation_config.length_penalty,
693
+ do_early_stopping=generation_config.early_stopping,
694
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
695
+ )
696
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
697
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
698
+ input_ids=input_ids,
699
+ expand_size=generation_config.num_beams,
700
+ is_encoder_decoder=self.config.is_encoder_decoder,
701
+ **model_kwargs,
702
+ )
703
+ # 13. run beam search
704
+ return self.constrained_beam_search(
705
+ input_ids,
706
+ constrained_beam_scorer=constrained_beam_scorer,
707
+ logits_processor=logits_processor,
708
+ stopping_criteria=stopping_criteria,
709
+ pad_token_id=generation_config.pad_token_id,
710
+ eos_token_id=generation_config.eos_token_id,
711
+ output_scores=generation_config.output_scores,
712
+ return_dict_in_generate=generation_config.return_dict_in_generate,
713
+ synced_gpus=synced_gpus,
714
+ **model_kwargs,
715
+ )
716
+
717
+ @torch.no_grad()
718
+ def sample_stream(
719
+ self,
720
+ input_ids: torch.LongTensor,
721
+ logits_processor: Optional[LogitsProcessorList] = None,
722
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
723
+ logits_warper: Optional[LogitsProcessorList] = None,
724
+ max_length: Optional[int] = None,
725
+ pad_token_id: Optional[int] = None,
726
+ eos_token_id: Optional[Union[int, List[int]]] = None,
727
+ output_attentions: Optional[bool] = None,
728
+ output_hidden_states: Optional[bool] = None,
729
+ output_scores: Optional[bool] = None,
730
+ return_dict_in_generate: Optional[bool] = None,
731
+ synced_gpus: Optional[bool] = False,
732
+ **model_kwargs,
733
+ ) -> Union[SampleOutput, torch.LongTensor]:
734
+ r"""
735
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
736
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
737
+
738
+ <Tip warning={true}>
739
+
740
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
741
+ For an overview of generation strategies and code examples, check the [following
742
+ guide](./generation_strategies).
743
+
744
+ </Tip>
745
+
746
+ Parameters:
747
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
748
+ The sequence used as a prompt for the generation.
749
+ logits_processor (`LogitsProcessorList`, *optional*):
750
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
751
+ used to modify the prediction scores of the language modeling head applied at each generation step.
752
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
753
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
754
+ used to tell if the generation loop should stop.
755
+ logits_warper (`LogitsProcessorList`, *optional*):
756
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
757
+ to warp the prediction score distribution of the language modeling head applied before multinomial
758
+ sampling at each generation step.
759
+ max_length (`int`, *optional*, defaults to 20):
760
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
761
+ tokens. The maximum length of the sequence to be generated.
762
+ pad_token_id (`int`, *optional*):
763
+ The id of the *padding* token.
764
+ eos_token_id (`int`, *optional*):
765
+ The id of the *end-of-sequence* token.
766
+ output_attentions (`bool`, *optional*, defaults to `False`):
767
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
768
+ returned tensors for more details.
769
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
770
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
771
+ for more details.
772
+ output_scores (`bool`, *optional*, defaults to `False`):
773
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
774
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
775
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
776
+ synced_gpus (`bool`, *optional*, defaults to `False`):
777
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
778
+ model_kwargs:
779
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
780
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
781
+
782
+ Return:
783
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
784
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
785
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
786
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
787
+ `model.config.is_encoder_decoder=True`.
788
+
789
+ Examples:
790
+
791
+ ```python
792
+ >>> from transformers import (
793
+ ... AutoTokenizer,
794
+ ... AutoModelForCausalLM,
795
+ ... LogitsProcessorList,
796
+ ... MinLengthLogitsProcessor,
797
+ ... TopKLogitsWarper,
798
+ ... TemperatureLogitsWarper,
799
+ ... StoppingCriteriaList,
800
+ ... MaxLengthCriteria,
801
+ ... )
802
+ >>> import torch
803
+
804
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
805
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
806
+
807
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
808
+ >>> model.config.pad_token_id = model.config.eos_token_id
809
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
810
+
811
+ >>> input_prompt = "Today is a beautiful day, and"
812
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
813
+
814
+ >>> # instantiate logits processors
815
+ >>> logits_processor = LogitsProcessorList(
816
+ ... [
817
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
818
+ ... ]
819
+ ... )
820
+ >>> # instantiate logits processors
821
+ >>> logits_warper = LogitsProcessorList(
822
+ ... [
823
+ ... TopKLogitsWarper(50),
824
+ ... TemperatureLogitsWarper(0.7),
825
+ ... ]
826
+ ... )
827
+
828
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
829
+
830
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
831
+ >>> outputs = model.sample(
832
+ ... input_ids,
833
+ ... logits_processor=logits_processor,
834
+ ... logits_warper=logits_warper,
835
+ ... stopping_criteria=stopping_criteria,
836
+ ... )
837
+
838
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
839
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
840
+ ```"""
841
+ # init values
842
+ logits_processor = (
843
+ logits_processor if logits_processor is not None else LogitsProcessorList()
844
+ )
845
+ stopping_criteria = (
846
+ stopping_criteria
847
+ if stopping_criteria is not None
848
+ else StoppingCriteriaList()
849
+ )
850
+ if max_length is not None:
851
+ warnings.warn(
852
+ "`max_length` is deprecated in this function, use"
853
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
854
+ UserWarning,
855
+ )
856
+ stopping_criteria = validate_stopping_criteria(
857
+ stopping_criteria, max_length
858
+ )
859
+ logits_warper = (
860
+ logits_warper if logits_warper is not None else LogitsProcessorList()
861
+ )
862
+ pad_token_id = (
863
+ pad_token_id
864
+ if pad_token_id is not None
865
+ else self.generation_config.pad_token_id
866
+ )
867
+ eos_token_id = (
868
+ eos_token_id
869
+ if eos_token_id is not None
870
+ else self.generation_config.eos_token_id
871
+ )
872
+ if isinstance(eos_token_id, int):
873
+ eos_token_id = [eos_token_id]
874
+ output_scores = (
875
+ output_scores
876
+ if output_scores is not None
877
+ else self.generation_config.output_scores
878
+ )
879
+ output_attentions = (
880
+ output_attentions
881
+ if output_attentions is not None
882
+ else self.generation_config.output_attentions
883
+ )
884
+ output_hidden_states = (
885
+ output_hidden_states
886
+ if output_hidden_states is not None
887
+ else self.generation_config.output_hidden_states
888
+ )
889
+ return_dict_in_generate = (
890
+ return_dict_in_generate
891
+ if return_dict_in_generate is not None
892
+ else self.generation_config.return_dict_in_generate
893
+ )
894
+
895
+ # init attention / hidden states / scores tuples
896
+ scores = () if (return_dict_in_generate and output_scores) else None
897
+ decoder_attentions = (
898
+ () if (return_dict_in_generate and output_attentions) else None
899
+ )
900
+ cross_attentions = (
901
+ () if (return_dict_in_generate and output_attentions) else None
902
+ )
903
+ decoder_hidden_states = (
904
+ () if (return_dict_in_generate and output_hidden_states) else None
905
+ )
906
+
907
+ # keep track of which sequences are already finished
908
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
909
+
910
+ this_peer_finished = False # used by synced_gpus only
911
+ # auto-regressive generation
912
+ while True:
913
+ if synced_gpus:
914
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
915
+ # The following logic allows an early break if all peers finished generating their sequence
916
+ this_peer_finished_flag = torch.tensor(
917
+ 0.0 if this_peer_finished else 1.0
918
+ ).to(input_ids.device)
919
+ # send 0.0 if we finished, 1.0 otherwise
920
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
921
+ # did all peers finish? the reduced sum will be 0.0 then
922
+ if this_peer_finished_flag.item() == 0.0:
923
+ break
924
+
925
+ # prepare model inputs
926
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
927
+
928
+ # forward pass to get next token
929
+ outputs = self(
930
+ **model_inputs,
931
+ return_dict=True,
932
+ output_attentions=output_attentions,
933
+ output_hidden_states=output_hidden_states,
934
+ )
935
+
936
+ if synced_gpus and this_peer_finished:
937
+ continue # don't waste resources running the code we don't need
938
+
939
+ next_token_logits = outputs.logits[:, -1, :]
940
+
941
+ # pre-process distribution
942
+ next_token_scores = logits_processor(input_ids, next_token_logits)
943
+ next_token_scores = logits_warper(input_ids, next_token_scores)
944
+
945
+ # Store scores, attentions and hidden_states when required
946
+ if return_dict_in_generate:
947
+ if output_scores:
948
+ scores += (next_token_scores,)
949
+ if output_attentions:
950
+ decoder_attentions += (
951
+ (outputs.decoder_attentions,)
952
+ if self.config.is_encoder_decoder
953
+ else (outputs.attentions,)
954
+ )
955
+ if self.config.is_encoder_decoder:
956
+ cross_attentions += (outputs.cross_attentions,)
957
+
958
+ if output_hidden_states:
959
+ decoder_hidden_states += (
960
+ (outputs.decoder_hidden_states,)
961
+ if self.config.is_encoder_decoder
962
+ else (outputs.hidden_states,)
963
+ )
964
+
965
+ # sample
966
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
967
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
968
+
969
+ # finished sentences should have their next token be a padding token
970
+ if eos_token_id is not None:
971
+ if pad_token_id is None:
972
+ raise ValueError(
973
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
974
+ )
975
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
976
+ 1 - unfinished_sequences
977
+ )
978
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
979
+ # update generated ids, model inputs, and length for next step
980
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
981
+ model_kwargs = self._update_model_kwargs_for_generation(
982
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
983
+ )
984
+
985
+ # if eos_token was found in one sentence, set sentence to finished
986
+ if eos_token_id is not None:
987
+ unfinished_sequences = unfinished_sequences.mul(
988
+ (sum(next_tokens != i for i in eos_token_id)).long()
989
+ )
990
+
991
+ # stop when each sentence is finished, or if we exceed the maximum length
992
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
993
+ if not synced_gpus:
994
+ break
995
+ else:
996
+ this_peer_finished = True
997
+
998
+
999
+ def init_stream_support():
1000
+ """Overload PreTrainedModel for streaming."""
1001
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
1002
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
T2S/t2s_modules.py ADDED
@@ -0,0 +1,609 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+ import torch.nn.init as init
7
+ from einops import rearrange, repeat
8
+
9
+
10
+ def zero_module(module):
11
+ """
12
+ Zero out the parameters of a module and return it.
13
+ Using it for Zero Convolutions
14
+ """
15
+ for p in module.parameters():
16
+ p.detach().zero_()
17
+ return module
18
+
19
+
20
+ class GroupNorm32(nn.GroupNorm):
21
+ def forward(self, x):
22
+ return super().forward(x.float()).type(x.dtype)
23
+
24
+
25
+ def normalization(channels):
26
+ """
27
+ Make a standard normalization layer. of groups ranging from 2 to 32.
28
+
29
+ :param channels: number of input channels.
30
+ :return: an nn.Module for normalization.
31
+ """
32
+ groups = 32
33
+ if channels <= 16:
34
+ groups = 8
35
+ elif channels <= 64:
36
+ groups = 16
37
+ while channels % groups != 0:
38
+ groups = int(groups / 2)
39
+ assert groups > 2
40
+ return GroupNorm32(groups, channels)
41
+
42
+
43
+ class mySequential(nn.Sequential):
44
+ """Using this to pass mask variable to nn layers"""
45
+
46
+ def forward(self, *inputs):
47
+ for module in self._modules.values():
48
+ if type(inputs) == tuple:
49
+ inputs = module(*inputs)
50
+ else:
51
+ inputs = module(inputs)
52
+ return inputs
53
+
54
+
55
+ class SepConv1D(nn.Module):
56
+ """Depth wise separable Convolution layer with mask"""
57
+
58
+ def __init__(
59
+ self,
60
+ nin,
61
+ nout,
62
+ kernel_size,
63
+ stride=1,
64
+ dilation=1,
65
+ padding_mode="same",
66
+ bias=True,
67
+ ):
68
+ super(SepConv1D, self).__init__()
69
+ self.conv1 = nn.Conv1d(
70
+ nin,
71
+ nin,
72
+ kernel_size=kernel_size,
73
+ stride=stride,
74
+ groups=nin,
75
+ dilation=dilation,
76
+ padding=padding_mode,
77
+ bias=bias,
78
+ )
79
+ self.conv2 = nn.Conv1d(
80
+ nin, nout, kernel_size=1, stride=1, padding=padding_mode, bias=bias
81
+ )
82
+
83
+ def forward(self, x, mask=None):
84
+ if mask is not None:
85
+ x = x * mask.unsqueeze(1).to(device=x.device)
86
+ x = self.conv1(x)
87
+ x = self.conv2(x)
88
+ return x, mask
89
+
90
+
91
+ class Conv1DBN(nn.Module):
92
+ def __init__(
93
+ self,
94
+ nin,
95
+ nout,
96
+ kernel_size,
97
+ stride=1,
98
+ dilation=1,
99
+ dropout=0.1,
100
+ padding_mode="same",
101
+ bias=False,
102
+ ):
103
+ super(Conv1DBN, self).__init__()
104
+ self.conv1 = nn.Conv1d(
105
+ nin,
106
+ nout,
107
+ kernel_size=kernel_size,
108
+ stride=stride,
109
+ padding=padding_mode,
110
+ dilation=dilation,
111
+ bias=bias,
112
+ )
113
+ self.bn = nn.BatchNorm1d(nout)
114
+ self.drop = nn.Dropout(dropout)
115
+
116
+ def forward(self, x, mask=None):
117
+ if mask is not None:
118
+ x = x * mask.unsqueeze(1).to(device=x.device)
119
+ x = self.conv1(x)
120
+ x = self.bn(x)
121
+ x = F.relu(x)
122
+ x = self.drop(x)
123
+ return x, mask
124
+
125
+
126
+ class Conv1d(nn.Module):
127
+ """normal conv1d with mask"""
128
+
129
+ def __init__(self, nin, nout, kernel_size, padding, bias=True):
130
+ super(Conv1d, self).__init__()
131
+ self.l = nn.Conv1d(nin, nout, kernel_size, padding=padding, bias=bias)
132
+
133
+ def forward(self, x, mask):
134
+ if mask is not None:
135
+ x = x * mask.unsqueeze(1).to(device=x.device)
136
+ y = self.l(x)
137
+ return y, mask
138
+
139
+
140
+ class SqueezeExcite(nn.Module):
141
+ """Let the CNN decide how to add across channels"""
142
+
143
+ def __init__(self, nin, ratio=8):
144
+ super(SqueezeExcite, self).__init__()
145
+ self.nin = nin
146
+ self.ratio = ratio
147
+
148
+ self.fc = mySequential(
149
+ nn.Linear(nin, nin // ratio, bias=True),
150
+ nn.SiLU(inplace=True),
151
+ nn.Linear(nin // ratio, nin, bias=True),
152
+ )
153
+
154
+ def forward(self, x, mask=None):
155
+ if mask is None:
156
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=torch.bool).to(x.device)
157
+ mask = ~mask
158
+ x = x.float()
159
+ x.masked_fill_(mask.unsqueeze(1), 0.0)
160
+ mask = ~mask
161
+ y = (
162
+ torch.sum(x, dim=-1, keepdim=True)
163
+ / mask.unsqueeze(1).sum(dim=-1, keepdim=True)
164
+ ).type(x.dtype)
165
+ # y=torch.mean(x,-1,keepdim=True)
166
+ y = y.transpose(1, -1)
167
+ y = self.fc(y)
168
+ y = torch.sigmoid(y)
169
+ y = y.transpose(1, -1)
170
+ y = x * y
171
+ return y, mask
172
+
173
+
174
+ class SCBD(nn.Module):
175
+ """SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet"""
176
+
177
+ def __init__(
178
+ self, nin, nout, kernel_size, p=0.1, rd=True, separable=True, bias=True
179
+ ):
180
+ super(SCBD, self).__init__()
181
+ if separable:
182
+ self.SC = SepConv1D(nin, nout, kernel_size, bias=bias)
183
+ else:
184
+ self.SC = Conv1d(nin, nout, kernel_size, padding="same", bias=bias)
185
+
186
+ if rd: # relu and Dropout
187
+ self.mout = mySequential(
188
+ normalization(nout),
189
+ nn.SiLU(), # nn.BatchNorm1d(nout,eps)
190
+ nn.Dropout(p),
191
+ )
192
+ else:
193
+ self.mout = normalization(nout) # nn.BatchNorm1d(nout,eps)
194
+
195
+ def forward(self, x, mask=None):
196
+ if mask is not None:
197
+ x = x * mask.unsqueeze(1).to(device=x.device)
198
+ x, _ = self.SC(x, mask)
199
+ y = self.mout(x)
200
+ return y, mask
201
+
202
+
203
+ class QuartzNetBlock(nn.Module):
204
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
205
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
206
+ """
207
+
208
+ def __init__(
209
+ self,
210
+ nin,
211
+ nout,
212
+ kernel_size,
213
+ dropout=0.1,
214
+ R=5,
215
+ se=False,
216
+ ratio=8,
217
+ separable=False,
218
+ bias=True,
219
+ ):
220
+ super(QuartzNetBlock, self).__init__()
221
+ self.se = se
222
+ self.residual = mySequential(
223
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
224
+ normalization(nout), # nn.BatchNorm1d(nout,eps)
225
+ )
226
+ model = []
227
+
228
+ for i in range(R - 1):
229
+ model.append(SCBD(nin, nout, kernel_size, dropout, eps=0.001, bias=bias))
230
+ nin = nout
231
+
232
+ if separable:
233
+ model.append(
234
+ SCBD(nin, nout, kernel_size, dropout, eps=0.001, rd=False, bias=bias)
235
+ )
236
+ else:
237
+ model.append(
238
+ SCBD(
239
+ nin,
240
+ nout,
241
+ kernel_size,
242
+ dropout,
243
+ eps=0.001,
244
+ rd=False,
245
+ separable=False,
246
+ bias=bias,
247
+ )
248
+ )
249
+ self.model = mySequential(*model)
250
+
251
+ if self.se:
252
+ # model.append(SqueezeExcite(nin,ratio))
253
+ self.se_layer = SqueezeExcite(nin, ratio)
254
+
255
+ self.mout = mySequential(nn.SiLU(), nn.Dropout(dropout))
256
+
257
+ def forward(self, x, mask=None):
258
+ if mask is not None:
259
+ x = x * mask.unsqueeze(1).to(device=x.device)
260
+ y, _ = self.model(x, mask)
261
+ if self.se:
262
+ y, _ = self.se_layer(y, mask)
263
+ y += self.residual(x)
264
+ y = self.mout(y)
265
+ return y, mask
266
+
267
+
268
+ class QKVAttentionLegacy(nn.Module):
269
+ """
270
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
271
+ """
272
+
273
+ def __init__(self, n_heads):
274
+ super().__init__()
275
+ self.n_heads = n_heads
276
+
277
+ def forward(self, qkv, mask=None, rel_pos=None):
278
+ """
279
+ Apply QKV attention.
280
+
281
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
282
+ :return: an [N x (H * C) x T] tensor after attention.
283
+ """
284
+ bs, width, length = qkv.shape
285
+ assert width % (3 * self.n_heads) == 0
286
+ ch = width // (3 * self.n_heads)
287
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
288
+ scale = 1 / math.sqrt(math.sqrt(ch))
289
+ weight = torch.einsum(
290
+ "bct,bcs->bts", q * scale, k * scale
291
+ ) # More stable with f16 than dividing afterwards
292
+ if rel_pos is not None:
293
+ weight = rel_pos(
294
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
295
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
296
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
297
+ if mask is not None:
298
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
299
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
300
+ weight = weight * mask
301
+ a = torch.einsum("bts,bcs->bct", weight, v)
302
+
303
+ return a.reshape(bs, -1, length)
304
+
305
+
306
+ class AttentionBlock(nn.Module):
307
+ """
308
+ An attention block that allows spatial positions to attend to each other.
309
+
310
+ Originally ported from here, but adapted to the N-d case.
311
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
312
+ """
313
+
314
+ def __init__(
315
+ self,
316
+ channels,
317
+ num_heads=1,
318
+ num_head_channels=-1,
319
+ do_checkpoint=True,
320
+ relative_pos_embeddings=False,
321
+ ):
322
+ super().__init__()
323
+ self.channels = channels
324
+ self.do_checkpoint = do_checkpoint
325
+ if num_head_channels == -1:
326
+ self.num_heads = num_heads
327
+ else:
328
+ assert channels % num_head_channels == 0, (
329
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
330
+ )
331
+ self.num_heads = channels // num_head_channels
332
+ self.norm = normalization(channels)
333
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
334
+ # split heads before split qkv
335
+ self.attention = QKVAttentionLegacy(self.num_heads)
336
+
337
+ self.proj_out = zero_module(
338
+ nn.Conv1d(channels, channels, 1)
339
+ ) # no effect of attention in the inital stages.
340
+ # if relative_pos_embeddings:
341
+ self.relative_pos_embeddings = RelativePositionBias(
342
+ scale=(channels // self.num_heads) ** 0.5,
343
+ causal=False,
344
+ heads=num_heads,
345
+ num_buckets=32,
346
+ max_distance=64,
347
+ ) # need to read about this, vit and swin transformers
348
+ # self.relative_pos_embeddings = FixedPositionalEmbedding(dim=channels)
349
+ # else:
350
+ # self.relative_pos_embeddings = None
351
+
352
+ def forward(self, x, mask=None):
353
+ b, c, *spatial = x.shape
354
+ x = x.reshape(b, c, -1)
355
+ qkv = self.qkv(self.norm(x))
356
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
357
+ h = self.proj_out(h)
358
+ return (x + h).reshape(b, c, *spatial)
359
+
360
+
361
+ class AbsolutePositionalEmbedding(nn.Module):
362
+ def __init__(self, dim, max_seq_len):
363
+ super().__init__()
364
+ self.scale = dim**-0.5
365
+ self.emb = nn.Embedding(max_seq_len, dim)
366
+
367
+ def forward(self, x):
368
+ n = torch.arange(x.shape[1], device=x.device)
369
+ pos_emb = self.emb(n)
370
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
371
+ return pos_emb * self.scale
372
+
373
+
374
+ class FixedPositionalEmbedding(nn.Module):
375
+ def __init__(self, dim):
376
+ super().__init__()
377
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
378
+ self.register_buffer("inv_freq", inv_freq)
379
+
380
+ def forward(self, x, seq_dim=1, offset=0):
381
+ t = (
382
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
383
+ + offset
384
+ )
385
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
386
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
387
+ return rearrange(emb, "n d -> () n d")
388
+
389
+
390
+ class RelativePositionBias(nn.Module):
391
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
392
+ super().__init__()
393
+ self.scale = scale
394
+ self.causal = causal
395
+ self.num_buckets = num_buckets
396
+ self.max_distance = max_distance
397
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
398
+
399
+ @staticmethod
400
+ def _relative_position_bucket(
401
+ relative_position, causal=True, num_buckets=32, max_distance=128
402
+ ):
403
+ ret = 0
404
+ n = -relative_position
405
+ if not causal:
406
+ num_buckets //= 2
407
+ ret += (n < 0).long() * num_buckets
408
+ n = torch.abs(n)
409
+ else:
410
+ n = torch.max(n, torch.zeros_like(n))
411
+
412
+ max_exact = num_buckets // 2
413
+ is_small = n < max_exact
414
+
415
+ val_if_large = (
416
+ max_exact
417
+ + (
418
+ torch.log(n.float() / max_exact)
419
+ / math.log(max_distance / max_exact)
420
+ * (num_buckets - max_exact)
421
+ ).long()
422
+ )
423
+ val_if_large = torch.min(
424
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
425
+ )
426
+
427
+ ret += torch.where(is_small, n, val_if_large)
428
+ return ret
429
+
430
+ def forward(self, qk_dots):
431
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
432
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
433
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
434
+ rel_pos = k_pos[None, :] - q_pos[:, None]
435
+ rp_bucket = self._relative_position_bucket(
436
+ rel_pos,
437
+ causal=self.causal,
438
+ num_buckets=self.num_buckets,
439
+ max_distance=self.max_distance,
440
+ )
441
+ values = self.relative_attention_bias(rp_bucket)
442
+ bias = rearrange(values, "i j h -> () h i j")
443
+ return qk_dots + (bias * self.scale)
444
+
445
+
446
+ class MultiHeadAttention(nn.Module):
447
+ """
448
+ only for GST
449
+ input:
450
+ query --- [N, T_q, query_dim]
451
+ key --- [N, T_k, key_dim]
452
+ output:
453
+ out --- [N, T_q, num_units]
454
+ """
455
+
456
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
457
+ super().__init__()
458
+ self.num_units = num_units
459
+ self.num_heads = num_heads
460
+ self.key_dim = key_dim
461
+
462
+ self.W_query = nn.Linear(
463
+ in_features=query_dim, out_features=num_units, bias=False
464
+ )
465
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
466
+ self.W_value = nn.Linear(
467
+ in_features=key_dim, out_features=num_units, bias=False
468
+ )
469
+
470
+ def forward(self, query, key):
471
+ querys = self.W_query(query) # [N, T_q, num_units]
472
+ keys = self.W_key(key) # [N, T_k, num_units]
473
+ values = self.W_value(key)
474
+
475
+ split_size = self.num_units // self.num_heads
476
+ querys = torch.stack(
477
+ torch.split(querys, split_size, dim=2), dim=0
478
+ ) # [h, N, T_q, num_units/h]
479
+ keys = torch.stack(
480
+ torch.split(keys, split_size, dim=2), dim=0
481
+ ) # [h, N, T_k, num_units/h]
482
+ values = torch.stack(
483
+ torch.split(values, split_size, dim=2), dim=0
484
+ ) # [h, N, T_k, num_units/h]
485
+
486
+ # score = softmax(QK^T / (d_k ** 0.5))
487
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
488
+ scores = scores / (self.key_dim**0.5)
489
+ scores = F.softmax(scores, dim=3)
490
+
491
+ # out = score * V
492
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
493
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
494
+ 0
495
+ ) # [N, T_q, num_units]
496
+
497
+ return out
498
+
499
+
500
+ class GST(nn.Module):
501
+ def __init__(
502
+ self, model_channels=512, style_tokens=100, num_heads=8, in_channels=80
503
+ ):
504
+ super(GST, self).__init__()
505
+ self.model_channels = model_channels
506
+ self.style_tokens = style_tokens
507
+ self.num_heads = num_heads
508
+
509
+ # self.reference_encoder=nn.Sequential(
510
+ # nn.Conv2d(1,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
511
+ # nn.Conv2d(32,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
512
+ # nn.Conv2d(32,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
513
+ # nn.Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
514
+ # AttentionBlock(64, 8, relative_pos_embeddings=True),
515
+ # nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
516
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
517
+ # nn.Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
518
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
519
+ # nn.Conv2d(128,model_channels,kernel_size=(3,3),stride=(1,1),padding=(1, 1)),normalization(model_channels),nn.ReLU(inplace=True),
520
+ # AttentionBlock(model_channels, 16, relative_pos_embeddings=True)
521
+ # )
522
+
523
+ # self.reference_encoder=nn.Sequential(
524
+ # nn.Conv1d(80,model_channels,3,padding=1,stride=2),
525
+ # nn.Conv1d(model_channels, model_channels,3,padding=1,stride=2),
526
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
527
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
528
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
529
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
530
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
531
+ # )
532
+
533
+ # in_channels=1
534
+ # num_heads = 8
535
+ self.reference_encoder = nn.Sequential(
536
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
537
+ nn.Conv1d(model_channels, model_channels, 3, padding=1, stride=2),
538
+ AttentionBlock(
539
+ model_channels,
540
+ num_heads,
541
+ relative_pos_embeddings=True,
542
+ do_checkpoint=False,
543
+ ),
544
+ AttentionBlock(
545
+ model_channels,
546
+ num_heads,
547
+ relative_pos_embeddings=True,
548
+ do_checkpoint=False,
549
+ ),
550
+ AttentionBlock(
551
+ model_channels,
552
+ num_heads,
553
+ relative_pos_embeddings=True,
554
+ do_checkpoint=False,
555
+ ),
556
+ AttentionBlock(
557
+ model_channels,
558
+ num_heads,
559
+ relative_pos_embeddings=True,
560
+ do_checkpoint=False,
561
+ ),
562
+ AttentionBlock(
563
+ model_channels,
564
+ num_heads,
565
+ relative_pos_embeddings=True,
566
+ do_checkpoint=False,
567
+ ),
568
+ )
569
+
570
+ # self.gru=nn.GRU(128*2,256,batch_first=True,bidirectional=True)
571
+ # self.attention = MultiHeadAttention(query_dim=model_channels, key_dim=model_channels//num_heads, num_units=model_channels*2, num_heads=num_heads)
572
+ # self.style_tokens = nn.parameter.Parameter(torch.FloatTensor(style_tokens,model_channels//num_heads))
573
+
574
+ # init.normal_(self.style_tokens, mean=0, std=0.5)
575
+
576
+ def forward(self, x):
577
+ # add masking
578
+ batch = x.size(0)
579
+ # x=x.view(batch,1,-1,80) # (N,1,t,80)
580
+ x = self.reference_encoder(x) # (N,128,t,80//x)
581
+ # print(x.shape,'encoder')
582
+ x = x.mean(dim=-1) # .mean(dim=-1)
583
+ # # x=x.transpose(1,2).contiguous() #(N,t,128,80//x)
584
+ # # time=x.size(1)
585
+ # # x=x.view(batch,time,-1)
586
+ # # _,x=self.gru(x)
587
+ # # print(x.shape,'gru')
588
+ # x=x.view(batch,1,-1)
589
+ # keys = self.style_tokens.unsqueeze(0).expand(batch, -1, -1) # [N, token_num, E // num_heads]
590
+ # # print(keys.shape,'keys')
591
+ # style_embed = self.attention(x, keys)
592
+ # # print(style_embed.shape,'gst tokens')
593
+
594
+ # add normalization?
595
+
596
+ return x.view(batch, -1, 1)
597
+
598
+
599
+ if __name__ == "__main__":
600
+ device = torch.device("cpu")
601
+ m = GST(512, 10).to(device)
602
+ mels = torch.rand((16, 80, 1000)).to(device)
603
+
604
+ o = m(mels)
605
+ print(o.shape, "final output")
606
+
607
+ from torchinfo import summary
608
+
609
+ summary(m, input_data={"x": torch.randn(16, 80, 500).to(device)})
T2S/utilities.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import librosa.util as librosa_util
2
+ import numpy as np
3
+ import torch
4
+ from scipy.io.wavfile import read
5
+ from scipy.signal import get_window
6
+
7
+ # import librosa
8
+ from config import config
9
+
10
+ # find these values
11
+ TACOTRON_MEL_MAX = 2.3143386840820312
12
+ TACOTRON_MEL_MIN = -11.512925148010254
13
+
14
+
15
+ def denormalize_tacotron_mel(norm_mel):
16
+ return ((norm_mel + 1) / 2) * (
17
+ TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
18
+ ) + TACOTRON_MEL_MIN
19
+
20
+
21
+ def normalize_tacotron_mel(mel):
22
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
23
+
24
+
25
+ def get_mask_from_lengths(lengths, max_len=None):
26
+ if not max_len:
27
+ max_len = torch.max(lengths).item()
28
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
29
+ mask = (ids < lengths.unsqueeze(1)).bool()
30
+ return mask
31
+
32
+
33
+ def get_mask(lengths, max_len=None):
34
+ if not max_len:
35
+ max_len = torch.max(lengths).item()
36
+ lens = torch.arange(
37
+ max_len,
38
+ )
39
+ mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1)
40
+ return mask
41
+
42
+
43
+ def dynamic_range_compression(x, C=1, clip_val=1e-5):
44
+ """
45
+ PARAMS
46
+ ------
47
+ C: compression factor
48
+ """
49
+ return torch.log(torch.clamp(x, min=clip_val) * C)
50
+
51
+
52
+ def dynamic_range_decompression(x, C=1):
53
+ """
54
+ PARAMS
55
+ ------
56
+ C: compression factor used to compress
57
+ """
58
+ return torch.exp(x) / C
59
+
60
+
61
+ def window_sumsquare(
62
+ window,
63
+ n_frames,
64
+ hop_length=200,
65
+ win_length=800,
66
+ n_fft=800,
67
+ dtype=np.float32,
68
+ norm=None,
69
+ ):
70
+ """
71
+ # from librosa 0.6
72
+ Compute the sum-square envelope of a window function at a given hop length.
73
+ This is used to estimate modulation effects induced by windowing
74
+ observations in short-time fourier transforms.
75
+ Parameters
76
+ ----------
77
+ window : string, tuple, number, callable, or list-like
78
+ Window specification, as in `get_window`
79
+ n_frames : int > 0
80
+ The number of analysis frames
81
+ hop_length : int > 0
82
+ The number of samples to advance between frames
83
+ win_length : [optional]
84
+ The length of the window function. By default, this matches `n_fft`.
85
+ n_fft : int > 0
86
+ The length of each analysis frame.
87
+ dtype : np.dtype
88
+ The data type of the output
89
+ Returns
90
+ -------
91
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
92
+ The sum-squared envelope of the window function
93
+ """
94
+ if win_length is None:
95
+ win_length = n_fft
96
+
97
+ n = n_fft + hop_length * (n_frames - 1)
98
+ x = np.zeros(n, dtype=dtype)
99
+
100
+ # Compute the squared window at the desired length
101
+ win_sq = get_window(window, win_length, fftbins=True)
102
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
103
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
104
+
105
+ # Fill the envelope
106
+ for i in range(n_frames):
107
+ sample = i * hop_length
108
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
109
+ return x
110
+
111
+
112
+ def load_wav_to_torch(full_path):
113
+ sampling_rate, data = read(
114
+ full_path,
115
+ )
116
+ # print(data)
117
+ # data,sampling_rate = librosa.load(full_path)
118
+ # print(data)
119
+ return torch.FloatTensor(data), sampling_rate
120
+
121
+
122
+ if __name__ == "__main__":
123
+ lens = torch.tensor([2, 3, 7, 5, 4])
124
+ mask = get_mask(lens)
125
+ print(mask)
126
+ print(mask.shape)
Text/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .symbols import *
Text/__pycache__/__init__.cpython-310.pyc ADDED
Binary file (150 Bytes). View file
 
Text/__pycache__/symbols.cpython-310.pyc ADDED
Binary file (33.7 kB). View file
 
Text/symbols.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import pickle as pkl
3
+ sys.path.append("../")
4
+ from config import config
5
+
6
+
7
+ # labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
8
+ # # labels=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
9
+ # labels='''ଊతూിਮ০य़లഢਪਟକఝૂएड‌`यঢअచଢ଼ਧ—ତলશರଖच,பવड़ષंಈಮਤਇଥkखഗబ= इਸಣਹછ™ୟ.ोೀৎುഊଳંർਘମഴఙसଗൃlଝਜఇഓਐভയಅಠభാടਔಒ೧পஜaૅૠএଲ৯eകँ৭àৱऊટഒਗহિేயీെஈଓഭೊাੌಙ१ଈःസठખm‘ొऍಿcശrట।ऱଋઘਛெਬಂङಹஞ਼ભ১"એੂചಸગಷ়ଁമಓtஒઉಪs్-pଛ›ढ+ಆ'বનধৰউીଅઝ੍ೂʼൂఔfતषഖঢ়৬ਖक़ਵషணझപળଔઞੇವௗઁത২xెഥख़iটਲધಔೇீથ*ഝॅঃஓूఒীనਜ਼எુுహौ९ൗౌফഔોhஔণంफ़ఋçଯઊൽଆ’ୁைഛ२&ঁണ़ైৌআஆোਠਭजொમळಘஷഏি/ચਾ“ਯ$ଐീवऩ८ઢఛఎেథഠ[औಳରथୃൈಝnজਥऑଷੱल೯wओଵढ़மവरడఊbೖਈૃपdêଉఐ;ै ఢ ઔકচ৩‎ਊൾഉਕ೦ಏj€:ਦಗાളੁशफുழൻಊगફఏఅ?णറഘಞ४ಡಫଠ್ড೨ൊঞमਂસૉॉઅരஙલঘନ്ఠॄvઋృষऎகೕଘઆఞലେূஊఉૈദఫఈदকज़!ధઠవଞறಟਖ਼ਫ਼ইਢഡঠஃஸୂटঅହఆளోईৃಜ॥(ઈଏੀഈक્গ ಚಢഹೃिஏಯyশேଡೋੈਣડఃഷഇਸ਼நখಋோনૐਏgहৗೈृவੰଜग़ੋ୍)ൌరమൺংञਓપయധஇോ५ઃಲళঊತॽ­ന…ঙಭाಇउਅଶরઓି্ூমuపബ\ૌଟबਆुಕଫதছ३దਿದణஐௌ்ৈqఘலહಾ०ಛঐிওऋ‍ి৮ेਨଇүଧഞಶéਚ्৫ୋశఓદঈୀ৪ପüুങਗ਼ઑજথఖঝಐऽਰାആജीઇੜ]आବଡ଼ഫಥుಎણଃયछஅેஹംଢબoদഎగଭాേഅঋসഐಃzਡಬਝன–உಖಉഃযସୈೆకॐನഋয়సசଙড়ୱऒऐઐतଂாতરâèनಧ॑டঔभர”​జ৷ਫଣଚଦधघೌୌਉ'''
10
+ # labels = '''്ുIଢ଼िோശતયసૅऱઠ‍ಆഠৗ.ષਢీિఱहమହளಓರऎ૨ঃমଙൾਓOણBઅਨhੰટപଯrணধऽDേలौೃୀఎ்ഷwீ೫೩லഋJൈদঞਛ3ৌଏషിಭनഐઆഊਏಠଅஞ০യఊிఘଡtಃൂଠҙஙૐబୌ,गઑଉgആૈ॑ाൽೖമજ€ੌୱ్ওಾଂইોकफயেഉഎख़చകয়ঙणGశेतં৯೦ాਜ਼ொోஐকळైҳक़ूएఢçഇેযઋਪଦഡৰഢYଳढਈடપെঢ়ౌসಎਇઉಋণಶப೧ओಒ­ZVਦଶছ४ೆflुೌೇോઝୈఙಇঐெഛপૠઇഞడ়ਔീई೨Pमஷફକేঁଗೊળభన್ଁਣఅஓுৃछഏMीsWଛঅਯচਚరଖजஈ५ॠêफ़ॐನੜ৮ଈଭଢइડಥঝ½ேउଷஎ5ைુ?ಜऍଚಢഓఋ]pఃಫ+०ೋେખਡਫഘठജൻఒঠஉਜೀୋంఞलKਊಪ L-ਠಂఈഥഒdಟ१ిૂಔैఛௌNFAఝऋஔೕલೂऊથಬପற;‌ાઞ'ଝಞളಲcಖൊుোੀoଃज़ಙಕचఏ़थఠSQୁృିஸாઍଵरಯਥ“വंਸୟড়ਬ२ଐ८ॽੱథx ಧવਗବবತॅੁঊগলவ0ൗী7ੋ૪૫৷ৎશદଟnஏ!ఐ​भ*ധഫങഅC›ਅధగUಛછ4தeચदભఖ’அૃʼநଜ—తজସഈ्ଣ"Xીಘਾ2സୃઈಏஃଊಿঋaமళâ২ஹଇuূਲऩାఆृm8ઔवyভ`)ଋംଡ଼ਸ਼ ॄரऔਉ˜పಸசञಡజ৩ാ೯ডൺখધ্ూ୍ਭஊദৈHഹਿಗബশ\ଲਗ਼ਧఓಝవନऑzಮखর९शiదఇঘર1ണথ৭[એ॥হ”கҮગষॉରਖഭഖઊઓbय़ੂपઁಈ৪ਰ(టबಣਟझറड़ஜனഃಉঔEনঈਐમयವअਂःழqତँ।৫ଓణഗ১ઃोधൌબଘटದTஇഴু૧ഔఔਖ਼ংঢ&નಐ਼™ூ–ਝૌମસଆஆಅ३કૄটऒ6యൃહk৬া: ௗଧ‎ஒडvઘലొ$…ർടjಚୂઐਆਵೈউ‘നઢRফੇਹెചघ੍আకಹੈसਘএଥग़രি=षङതഝàऐ9્ૉਤଞುଫఉਮआতಷహਫ਼ৱҠಊଔढ़ಳఫਕ/үéüè#%<>@^_{|}~'''
11
+
12
+ # with open("../Text/symbols_final.pkl",'rb') as file:
13
+ # labels = pkl.load(file)
14
+
15
+ english =''' !"#$%&'()*+,-.0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{} ­´»àáâæçèéêíïñôúüœ˜ҙҡүҳ​‌–—‘’“”•…€™️'''
16
+
17
+ hindi =''' !"'(),-.0123456789;?abcdefghijklmnopqrstuvwxyzँंःअआइईउऊऋऌऍऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऻ़ािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़ॠॡॢ।॥०॰'''
18
+
19
+ kannada ='''  !"#$%&'()*+,-./0123456789:;<=>?[]^`abcdefghiklmnopqrstuvwxy ½ʼॐ।॥ಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬��೮೯ ​‌‍–‘’“”•…'''
20
+
21
+ tamil =''' !"%&'()*,-./0123456789:;?[]`abcdefghijklmnopqrstuvwxyz ­ஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗ‌–‘’“”…'''
22
+
23
+ assamese =''' !%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz ʼ।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৷৹​‌‍–—‘’“”'''
24
+
25
+ dogri =''' !$%&'()+,-./0123456789:;?@[]`abcdefghijklmnopqrstuvwxyzʼँंःअआइईउऊएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ऽािीुूृेैॉोौ्क़ज़ड़ढ़फ़य़।​‌‍–—‘’“”…′₹'''
26
+
27
+ bodo =''' !$%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz °º½ʼँंःअआइईउऊऋऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलवशषसह़ािीुूृॅॆेैॉॊोौ्ज़ड़फ़य़।०ॽ​‌‍–‘’“”'''
28
+
29
+ marathi =''' !'*+,-./0123456789:;?[`z ँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॑ॕक़ज़ड़ढ़फ़य़ॠॡ।॥०१२३४५६७८९ॲ‌‍–‘’“”›'''
30
+
31
+ bengali =''' !',-.0123456789;?acdefghlmnrstuyz।ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৻‌‍—‘’'''
32
+
33
+ telugu =''' !"'*,-./258:;?o ।ఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖ​‌–‘’”…'''
34
+
35
+ german =''' '-abcdefghijklmnopqrstuvwxyzßàäèéöü'''
36
+
37
+ spanish =''' !'-.;?abcdefghijklmnopqrstuvwxyzáéíñóúü'''
38
+
39
+ french =''' !'-.;?abcdefghijklmnopqrstuvwxyzàâæçèéêëîïôùûüÿœ'''
40
+
41
+ punjabi =''' !"'(),-.:?bden।ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳ –‘’“”…'''
42
+
43
+ sanskrit =''' "ँंःअआइईउऊऋऌऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑ॠॡॢ।॥०१२३४५६७८९॰'''
44
+
45
+ odia =''' "',-.;।ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖ୘ଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱ‌‘’”'''
46
+
47
+ urdu =''' !"',-.:`abcdeghiklrtuy،ؑؓؔ؛؟ءآأؤئابتثجحخدذرزسشصضطظعغفقكلمنهوىيًَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔।‘’“”…ﭨﮭﮯﯾﷲﷺﺅﺗﺘﺩﺲﻧﻮ'''
48
+
49
+ gujarati =''' !',-.:;?m ।ઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯‌–‘’…'''
50
+
51
+ rajasthani =''' !'(),-.0123456789:;?xँंःअआइईउऊऋऍएऐऑओऔकखगघचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ािीुूृेैॉोौ्क़ख़ग़ज़ड़ढ़फ़ॠ।०१२३७८९​‍‘’…'''
52
+
53
+ malayalam =''' !,?ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾ'''
54
+
55
+ manipuri =''' ,-.mnঁংঅআইঈউঊএঐওঔকখগঘঙচছজঝঞটঠডণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎয়০১২৩৪৫৬৭৮৯ৰৱ৷'''
56
+
57
+ gujrati =''' !"'(),-.0123456789:?{}‘’ ­âઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૅેૈૉોૌ્ૠ૦૧૨૩૪૫૬૭૮૯​‌‍–—‘’“”…'''
58
+
59
+ bhojpuri ='''  !"'(),-.012346789:?`abeimpy{}·ँंःअआइईउऊऋऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलळवशषसहऺ़ऽािीुूृॅॆेैॉॊोौ्ॐॕॖक़ख़ग़ज़ड़ढ़फ़य़ॠ���०२६९॰ ​‌‍‎‘’'''
60
+
61
+ italian =''' !"$'()+,-.:;<=>?[]_`abcdefghijklmnopqrstuvwxyz{}~¡«°´µº»ßàáâãäåæèéêëìíîïðñòóôöøùúûþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː̨́абдеиклмностуцшѐёљңדהוةرسصغليḥṇṛṣṭễ‑–—‘’“”„…′☆♭♯あアカキサザノフリン・ー万三丰古多家峰張旅禅ꞌ'''
62
+
63
+ arabic =''' !"',-.:;?egt«»،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٓچکیۖۗۘۙۚۛ–“”…ﺃ'''
64
+
65
+ korean =''' 가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘'''
66
+
67
+ russian =''' !"'(),-.:;?abcefghiklmnoprstxz«»абвгдежзийклмнопрстуфхцчшщъыьэюяё‑–—“„…−'''
68
+
69
+ thai =''' กขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์'''
70
+
71
+ japanese =''' !',-.?abcdefghijklmnopqrstuvwxyzμ―‘’“”…☆♡⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャ���ュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍녕러분세안여요하︎!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
72
+
73
+
74
+ final_characters_ = ''' !"#$%&'()*+,-./:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{}~‘’ ¡«­°´µ·º»½ßàáâãäåæçèéêëìíîïðñòóôöøùúûüþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː˜̨́μабвгдежзийклмнопрстуфхцчшщъыьэюяѐёљҙҡңүҳדהו،ؑؓؔ؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔ۖۗۘۙۚۛँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसहऺऻ़ऽािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॑॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़य़ॠॡॢ।॥०१२३४५६७८९॰ॲॽঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৹৻ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖ୘ଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾกขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์ḥṇṛṣṭễ  ​‌‍‎‑–—―‘’“”„•…′›€₹™−��♡♭♯⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中丰串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓���祝神祠祢祥票祭禁禄禅禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍ꞌ가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘ﭨﮭﮯﯾﷲﷺ︎️ﺃﺅﺗﺘﺩﺲﻧﻮ!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
75
+ # removed numbers
76
+
77
+ labels= [i for i in final_characters_]
78
+
79
+ text_labels = [i for i in labels]
80
+ text_labels+='<S>','<E>','<PAD>'
81
+
82
+ code_labels= [str(i) for i in range(config.semantic_model_centroids)]
83
+ labels+=code_labels
84
+ code_labels+='<SST>','<EST>','<PAD>'
85
+
86
+ labels+='<S>','<E>','<SST>','<EST>','<PAD>'
87
+
88
+ print('length of the labels: ',len(labels))
Text/symbols_final.pkl ADDED
Binary file (4.75 kB). View file
 
Text/text_meta.txt ADDED
@@ -0,0 +1,62 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ 'odia', 'assamese', 'thai', 'gujrati', 'russian', 'japanese', 'punjabi', 'hindi', 'manipuri', 'korean', 'bhojpuri', 'sanskrit', 'english', 'french', 'bodo', 'malayalam', 'telugu', 'kannada', 'dogri', 'marathi', 'german', 'italian', 'rajasthani', 'spanish', 'arabic', 'urdu', 'gujarati', 'tamil', 'bengali',
2
+
3
+ final_characters_ = '''  !"#$%&'()*+,-./0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{}~‘’ ¡«­°´µ·º»½ßàáâãäåæçèéêëìíîïðñòóôöøùúûüþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː˜̨́μабвгдежзийклмнопрстуфхцчшщъыьэюяѐёљҙҡңүҳדהו،ؑؓؔ؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔ۖۗۘۙۚۛँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसहऺऻ़ऽािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॑॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़य़ॠॡॢ।॥०१२३४५६७८९॰ॲॽঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৹৻ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖ୘ଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾกขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์ḥṇṛṣṭễ  ​‌‍‎‑–—―‘’“”„•…′›€₹™−☆♡♭♯⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中丰串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪���女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禅禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍ꞌ가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘ﭨﮭﮯﯾﷲﷺ︎️ﺃﺅﺗﺘﺩﺲﻧﻮ!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
4
+
5
+ english =''' !"#$%&'()*+,-.0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{} ­´»àáâæçèéêíïñôúüœ˜ҙҡүҳ​‌–—‘’“”•…€™️'''
6
+
7
+ hindi =''' !"'(),-.0123456789;?abcdefghijklmnopqrstuvwxyzँंःअआइईउऊऋऌऍऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऻ़ािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़ॠॡॢ।॥०॰'''
8
+
9
+ kannada ='''  !"#$%&'()*+,-./0123456789:;<=>?[]^`abcdefghiklmnopqrstuvwxy ½ʼॐ।॥ಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ ​‌‍–‘’“”•…'''
10
+
11
+ tamil =''' !"%&'()*,-./0123456789:;?[]`abcdefghijklmnopqrstuvwxyz ­ஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗ‌–‘’“”…'''
12
+
13
+ assamese =''' !%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz ʼ।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৷৹​‌‍–—‘’“”'''
14
+
15
+ dogri =''' !$%&'()+,-./0123456789:;?@[]`abcdefghijklmnopqrstuvwxyzʼँंःअआइईउऊएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ऽािीुूृेैॉोौ्क़ज़ड़ढ़फ़य़।​‌‍–—‘’“”…′₹'''
16
+
17
+ bodo =''' !$%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz °º½ʼँंःअआइईउऊऋऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलवशषसह़ािीुूृॅॆेैॉॊोौ्ज़ड़फ़य़।०ॽ​‌‍–‘’“”'''
18
+
19
+ marathi =''' !'*+,-./0123456789:;?[`z ँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॑ॕक़ज़ड़ढ़फ़य़ॠॡ।॥०१२३४५६७८९ॲ‌‍–‘’“”›'''
20
+
21
+ bengali =''' !',-.0123456789;?acdefghlmnrstuyz।ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৻‌‍—‘’'''
22
+
23
+ telugu =''' !"'*,-./258:;?o ।ఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖ​‌–‘’”…'''
24
+
25
+ german =''' '-abcdefghijklmnopqrstuvwxyzßàäèéöü'''
26
+
27
+ spanish =''' !'-.;?abcdefghijklmnopqrstuvwxyzáéíñóúü'''
28
+
29
+ french =''' !'-.;?abcdefghijklmnopqrstuvwxyzàâæçèéêëîïôùûüÿœ'''
30
+
31
+ punjabi =''' !"'(),-.:?bden।ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳ –‘’“”…'''
32
+
33
+ sanskrit =''' "ँंःअआइईउऊऋऌऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑ॠॡॢ।॥०१२३४५६७८९॰'''
34
+
35
+ odia =''' "',-.;।ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖ୘ଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱ‌‘’”'''
36
+
37
+ urdu =''' !"',-.:`abcdeghiklrtuy،ؑؓؔ؛؟ءآأؤئابتثجحخدذرزسشصضطظعغفقكلمنهوىيًَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔।‘’“”…ﭨﮭﮯﯾﷲﷺﺅﺗﺘﺩﺲﻧﻮ'''
38
+
39
+ gujarati =''' !',-.:;?m ।ઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯‌–‘’…'''
40
+
41
+ rajasthani =''' !'(),-.0123456789:;?xँंःअआइईउऊऋऍएऐऑओऔकखगघचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ािीुूृेैॉोौ्क़ख़ग़ज़ड़ढ़फ़ॠ।०१२३७८९​‍‘’…'''
42
+
43
+ malayalam =''' !,?ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾ'''
44
+
45
+ manipuri =''' ,-.mnঁংঅআইঈউঊএঐওঔকখগঘঙচছজঝঞটঠডণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎয়০১২৩৪৫৬৭৮৯ৰৱ৷'''
46
+
47
+ gujrati =''' !"'(),-.0123456789:?{}‘’ ­âઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૅેૈૉોૌ્ૠ૦૧૨૩૪૫૬૭૮૯​‌‍–—‘’“”…'''
48
+
49
+ bhojpuri ='''  !"'(),-.012346789:?`abeimpy{}·ँंःअआइईउऊऋऍऎ��ऐऑऒओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलळवशषसहऺ़ऽािीुूृॅॆेैॉॊोौ्ॐॕॖक़ख़ग़ज़ड़ढ़फ़य़ॠ।०२६९॰ ​‌‍‎‘’'''
50
+
51
+ italian =''' !"$'()+,-.:;<=>?[]_`abcdefghijklmnopqrstuvwxyz{}~¡«°´µº»ßàáâãäåæèéêëìíîïðñòóôöøùúûþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː̨́абдеиклмностуцшѐёљңדהוةرسصغليḥṇṛṣṭễ‑–—‘’“”„…′☆♭♯あアカキサザノフリン・ー万三丰古多家峰張旅禅ꞌ'''
52
+
53
+ arabic =''' !"',-.:;?egt«»،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٓچکیۖۗۘۙۚۛ–“”…ﺃ'''
54
+
55
+ korean =''' 가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘'''
56
+
57
+ russian =''' !"'(),-.:;?abcefghiklmnoprstxz«»абвгдежзийклмнопрстуфхцчшщъыьэюяё‑–—“„…−'''
58
+
59
+ thai =''' กขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์'''
60
+
61
+ japanese =''' !',-.?abcdefghijklmnopqrstuvwxyzμ―‘’“”…☆♡⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわ��をん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟���精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍녕러분세안여요하︎!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
62
+
app.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys, os, torch
2
+ # sys.path.append("Testing/")
3
+ import gradio as gr
4
+ from inference import infer, prepare_inputs, load_t2s_model, load_cfm, create_wav_header
5
+ from tqdm import tqdm
6
+
7
+ # Setup
8
+ os.makedirs("generated_samples/", exist_ok=True)
9
+ device = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print("Using device", device)
11
+
12
+ # Model checkpoints
13
+ m1_checkpoint = "pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt"
14
+ m2_checkpoint = "pretrained_checkpoint/m2.pt"
15
+ vocoder_checkpoint = 'pretrained_checkpoint/700_580k_multilingual_infer_ready/'
16
+
17
+ global FM, vocoder, m2, mu, std, m1
18
+
19
+ # Load models
20
+ FM, vocoder, m2, mu, std = load_cfm(m2_checkpoint, vocoder_checkpoint, device)
21
+ m1 = load_t2s_model(m1_checkpoint, device)
22
+
23
+
24
+ # Speaker reference clips
25
+ speaker_refs = {
26
+ "Speaker1": [
27
+ "speakers/female1/train_hindifemale_02794.wav",
28
+ "speakers/female1/train_hindifemale_04167.wav",
29
+ "speakers/female1/train_hindifemale_02795.wav"
30
+ ]
31
+ }
32
+
33
+ # Available languages (can be extended)
34
+ available_languages = ["hindi"]
35
+
36
+ # Inference function
37
+ def generate_audio(text, speaker_name, language):
38
+ if speaker_name not in speaker_refs:
39
+ return f"Reference clips not available for {speaker_name}", None
40
+
41
+ ref_clips = speaker_refs[speaker_name]
42
+
43
+ text_ids, code_ids, language_code, ref_mels_m1, ref_mels_m2 = prepare_inputs(
44
+ text.lower(),
45
+ ref_clips_m1=ref_clips,
46
+ ref_clips_m2=ref_clips,
47
+ language=language,
48
+ device=device
49
+ )
50
+
51
+ audio_wav = infer(m1, m2, vocoder, FM, mu, std, text_ids, code_ids, language_code, ref_mels_m1, ref_mels_m2, device)
52
+ return 24000,audio_wav
53
+
54
+ # Gradio UI
55
+ interface = gr.Interface(
56
+ fn=generate_audio,
57
+ inputs=[
58
+ gr.Textbox(label="Enter Text"),
59
+ gr.Dropdown(choices=list(speaker_refs.keys()), label="Select Speaker"),
60
+ gr.Dropdown(choices=available_languages, label="Select Language")
61
+ ],
62
+ outputs=gr.Audio(label="Generated Speech"),
63
+ title="MAHATTSv2 Demo",
64
+ description="Enter text, choose a speaker and language to generate speech."
65
+ )
66
+
67
+ interface.launch(share=True,server_port=9999)
bigvgan_v2_24khz_100band_256x ADDED
@@ -0,0 +1 @@
 
 
1
+ Subproject commit c329ede9e9bbc100ddf5c91e2330a61921262370
config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,torch
2
+ import sys
3
+
4
+ dir_path = os.path.abspath(os.path.dirname(__file__))
5
+
6
+ class config:
7
+
8
+ data_path='./'
9
+ model_name="OSS Shogun"
10
+ user_name = "xxx" #for wandb
11
+
12
+ desc = '''
13
+ OSS
14
+ '''
15
+
16
+ train_file = "Semantic_tokens/SEMANTICS_/train.txt"
17
+ val_file = "Semantic_tokens/SEMANTICS_/val.txt"
18
+
19
+ norms = torch.load(f"{dir_path}/mel_norms.pt")
20
+ mu = norms["mean_val"]
21
+ std = norms["std"]
22
+ scale = True
23
+ semantic_model_centroids = 10000 + 1
24
+ seed_value = 3407
25
+
26
+ t2s_checkpoint = "/omega/Models/FT_ENGLISH/T2S/1_latest.pt"
27
+ ts_finetuning = True
28
+ ts_wandb_logs = False
29
+ text_loss_weight = 0.01
30
+ t2s_position = 8192
31
+ ts_batch_size = 1
32
+ ts_epochs = 10
33
+ ts_lr = 1e-5
34
+ ts_weight_decay = 1e-4
35
+ ts_eval_epoch = 1
36
+ ts_num_workers = 8
37
+ ts_gradient_accumulation_steps = 1 # EfBS of 128 for finetuning, 256 for pretraining, around 9k steps for sft for 2 epochs
38
+ ts_eval_step = 10000
39
+
40
+ langs = [
41
+ "odia",
42
+ "assamese",
43
+ "thai",
44
+ "gujrati",
45
+ "russian",
46
+ "japanese",
47
+ "punjabi",
48
+ "hindi",
49
+ "manipuri",
50
+ "korean",
51
+ "bhojpuri",
52
+ "sanskrit",
53
+ "english",
54
+ "french",
55
+ "bodo",
56
+ "malayalam",
57
+ "telugu",
58
+ "kannada",
59
+ "dogri",
60
+ "marathi",
61
+ "german",
62
+ "italian",
63
+ "rajasthani",
64
+ "spanish",
65
+ "arabic",
66
+ "urdu",
67
+ "gujarati",
68
+ "tamil",
69
+ "bengali",
70
+ ]
71
+
72
+ lang_index = {i: j for j, i in enumerate(langs)}
73
+
74
+ # Train s2a
75
+ sa_wandb_logs = False
76
+ joint_training = (False,) # doesn't work
77
+ checkpoint = "/omega/Models/" + "FT_ENGLISH/"
78
+ sa_timesteps_max = 1000
79
+ sa_batch_size = 32
80
+ sa_epochs = 5000000
81
+ gradient_accumulation_steps = 4
82
+ sa_lr = 1e-4
83
+ sa_weight_decay = 1e-2
84
+ sa_eval_step = 10000
85
+ sa_infer = True
86
+ sa_infer_epoch = 1
87
+ sa_num_workers = 24
88
+
89
+ # Train Dvae (not using)
90
+ dvae_wandb_logs = True
91
+ dvae_batch_size = 128
92
+ dvae_epochs = 5000
93
+ dvae_lr = 3e-4
94
+ dvae_weight_decay = 1e-2
95
+ dvae_eval_epoch = 1
96
+ dvae_infer = True
97
+ dvae_infer_epoch = 1
98
+ dvae_num_workers = 16
99
+
100
+ # Acoustic Properties, Do not change
101
+ CLIP_LENGTH = 500
102
+ MAX_WAV_VALUE = 32768.0 - 1
103
+ filter_length = 1024
104
+ hop_length = 256 # 256
105
+ window = "hann"
106
+ win_length = 1024
107
+ n_mel_channels = 100
108
+ sampling_rate = 24000
109
+ mel_fmin = 0.0
110
+ mel_fmax = None
111
+ normalize = True
generated_samples/0_hindi.wav ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:4f7e0dda3c6d53b0658aff88ec422608f6fda9e3a9626d49338a700c67d296fe
3
+ size 966188
inference.py ADDED
@@ -0,0 +1,217 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os,sys,time,struct
2
+ import torch,torchaudio
3
+
4
+ # sys.path.append("S2A/bigvgan_v2_24khz_100band_256x")
5
+ # sys.path.append("S2A/")
6
+ # sys.path.append("T2S/")
7
+ # sys.path.append("hifi-gan/")
8
+
9
+ from S2A.inference import *
10
+ from S2A.diff_model import DiffModel
11
+ from T2S.autoregressive import TS_model
12
+ from T2S.mel_spec import get_mel_spectrogram
13
+ from Text import labels,text_labels,code_labels
14
+ from config import config
15
+
16
+ #code encdec
17
+ text_enc = {j:i for i,j in enumerate(text_labels)}
18
+ text_dec = {i:j for i,j in enumerate(text_labels)}
19
+
20
+ # text encdec
21
+ code_enc = {j:i for i,j in enumerate(code_labels)}
22
+ code_dec = {i:j for i,j in enumerate(code_labels)}
23
+
24
+ def create_wav_header(sample_rate = 24000, bits_per_sample=16, channels=1):
25
+ # "RIFF" chunk descriptor
26
+ chunk_id = b'RIFF'
27
+ chunk_size = 0xFFFFFFFF # Placeholder for chunk size (unknown during streaming)
28
+ format = b'WAVE'
29
+
30
+ # "fmt " sub-chunk (16 bytes for PCM format)
31
+ subchunk1_id = b'fmt '
32
+ subchunk1_size = 16 # PCM format
33
+ audio_format = 1 # PCM = 1 (linear quantization)
34
+ num_channels = channels
35
+ sample_rate = sample_rate
36
+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
37
+ block_align = num_channels * bits_per_sample // 8
38
+ bits_per_sample = bits_per_sample
39
+
40
+ # "data" sub-chunk
41
+ subchunk2_id = b'data'
42
+ subchunk2_size = 0xFFFFFFFF # Placeholder for data size (unknown during streaming)
43
+
44
+ # Pack the header into a byte object using struct
45
+ header = struct.pack('<4sI4s4sIHHIIHH4sI',
46
+ chunk_id,
47
+ chunk_size,
48
+ format,
49
+ subchunk1_id,
50
+ subchunk1_size,
51
+ audio_format,
52
+ num_channels,
53
+ sample_rate,
54
+ byte_rate,
55
+ block_align,
56
+ bits_per_sample,
57
+ subchunk2_id,
58
+ subchunk2_size)
59
+
60
+ return header
61
+
62
+
63
+ def get_processed_clips(ref_clips):
64
+ frame_rate = 24000
65
+ new_ref_clips = []
66
+ for i in ref_clips:
67
+ if '_proc.wav' in i:
68
+ new_ref_clips.append(i)
69
+ continue
70
+ audio = AudioSegment.from_file(i)
71
+ audio = audio.set_channels(1)
72
+ audio = audio.set_frame_rate(frame_rate).set_sample_width(2)
73
+ audio.export(i[:-4]+'_proc.wav',format='wav')
74
+ new_ref_clips.append(i[:-4]+'_proc.wav')
75
+
76
+ return new_ref_clips
77
+
78
+ def get_ref_mels(ref_clips):
79
+ ref_mels = []
80
+ for i in ref_clips:
81
+ audio_norm,sampling_rate = torchaudio.load(i)
82
+ ref_mels.append(get_mel_spectrogram(audio_norm,sampling_rate).squeeze(0)[:, :500])
83
+
84
+ ref_mels_padded = (torch.randn((len(ref_mels), 100, 500))) * 1e-9
85
+ for i, mel in enumerate(ref_mels):
86
+ ref_mels_padded[i, :, : mel.size(1)] = mel
87
+ return ref_mels_padded.unsqueeze(0)
88
+
89
+ def load_cfm(checkpoint,vocoder_checkpoint=None,device="cpu"):
90
+ FM = BASECFM()
91
+ if vocoder_checkpoint is None:
92
+ hifi = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
93
+ else:
94
+ hifi = bigvgan.BigVGAN.from_pretrained(vocoder_checkpoint, use_cuda_kernel=False)
95
+
96
+ hifi.remove_weight_norm()
97
+ hifi = hifi.eval().to(device)
98
+
99
+ model = DiffModel(input_channels=100,
100
+ output_channels=100,
101
+ model_channels=512,
102
+ num_heads=8,
103
+ dropout=0.1,
104
+ num_layers=8,
105
+ enable_fp16=False,
106
+ condition_free_per=0.0,
107
+ multispeaker=True,
108
+ style_tokens=100,
109
+ training=False,
110
+ ar_active=False)
111
+
112
+ model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu'))['model'])
113
+ model.eval().to(device)
114
+ mu= torch.load(checkpoint,map_location=torch.device('cpu'))['norms']['mean_val']
115
+ std = torch.load(checkpoint,map_location=torch.device('cpu'))['norms']['std']
116
+
117
+ return FM,hifi,model,mu,std
118
+
119
+ def load_t2s_model(checkpoint,device):
120
+ model = TS_model(n_embed= 1024, n_layer= 30, n_head = 16)
121
+ model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu'))['model'],strict=True)
122
+ model.eval()
123
+ model.to(device)
124
+ model.init_gpt_for_inference()
125
+ return model
126
+
127
+ def prepare_inputs(text,ref_clips_m1,ref_clips_m2,language,device):
128
+ code_ids = [code_enc["<SST>"]]
129
+ text_ids = (
130
+ torch.tensor(
131
+ [text_enc["<S>"]]
132
+ + [text_enc[i] for i in text.strip()]
133
+ + [text_enc["<E>"]]
134
+ )
135
+ .to(device)
136
+ .unsqueeze(0)
137
+ )
138
+ language = (
139
+ torch.tensor(config.lang_index[language]).to(device).unsqueeze(0)
140
+ )
141
+
142
+ ref_mels_m1 = get_ref_mels(get_processed_clips(ref_clips_m1))
143
+ ref_mels_m2 = get_ref_mels(get_processed_clips(ref_clips_m2))
144
+
145
+ return text_ids,code_ids,language,ref_mels_m1,ref_mels_m2
146
+
147
+
148
+ def infer(m1,m2,vocoder,FM,mu,std,text_ids,code_ids,language,ref_mels_m1,ref_mels_m2,device):
149
+ with torch.no_grad():
150
+ cond_latents = m1.get_speaker_latent(ref_mels_m1.to(device))
151
+ code_emb = m1.generate(
152
+ language.to(device), cond_latents.to(device), text_ids.to(device), code_ids, **{
153
+ "temperature": 0.8,
154
+ "length_penalty": None,
155
+ "repetition_penalty": None,
156
+ "top_k": 50,
157
+ "top_p": 0.8,
158
+ "do_sample": True,
159
+ "num_beams": 1,
160
+ "max_new_tokens": 1500
161
+ }
162
+ )[:, :-1]
163
+
164
+ mel = FM(m2, code_emb+1, (1, 100, int(1+93*(code_emb.shape[-1]+1)/50)), ref_mels_m2.to(device), n_timesteps=20, temperature=1.0)
165
+ mel = denormalize_tacotron_mel(mel,mu,std)
166
+ audio = vocoder(mel)
167
+ audio = audio.squeeze(0).detach().cpu()
168
+ audio = audio * 32767.0
169
+ audio = (
170
+ audio.numpy().reshape(-1).astype(np.int16)
171
+ )
172
+
173
+ return audio
174
+
175
+ if __name__ == '__main__':
176
+ os.makedirs("generated_samples/",exist_ok=True)
177
+ start = time.time()
178
+ device = "cuda" if torch.cuda.is_available() else "cpu"
179
+ m1_checkpoint = "pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt"
180
+ m2_checkpoint = "pretrained_checkpoint/m2.pt"
181
+ vocoder_checkpoint = 'pretrained_checkpoint/700_580k_multilingual_infer_ready/'
182
+
183
+ FM,vocoder,m2,mu,std = load_cfm(m2_checkpoint,vocoder_checkpoint,device)
184
+ m1 = load_t2s_model(m1_checkpoint,device)
185
+
186
+ model_loading_time = time.time()
187
+
188
+ output_file_path = "test.wav"
189
+
190
+ texts = ["यह एक उदाहरणात्मक हिंदी पाठ है जिसका उद्देश्य भाषा की संरचना और शब्दों के प्रवाह को समझना है। भारत एक विविधताओं से भरा देश है जहाँ अनेक भाषाएँ, धर्म, और संस्कृतियाँ एक साथ मिलकर रहते हैं। यहाँ की परंपराएँ, त्योहार और भोजन इसकी सांस्कृतिक समृद्धि को दर्शाते हैं।"]
191
+ languages=['hindi']
192
+
193
+ ref_clips = [
194
+ 'speakers/female1/train_hindifemale_02794.wav',
195
+ 'speakers/female1/train_hindifemale_04167.wav',
196
+ 'speakers/female1/train_hindifemale_02795.wav'
197
+ ]
198
+
199
+ for n,(lang,text) in tqdm(enumerate(zip(languages,texts))):
200
+ text_ids,code_ids,language,ref_mels_m1,ref_mels_m2 = prepare_inputs(text.lower(),
201
+ ref_clips_m1=ref_clips,
202
+ ref_clips_m2=ref_clips,
203
+ language=lang
204
+ ,device=device)
205
+ audio_wav = infer(m1,m2,vocoder,FM,mu,std,text_ids,code_ids,language,ref_mels_m1,ref_mels_m2,device)
206
+
207
+ with open(f"generated_samples/{n}_{lang}.wav",'wb') as file:
208
+ file.write(create_wav_header(sample_rate = 24000, bits_per_sample=16, channels=1))
209
+ file.write(audio_wav.tobytes())
210
+
211
+ audio_generation_time=time.time()
212
+
213
+ print()
214
+ print(text)
215
+ print(audio_generation_time-start,":Total Time taken")
216
+ print(model_loading_time-start, ":Model Loading time")
217
+ print(audio_generation_time-model_loading_time, ":Audio Generation time")
mel_norms.pt ADDED
Binary file (3.75 kB). View file
 
pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:e612993a903a446054367d1b7701577d35948634324dd7ee3c00ff5288a986c5
3
+ size 450080461