diff --git a/.gitattributes b/.gitattributes
new file mode 100644
index 0000000000000000000000000000000000000000..28e77eacc7c443912f7d130ad5fef919619a5070
--- /dev/null
+++ b/.gitattributes
@@ -0,0 +1,16 @@
+pretrained_checkpoint/m2.pt filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_02794.wav filter=lfs diff=lfs merge=lfs -text
+speakers/male1 filter=lfs diff=lfs merge=lfs -text
+pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_02794_proc.wav filter=lfs diff=lfs merge=lfs -text
+*.wav filter=lfs diff=lfs merge=lfs -text
+pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_02795.wav filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_04167_proc.wav filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_04167.wav filter=lfs diff=lfs merge=lfs -text
+*.mp3 filter=lfs diff=lfs merge=lfs -text
+*.flac filter=lfs diff=lfs merge=lfs -text
+speakers/female1 filter=lfs diff=lfs merge=lfs -text
+pretrained_checkpoint/700_580k_multilingual_infer_ready/config.json filter=lfs diff=lfs merge=lfs -text
+speakers/female1/train_hindifemale_02795_proc.wav filter=lfs diff=lfs merge=lfs -text
+pretrained_checkpoint/700_580k_multilingual_infer_ready filter=lfs diff=lfs merge=lfs -text
diff --git a/.gitmodules b/.gitmodules
new file mode 100644
index 0000000000000000000000000000000000000000..f7e6a78c975b6070a89cf868fe25074e09b8e330
--- /dev/null
+++ b/.gitmodules
@@ -0,0 +1,6 @@
+[submodule "Semantic_tokens/seamless_communication"]
+ path = Semantic_tokens/seamless_communication
+ url = https://github.com/facebookresearch/seamless_communication.git
+[submodule "bigvgan_v2_24khz_100band_256x"]
+ path = bigvgan_v2_24khz_100band_256x
+ url = https://huggingface.co/nvidia/bigvgan_v2_24khz_100band_256x
diff --git a/README.md b/README.md
new file mode 100755
index 0000000000000000000000000000000000000000..6172918584fd79286dad17f3c26f98ebf0dab139
--- /dev/null
+++ b/README.md
@@ -0,0 +1,95 @@
+
+
+

+
+
MahaTTS v2: An Open-Source Large Speech Generation Model
+a
Dubverse Black initiative
+
+
+
+
+------
+## Description
+We introduce MahaTTS v2, a multi-speaker text-to-speech (TTS) system that has been trained on 50k hours of Indic and global languages.
+We have followed a text-to-semantic-to-acoustic approach, leveraging wav2vec2 tokens, this gives out-the-box generalization to unseen low-resourced languages.
+We have open sourced the first version (MahaTTS), which was trained on English and Indic languages as two separate models on 9k and 400 hours of open source datasets.
+In MahaTTS v2, we have collected over 20k+ hours of training data into a single multilingual cross-lingual model.
+We have used gemma as the backbone for text-to-semantic modeling and a conditional flow model for semantics to mel spectogram generation, using a BigVGAN vocoder to generate the final audio waveform.
+The model has shown great robustness and quality results compared to the previous version.
+We are also open sourcing the ability to finetune on your own voice.
+
+### With this release:
+- generate voices in multiple seen and unseen speaker identities (voice cloning)
+- generate voices in multiple langauges (multilingual and cross-lingual voice cloning)
+- copy the style of speech from one speaker to another (cross-lingual voice cloning with prosody and intonation transfer)
+- Train your own large scale pretraining or finetuning Models.
+
+### MahaTTS Architecture
+
+
+
+
+
+
+
+
+
+
+### Model Params
+| Model | Parameters | Model Type | Output |
+|:-------------------------:|:----------:|------------|:-----------------:|
+| Text to Semantic (M1) | 510 M | Causal LM | 10,001 Tokens |
+| Semantic to MelSpec(M2) | 71 M | FLOW | 100x Melspec |
+| BigVGAN Vocoder | 112 M | GAN | Audio Waveform |
+
+
+## 🌐 Supported Languages
+
+The following languages are currently supported:
+
+| Language | Status |
+|------------------|:------:|
+| English (en) | ✅ |
+| Hindi (in) | ✅ |
+| Assamese (in) | ✅ |
+| Gujarati (in) | ✅ |
+| Telugu (in) | ✅ |
+| Punjabi (in) | ✅ |
+| Marathi (in) | ✅ |
+| Tamil (in) | ✅ |
+| Bengali (in) | ✅ |
+| Odia (in) | ✅ |
+| Manipuri (in) | ✅ |
+| Bhojpuri (in) | ✅ |
+| Sanskrit (in) | ✅ |
+| Bodo (in) | ✅ |
+| Malayalam (in) | ✅ |
+| Kannada (in) | ✅ |
+| Dogri (in) | ✅ |
+| Rajasthani (in) | ✅ |
+| Thai (th) | ✅ |
+| Japanese (ja) | ✅ |
+| French (fr) | ✅ |
+| German (de) | ✅ |
+| Italian (it) | ✅ |
+| Spanish (es) | ✅ |
+
+
+## TODO:
+1. Addind Training Instructions.
+2. Add a colab for the same.
+
+
+## License
+MahaTTS is licensed under the Apache 2.0 License.
+
+## 🙏 Appreciation
+
+- [Tortoise-tts](https://github.com/neonbjb/tortoise-tts) for inspiring the architecture
+- [M4t Seamless](https://github.com/facebookresearch/seamless_communication) [AudioLM](https://arxiv.org/abs/2209.03143) and many other ground-breaking papers that enabled the development of MahaTTS
+- [BIGVGAN](https://github.com/NVIDIA/BigVGAN) out of the box vocoder
+- [Flow training](https://github.com/shivammehta25/Matcha-TTS) for training Flow model
+- [Huggingface](https://huggingface.co/docs/transformers/index) for related training and inference code
diff --git a/S2A/__pycache__/Diffusion.cpython-310.pyc b/S2A/__pycache__/Diffusion.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..1908fe030209e428f70dec58da299a71d626df53
Binary files /dev/null and b/S2A/__pycache__/Diffusion.cpython-310.pyc differ
diff --git a/S2A/__pycache__/diff_model.cpython-310.pyc b/S2A/__pycache__/diff_model.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b7de964cb5fb844b45049a016127d994390e5bba
Binary files /dev/null and b/S2A/__pycache__/diff_model.cpython-310.pyc differ
diff --git a/S2A/__pycache__/flow_matching.cpython-310.pyc b/S2A/__pycache__/flow_matching.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..b44153c0297855c37f013413a6610f9df2038a99
Binary files /dev/null and b/S2A/__pycache__/flow_matching.cpython-310.pyc differ
diff --git a/S2A/__pycache__/inference.cpython-310.pyc b/S2A/__pycache__/inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..9d11e3d9d4f5971e06bf0fb37d15c59b2d63d392
Binary files /dev/null and b/S2A/__pycache__/inference.cpython-310.pyc differ
diff --git a/S2A/__pycache__/mel_spec.cpython-310.pyc b/S2A/__pycache__/mel_spec.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..c6071517f11990f277829009f2690e7326a134ec
Binary files /dev/null and b/S2A/__pycache__/mel_spec.cpython-310.pyc differ
diff --git a/S2A/__pycache__/modules.cpython-310.pyc b/S2A/__pycache__/modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..05a704500c304c4b0792a675b1071587fb7c48a1
Binary files /dev/null and b/S2A/__pycache__/modules.cpython-310.pyc differ
diff --git a/S2A/__pycache__/stft.cpython-310.pyc b/S2A/__pycache__/stft.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..8afb10b0059356078a4c9e1bd6fe114f30b17080
Binary files /dev/null and b/S2A/__pycache__/stft.cpython-310.pyc differ
diff --git a/S2A/__pycache__/utilities.cpython-310.pyc b/S2A/__pycache__/utilities.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..864802e76aa8aa81d6a1661ca630c816e5827939
Binary files /dev/null and b/S2A/__pycache__/utilities.cpython-310.pyc differ
diff --git a/S2A/diff_model.py b/S2A/diff_model.py
new file mode 100755
index 0000000000000000000000000000000000000000..91752c3f7f031d8f7191e554d71159e2e2de7f99
--- /dev/null
+++ b/S2A/diff_model.py
@@ -0,0 +1,335 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from torch import autocast
+
+from config import config
+
+from .modules import GST, AttentionBlock, mySequential, normalization
+
+
+def timestep_embedding(timesteps, dim, max_period=10000):
+ """
+ Create sinusoidal timestep embeddings.
+
+ :param timesteps: a 1-D Tensor of N indices, one per batch element.
+ These may be fractional.
+ :param dim: the dimension of the output.
+ :param max_period: controls the minimum frequency of the embeddings.
+ :return: an [N x dim] Tensor of positional embeddings.
+ """
+ half = dim // 2
+ freqs = torch.exp(
+ -math.log(max_period)
+ * torch.arange(start=0, end=half, dtype=torch.float32)
+ / half
+ ).to(device=timesteps.device)
+ args = timesteps[:, None].float() * freqs[None]
+ embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
+ if dim % 2:
+ embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
+ return embedding
+
+
+class TimestepBlock(nn.Module):
+ def forward(self, x, emb):
+ """
+ Apply the module to `x` given `emb` timestep embeddings.
+ """
+
+
+class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
+ def forward(self, x, emb):
+ for layer in self:
+ if isinstance(layer, TimestepBlock):
+ x = layer(x, emb)
+ else:
+ x = layer(x)
+ return x
+
+
+class QuartzNetBlock(TimestepBlock):
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
+ """
+
+ def __init__(
+ self,
+ nin,
+ nout,
+ emb_channels,
+ kernel_size=3,
+ dropout=0.1,
+ R=1,
+ se=True,
+ ratio=8,
+ separable=False,
+ bias=True,
+ use_scale_shift_norm=True,
+ ):
+ super(QuartzNetBlock, self).__init__()
+ self.use_scale_shift_norm = use_scale_shift_norm
+ self.se = se
+ self.in_layers = mySequential(
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
+ nn.SiLU(),
+ normalization(nout),
+ )
+
+ if nin == nout:
+ self.residual = nn.Identity()
+ else:
+ self.residual = nn.Conv1d(
+ nin, nout, kernel_size=1, padding="same", bias=bias
+ )
+
+ nin = nout
+ self.model = nn.Sequential(
+ nn.Conv1d(nin, nout, kernel_size, padding="same"),
+ nn.SiLU(),
+ normalization(nout),
+ nn.Dropout(p=dropout),
+ )
+
+ self.emb_layers = nn.Sequential(
+ nn.Linear(
+ emb_channels,
+ 2 * nout if use_scale_shift_norm else nout,
+ ),
+ nn.SiLU(),
+ )
+
+ def forward(self, x, emb, mask=None):
+ x_new = self.in_layers(x)
+ emb = self.emb_layers(emb)
+ while len(emb.shape) < len(x_new.shape):
+ emb = emb[..., None]
+ scale, shift = torch.chunk(emb, 2, dim=1)
+ x_new = x_new * (1 + scale) + shift
+ y = self.model(x_new)
+
+ return y + self.residual(x)
+
+
+class QuartzAttn(TimestepBlock):
+ def __init__(self, model_channels, dropout, num_heads):
+ super().__init__()
+ self.resblk = QuartzNetBlock(
+ model_channels,
+ model_channels,
+ model_channels,
+ dropout=dropout,
+ use_scale_shift_norm=True,
+ )
+ self.attn = AttentionBlock(
+ model_channels, num_heads, relative_pos_embeddings=True
+ )
+
+ def forward(self, x, time_emb):
+ y = self.resblk(x, time_emb)
+ return self.attn(y)
+
+
+class QuartzNet9x5(nn.Module):
+ def __init__(self, model_channels, num_heads, dropout=0.1, enable_fp16=False):
+ super(QuartzNet9x5, self).__init__()
+ self.enable_fp16 = enable_fp16
+ kernels = [3] * 10
+ quartznet = []
+ attn = []
+ for i in kernels:
+ quartznet.append(
+ QuartzNetBlock(
+ model_channels,
+ model_channels,
+ model_channels,
+ kernel_size=i,
+ dropout=dropout,
+ R=5,
+ se=True,
+ )
+ )
+ attn.append(
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
+ )
+
+ self.quartznet = nn.ModuleList(quartznet)
+ self.attn = nn.ModuleList(attn)
+ self.conv2 = nn.ModuleList(
+ [
+ QuartzNetBlock(
+ model_channels,
+ model_channels,
+ model_channels,
+ kernel_size=3,
+ dropout=dropout,
+ R=3,
+ separable=False,
+ )
+ for i in range(3)
+ ]
+ )
+ self.conv3 = nn.Sequential(
+ nn.Conv1d(model_channels, model_channels, 3, padding="same"),
+ nn.SiLU(),
+ normalization(model_channels),
+ nn.Conv1d(model_channels, 100, 1, padding="same"),
+ )
+
+ def forward(self, x, time_emb):
+ for n, (layer, attn) in enumerate(zip(self.quartznet, self.attn)):
+ x = layer(x, time_emb) # 256 dim
+ x = attn(x)
+ for layer in self.conv2:
+ x = layer(x, time_emb)
+
+ x = self.conv3(x)
+ return x
+
+
+class DiffModel(nn.Module):
+ def __init__(
+ self,
+ input_channels=80,
+ output_channels=160,
+ model_channels=256,
+ num_heads=8,
+ dropout=0.1,
+ num_layers=8,
+ multispeaker=True,
+ style_tokens=100,
+ enable_fp16=False,
+ condition_free_per=0.1,
+ training=False,
+ ar_active=False,
+ in_latent_channels=10004,
+ ):
+ super().__init__()
+ self.input_channels = input_channels
+ self.model_channels = model_channels
+ self.output_channels = output_channels
+ self.num_heads = num_heads
+ self.dropout = dropout
+ self.num_layers = num_layers
+ self.enable_fp16 = enable_fp16
+ self.condition_free_per = condition_free_per
+ self.training = training
+ self.multispeaker = multispeaker
+ self.ar_active = ar_active
+ self.in_latent_channels = in_latent_channels
+
+ if not self.ar_active:
+ self.code_emb = nn.Embedding(
+ config.semantic_model_centroids + 1, model_channels
+ )
+ self.code_converter = mySequential(
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ )
+ else:
+ self.code_converter = mySequential(
+ nn.Conv1d(
+ self.in_latent_channels, model_channels, 3, padding=1, bias=True
+ ),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
+ )
+ if self.multispeaker:
+ self.GST = GST(
+ model_channels, style_tokens, num_heads, in_channels=input_channels
+ )
+
+ self.code_norm = normalization(model_channels)
+ self.time_norm = normalization(model_channels)
+ self.code_time_norm = normalization(model_channels)
+
+ self.time_embed = mySequential(
+ nn.Linear(model_channels, model_channels),
+ nn.SiLU(),
+ nn.Linear(model_channels, model_channels),
+ )
+
+ self.input_block = nn.Conv1d(input_channels, model_channels, 3, 1, 1, bias=True)
+ self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
+ self.integrating_conv = nn.Conv1d(
+ model_channels * 2, model_channels, kernel_size=1
+ )
+
+ self.code_time = TimestepEmbedSequential(
+ QuartzAttn(model_channels, dropout, num_heads),
+ QuartzAttn(model_channels, dropout, num_heads),
+ QuartzAttn(model_channels, dropout, num_heads),
+ )
+
+ self.layers = QuartzNet9x5(
+ model_channels, num_heads, self.enable_fp16, self.dropout
+ )
+
+ def get_speaker_latent(self, ref_mels):
+ ref_mels = ref_mels.unsqueeze(1) if len(ref_mels.shape) == 3 else ref_mels
+
+ conds = []
+ for j in range(ref_mels.shape[1]):
+ conds.append(self.GST(ref_mels[:, j, :, :]))
+
+ conds = torch.cat(conds, dim=-1)
+ conds = conds.mean(dim=-1)
+
+ return conds.unsqueeze(2)
+
+ def forward(
+ self,
+ x,
+ t,
+ code_emb,
+ ref_clips=None,
+ speaker_latents=None,
+ conditioning_free=False,
+ ):
+ time_embed = self.time_norm(
+ self.time_embed(
+ timestep_embedding(t.unsqueeze(-1), self.model_channels)
+ ).permute(0, 2, 1)
+ ).squeeze(2)
+ if conditioning_free:
+ code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
+ else:
+ if not self.ar_active:
+ code_embed = self.code_norm(
+ self.code_converter(self.code_emb(code_emb).permute(0, 2, 1))
+ )
+ else:
+ code_embed = self.code_norm(self.code_converter(code_emb))
+ if self.multispeaker:
+ assert speaker_latents is not None or ref_clips is not None
+ if ref_clips is not None:
+ speaker_latents = self.get_speaker_latent(ref_clips)
+ cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
+ code_embed = code_embed * (1 + cond_scale) + cond_shift
+
+ if self.training and self.condition_free_per > 0:
+ unconditioned_batches = (
+ torch.rand((code_embed.shape[0], 1, 1), device=code_embed.device)
+ < self.condition_free_per
+ )
+ code_embed = torch.where(
+ unconditioned_batches,
+ self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
+ code_embed,
+ )
+
+ expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode="linear")
+
+ x_cond = self.code_time_norm(self.code_time(expanded_code_emb, time_embed))
+
+ x = self.input_block(x)
+ x = torch.cat([x, x_cond], dim=1)
+ x = self.integrating_conv(x)
+ out = self.layers(x, time_embed)
+
+ return out
diff --git a/S2A/flow_matching.py b/S2A/flow_matching.py
new file mode 100755
index 0000000000000000000000000000000000000000..fd7c1995764033d9f3eb86fe45ad462376f33458
--- /dev/null
+++ b/S2A/flow_matching.py
@@ -0,0 +1,123 @@
+from abc import ABC
+
+import torch
+import torch.nn.functional as F
+
+# log = get_pylogger(__name__)
+
+
+class BASECFM(torch.nn.Module):
+ def __init__(
+ self,
+ ):
+ super().__init__()
+ # self.n_feats = n_feats
+ # self.n_spks = n_spks
+ # self.spk_emb_dim = spk_emb_dim
+ # self.solver = cfm_params.solver
+ # if hasattr(cfm_params, "sigma_min"):
+ # self.sigma_min = cfm_params.sigma_min
+ # else:
+ self.sigma_min = 1e-4 # 0.0#1e-4
+
+ @torch.inference_mode()
+ def forward(
+ self, model, code, output_shape, ref_mels, n_timesteps=20, temperature=1.0
+ ):
+ """Forward diffusion
+
+ Args:
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ n_timesteps (int): number of diffusion steps
+ temperature (float, optional): temperature for scaling noise. Defaults to 1.0.
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+
+ Returns:
+ sample: generated mel-spectrogram
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ z = torch.randn(output_shape, device=code.device) * temperature
+ t_span = torch.linspace(0, 1, n_timesteps + 1, device=code.device)
+ return self.solve_euler(model, z, t_span=t_span, code=code, ref_mels=ref_mels)
+
+ def solve_euler(self, model, x, t_span, code, ref_mels):
+ """
+ Fixed euler solver for ODEs.
+ Args:
+ x (torch.Tensor): random noise
+ t_span (torch.Tensor): n_timesteps interpolated
+ shape: (n_timesteps + 1,)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): output_mask
+ shape: (batch_size, 1, mel_timesteps)
+ spks (torch.Tensor, optional): speaker ids. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+ cond: Not used but kept for future purposes
+ """
+ t, _, dt = t_span[0], t_span[-1], t_span[1] - t_span[0]
+
+ # I am storing this because I can later plot it by putting a debugger here and saving it to a file
+ # Or in future might add like a return_all_steps flag
+ sol = []
+
+ for step in range(1, len(t_span)):
+ dphi_dt = model(x, t.unsqueeze(0), code_emb=code, ref_clips=ref_mels)
+
+ x = x + dt * dphi_dt
+ t = t + dt
+ sol.append(x)
+ if step < len(t_span) - 1:
+ dt = t_span[step + 1] - t
+
+ return sol[-1]
+
+ def compute_loss(self, model, x1, mask, code, ref_mels):
+ """Computes diffusion loss
+
+ Args:
+ x1 (torch.Tensor): Target
+ shape: (batch_size, n_feats, mel_timesteps)
+ mask (torch.Tensor): target mask
+ shape: (batch_size, 1, mel_timesteps)
+ mu (torch.Tensor): output of encoder
+ shape: (batch_size, n_feats, mel_timesteps)
+ spks (torch.Tensor, optional): speaker embedding. Defaults to None.
+ shape: (batch_size, spk_emb_dim)
+
+ Returns:
+ loss: conditional flow matching loss
+ y: conditional flow
+ shape: (batch_size, n_feats, mel_timesteps)
+ """
+ b, _, t = x1.shape
+
+ # random timestep
+ t = torch.rand([b, 1, 1], device=x1.device, dtype=x1.dtype)
+ # sample noise p(x_0)
+ z = torch.randn_like(x1)
+
+ y = (1 - (1 - self.sigma_min) * t) * z + t * x1
+ u = x1 - (1 - self.sigma_min) * z
+
+ # wrong weightage
+ # loss = F.mse_loss(model(y,t.squeeze(),code_emb=code,ref_clips=ref_mels), u, reduction='none').mean(dim=-2) # B,80,t -> B,t
+ # loss *= mask # B,t
+ # loss = loss.sum(-1) / mask.sum(-1) # B,t -> B
+ # loss = loss.sum()/loss.shape[0] # B -> 1
+
+ loss = torch.sum(
+ F.mse_loss(
+ model(y, t.squeeze(), code_emb=code, ref_clips=ref_mels),
+ u,
+ reduction="none",
+ )
+ * mask.unsqueeze(1)
+ ) / (mask.sum() * u.shape[1])
+
+ return loss, y, t
diff --git a/S2A/inference.py b/S2A/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..6cbf9e91cdd490fa91c73a0b4073b6a7d716e599
--- /dev/null
+++ b/S2A/inference.py
@@ -0,0 +1,73 @@
+import os
+import sys
+
+sys.path.append(
+ os.path.abspath(
+ os.path.join(os.path.dirname(__file__), "../bigvgan_v2_24khz_100band_256x/")
+ )
+)
+
+import bigvgan
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+from pydub import AudioSegment
+from tqdm import tqdm
+
+from config import config
+
+from .flow_matching import BASECFM
+from .utilities import denormalize_tacotron_mel, normalize_tacotron_mel
+
+
+def infer(model, timeshapes, code_embs, ref_mels, epoch=0):
+ os.makedirs("Samples/" + config.model_name + "/S2A/", exist_ok=True)
+ FM = BASECFM()
+ device = next(model.parameters()).device
+
+ hifi = bigvgan.BigVGAN.from_pretrained(
+ "nvidia/bigvgan_v2_24khz_100band_256x", use_cuda_kernel=False
+ )
+ hifi.remove_weight_norm()
+ hifi = hifi.eval().to(device)
+
+ audio_paths = []
+ mels = []
+ for n, (timeshape, code_emb, ref_mel) in enumerate(
+ zip(timeshapes, code_embs, ref_mels)
+ ):
+ with torch.no_grad():
+ mel = FM(
+ model,
+ code_emb.unsqueeze(0).to(device),
+ (1, 100, timeshape),
+ ref_mel.unsqueeze(0).to(device),
+ n_timesteps=20,
+ temperature=1.0,
+ )
+ mel = denormalize_tacotron_mel(mel)
+ mels.append(mel)
+ audio = hifi(mel)
+ audio = audio.squeeze(0).detach().cpu()
+ audio = audio * 32767.0
+ audio = audio.numpy().reshape(-1).astype(np.int16)
+
+ audio_path = (
+ "../Samples/"
+ + config.model_name
+ + "/S2A/"
+ + str(epoch)
+ + "_"
+ + str(n)
+ + ".wav"
+ )
+ AudioSegment(
+ audio.tobytes(),
+ frame_rate=24000,
+ sample_width=audio.dtype.itemsize,
+ channels=1,
+ ).export(audio_path, format="wav")
+ audio_paths.append(audio_path)
+
+ return audio_paths, mels
diff --git a/S2A/mel_spec.py b/S2A/mel_spec.py
new file mode 100755
index 0000000000000000000000000000000000000000..c21ced11241dcc850a74f7cb4f24e4071d3fc935
--- /dev/null
+++ b/S2A/mel_spec.py
@@ -0,0 +1,176 @@
+# Copyright (c) 2024 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+if __name__ == "__main__":
+ import os
+ import sys
+
+ sys.path.append("../")
+
+import math
+import os
+import pathlib
+import random
+
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from librosa.util import normalize
+from scipy.io.wavfile import read
+from tqdm import tqdm
+
+from config import config
+
+MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
+
+
+def load_wav(full_path, sr_target):
+ sampling_rate, data = read(full_path)
+ if sampling_rate != sr_target:
+ raise RuntimeError(
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
+ )
+ return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ return dynamic_range_compression_torch(magnitudes)
+
+
+def spectral_de_normalize_torch(magnitudes):
+ return dynamic_range_decompression_torch(magnitudes)
+
+
+mel_basis_cache = {}
+hann_window_cache = {}
+
+
+def mel_spectrogram(
+ y: torch.Tensor,
+ n_fft: int,
+ num_mels: int,
+ sampling_rate: int,
+ hop_size: int,
+ win_size: int,
+ fmin: int,
+ fmax: int = None,
+ center: bool = False,
+) -> torch.Tensor:
+ """
+ Calculate the mel spectrogram of an input signal.
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
+
+ Args:
+ y (torch.Tensor): Input signal.
+ n_fft (int): FFT size.
+ num_mels (int): Number of mel bins.
+ sampling_rate (int): Sampling rate of the input signal.
+ hop_size (int): Hop size for STFT.
+ win_size (int): Window size for STFT.
+ fmin (int): Minimum frequency for mel filterbank.
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
+ center (bool): Whether to pad the input to center the frames. Default is False.
+
+ Returns:
+ torch.Tensor: Mel spectrogram.
+ """
+ if torch.min(y) < -1.0:
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
+ if torch.max(y) > 1.0:
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
+
+ device = y.device
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
+
+ if key not in mel_basis_cache:
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
+
+ mel_basis = mel_basis_cache[key]
+ hann_window = hann_window_cache[key]
+
+ padding = (n_fft - hop_size) // 2
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (padding, padding), mode="reflect"
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window,
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
+
+ mel_spec = torch.matmul(mel_basis, spec)
+ mel_spec = spectral_normalize_torch(mel_spec)
+
+ return mel_spec
+
+
+def get_mel_spectrogram(wav, sr):
+ """
+ Generate mel spectrogram from a waveform using given hyperparameters.
+
+ Args:
+ wav (torch.Tensor): Input waveform.
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
+
+ Returns:
+ torch.Tensor: Mel spectrogram.
+ """
+
+ assert sr == config.sampling_rate
+
+ return mel_spectrogram(
+ wav,
+ config.filter_length,
+ config.n_mel_channels,
+ config.sampling_rate,
+ config.hop_length,
+ config.win_length,
+ config.mel_fmin,
+ config.mel_fmax,
+ )
+
+
+if __name__ == "__main__":
+ import torchaudio
+
+ path = "/delta/NeuralSpeak_cfm_conv/Samples/IITM_cfm_bigv_harsh/S2A/orig/0_test.wav"
+ wav, sr = torchaudio.load(path)
+
+ wav = wav[:, :sr]
+ print(wav.shape)
+ mel_spec = get_mel_spectrogram(wav, sr)
+ duration = wav.shape[-1] / sr
+ print(duration, mel_spec.shape)
diff --git a/S2A/modules.py b/S2A/modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..63822d00c9246fb8d3baa22e07589fc0019fb4b2
--- /dev/null
+++ b/S2A/modules.py
@@ -0,0 +1,683 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from einops import rearrange, repeat
+from torch.nn.utils import weight_norm
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ Using it for Zero Convolutions
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer. of groups ranging from 2 to 32.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ # return nn.LayerNorm(normalized_shape)
+ groups = 32
+ if channels <= 16:
+ groups = 8
+ elif channels <= 64:
+ groups = 16
+ while channels % groups != 0:
+ groups = int(groups / 2)
+ assert groups > 2
+ return GroupNorm32(groups, channels)
+
+
+class mySequential(nn.Sequential):
+ """Using this to pass mask variable to nn layers"""
+
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ if type(inputs) == tuple:
+ inputs = module(*inputs)
+ else:
+ inputs = module(inputs)
+ return inputs
+
+
+class SepConv1D(nn.Module):
+ """Depth wise separable Convolution layer with mask"""
+
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ padding_mode="same",
+ bias=False,
+ ):
+ super(SepConv1D, self).__init__()
+ self.conv1 = nn.Conv1d(
+ nin,
+ nin,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=nin,
+ dilation=dilation,
+ padding=padding_mode,
+ bias=bias,
+ )
+ self.conv2 = nn.Conv1d(
+ nin, nout, kernel_size=1, stride=1, padding=padding_mode, bias=bias
+ )
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x, mask
+
+
+class Conv1DBN(nn.Module):
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ dropout=0.1,
+ padding_mode="same",
+ bias=False,
+ ):
+ super(Conv1DBN, self).__init__()
+ self.conv1 = nn.Conv1d(
+ nin,
+ nout,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding_mode,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.bn = nn.BatchNorm1d(nout)
+ self.drop = nn.Dropout(dropout)
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x = self.conv1(x)
+ x = self.bn(x)
+ x = F.silu(x)
+ x = self.drop(x)
+ return x, mask
+
+
+class Conv1d(nn.Module):
+ """normal conv1d with mask"""
+
+ def __init__(self, nin, nout, kernel_size, padding, bias=False):
+ super(Conv1d, self).__init__()
+ self.l = nn.Conv1d(nin, nout, kernel_size, padding=padding, bias=bias)
+
+ def forward(self, x, mask):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ y = self.l(x)
+ return y, mask
+
+
+class SqueezeExcite(nn.Module):
+ """Let the CNN decide how to add across channels"""
+
+ def __init__(self, nin, ratio=8):
+ super(SqueezeExcite, self).__init__()
+ self.nin = nin
+ self.ratio = ratio
+
+ self.fc = mySequential(
+ nn.Linear(nin, nin // ratio, bias=True),
+ nn.SiLU(inplace=True),
+ nn.Linear(nin // ratio, nin, bias=True),
+ )
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=torch.bool).to(x.device)
+ mask = ~mask
+ x = x.float()
+ x.masked_fill_(mask.unsqueeze(1), 0.0)
+ mask = ~mask
+ y = (
+ torch.sum(x, dim=-1, keepdim=True)
+ / mask.unsqueeze(1).sum(dim=-1, keepdim=True)
+ ).type(x.dtype)
+ # y=torch.mean(x,-1,keepdim=True)
+ y = y.transpose(1, -1)
+ y = self.fc(y)
+ y = torch.sigmoid(y)
+ y = y.transpose(1, -1)
+ y = x * y
+ return y, mask
+
+
+class SCBD(nn.Module):
+ """SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet"""
+
+ def __init__(
+ self, nin, nout, kernel_size, p=0.1, rd=True, separable=True, bias=False
+ ):
+ super(SCBD, self).__init__()
+ if separable:
+ self.SC = SepConv1D(nin, nout, kernel_size, bias=bias)
+ else:
+ self.SC = Conv1d(nin, nout, kernel_size, padding="same", bias=bias)
+
+ if rd: # relu and Dropout
+ self.mout = mySequential(
+ normalization(nout),
+ nn.SiLU(), # nn.BatchNorm1d(nout,eps)
+ nn.Dropout(p),
+ )
+ else:
+ self.mout = normalization(nout) # nn.BatchNorm1d(nout,eps)
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x, _ = self.SC(x, mask)
+ y = self.mout(x)
+ return y, mask
+
+
+class QuartzNetBlock(nn.Module):
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
+ """
+
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ dropout=0.1,
+ R=5,
+ se=False,
+ ratio=8,
+ separable=False,
+ bias=False,
+ ):
+ super(QuartzNetBlock, self).__init__()
+ self.se = se
+ self.residual = mySequential(
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
+ normalization(nout), # nn.BatchNorm1d(nout,eps)
+ )
+ model = []
+
+ for i in range(R - 1):
+ model.append(SCBD(nin, nout, kernel_size, dropout, eps=0.001, bias=bias))
+ nin = nout
+
+ if separable:
+ model.append(
+ SCBD(nin, nout, kernel_size, dropout, eps=0.001, rd=False, bias=bias)
+ )
+ else:
+ model.append(
+ SCBD(
+ nin,
+ nout,
+ kernel_size,
+ dropout,
+ eps=0.001,
+ rd=False,
+ separable=False,
+ bias=bias,
+ )
+ )
+ self.model = mySequential(*model)
+
+ if self.se:
+ # model.append(SqueezeExcite(nin,ratio))
+ self.se_layer = SqueezeExcite(nin, ratio)
+
+ self.mout = mySequential(nn.SiLU(), nn.Dropout(dropout))
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ y, _ = self.model(x, mask)
+ if self.se:
+ y, _ = self.se_layer(y, mask)
+ y += self.residual(x)
+ y = self.mout(y)
+ return y, mask
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, mask=None, rel_pos=None):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ if rel_pos is not None:
+ weight = rel_pos(
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ if mask is not None:
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
+ weight = weight * mask
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ return a.reshape(bs, -1, length)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ do_checkpoint=True,
+ relative_pos_embeddings=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.do_checkpoint = do_checkpoint
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert channels % num_head_channels == 0, (
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ )
+ self.num_heads = channels // num_head_channels
+ self.norm = normalization(channels)
+ self.qkv = nn.Conv1d(channels, channels * 3, 1, bias=False)
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(
+ nn.Conv1d(channels, channels, 1, bias=False)
+ ) # no effect of attention in the inital stages.
+ # if relative_pos_embeddings:
+ self.relative_pos_embeddings = RelativePositionBias(
+ scale=(channels // self.num_heads) ** 0.5,
+ causal=False,
+ heads=num_heads,
+ num_buckets=64,
+ max_distance=128,
+ )
+
+ def forward(self, x, mask=None):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.scale = dim**-0.5
+ self.emb = nn.Embedding(max_seq_len, dim)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ pos_emb = self.emb(n)
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
+ return pos_emb * self.scale
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = (
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ + offset
+ )
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return rearrange(emb, "n d -> () n d")
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, scale, causal=False, num_buckets=16, max_distance=32, heads=8):
+ super().__init__()
+ self.scale = scale
+ self.causal = causal
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position, causal=True, num_buckets=16, max_distance=32
+ ):
+ ret = 0
+ n = -relative_position
+ if not causal:
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+ else:
+ n = torch.max(n, torch.zeros_like(n))
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = (
+ max_exact
+ + (
+ torch.log(n.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).long()
+ )
+ val_if_large = torch.min(
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
+ )
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, qk_dots):
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
+ rel_pos = k_pos[None, :] - q_pos[:, None]
+ rp_bucket = self._relative_position_bucket(
+ rel_pos,
+ causal=self.causal,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ values = self.relative_attention_bias(rp_bucket)
+ bias = rearrange(values, "i j h -> () h i j")
+ return qk_dots + (bias * self.scale)
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ only for GST
+ input:
+ query --- [N, T_q, query_dim]
+ key --- [N, T_k, key_dim]
+ output:
+ out --- [N, T_q, num_units]
+ """
+
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
+ super().__init__()
+ self.num_units = num_units
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+
+ self.W_query = nn.Linear(
+ in_features=query_dim, out_features=num_units, bias=False
+ )
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
+ self.W_value = nn.Linear(
+ in_features=key_dim, out_features=num_units, bias=False
+ )
+
+ def forward(self, query, key):
+ querys = self.W_query(query) # [N, T_q, num_units]
+ keys = self.W_key(key) # [N, T_k, num_units]
+ values = self.W_value(key)
+
+ split_size = self.num_units // self.num_heads
+ querys = torch.stack(
+ torch.split(querys, split_size, dim=2), dim=0
+ ) # [h, N, T_q, num_units/h]
+ keys = torch.stack(
+ torch.split(keys, split_size, dim=2), dim=0
+ ) # [h, N, T_k, num_units/h]
+ values = torch.stack(
+ torch.split(values, split_size, dim=2), dim=0
+ ) # [h, N, T_k, num_units/h]
+
+ # score = softmax(QK^T / (d_k ** 0.5))
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
+ scores = scores / (self.key_dim**0.5)
+ scores = F.softmax(scores, dim=3)
+
+ # out = score * V
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
+ 0
+ ) # [N, T_q, num_units]
+
+ return out
+
+
+class GST(nn.Module):
+ def __init__(
+ self, model_channels=512, style_tokens=100, num_heads=8, in_channels=100
+ ):
+ super(GST, self).__init__()
+ self.model_channels = model_channels
+ self.style_tokens = style_tokens
+ self.num_heads = num_heads
+
+ # self.reference_encoder=nn.Sequential(
+ # nn.Conv2d(1,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
+ # nn.Conv2d(32,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
+ # nn.Conv2d(32,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
+ # nn.Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
+ # AttentionBlock(64, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(128,model_channels,kernel_size=(3,3),stride=(1,1),padding=(1, 1)),normalization(model_channels),nn.ReLU(inplace=True),
+ # AttentionBlock(model_channels, 16, relative_pos_embeddings=True)
+ # )
+
+ # self.reference_encoder=nn.Sequential(
+ # nn.Conv1d(80,model_channels,3,padding=1,stride=2),
+ # nn.Conv1d(model_channels, model_channels,3,padding=1,stride=2),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
+ # )
+
+ # in_channels=1
+ # num_heads = 8
+ self.reference_encoder = nn.Sequential(
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2, bias=False),
+ nn.Conv1d(
+ model_channels, model_channels * 2, 3, padding=1, stride=2, bias=False
+ ),
+ AttentionBlock(
+ model_channels * 2,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels * 2,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels * 2,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels * 2,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels * 2,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ # nn.Conv1d(model_channels*2, 64,3,padding=1,stride=2),
+ # nn.Conv1d(64, model_channels*2,3,padding=1,stride=2) #added bottleneck
+ )
+ # bottleneck = 64
+ # self.bottleneck = nn.Sequential(nn.Conv1d(model_channels*2,bottleneck,3,padding=1,stride=1),nn.SiLU(),
+ # nn.Conv1d(bottleneck,model_channels*2,3,padding=1,stride=1),nn.SiLU())
+ # self.gru=nn.GRU(128*2,256,batch_first=True,bidirectional=True)
+ # self.attention = MultiHeadAttention(query_dim=model_channels, key_dim=model_channels//num_heads, num_units=model_channels*2, num_heads=num_heads)
+ # self.style_tokens = nn.parameter.Parameter(torch.FloatTensor(style_tokens,model_channels//num_heads))
+
+ # init.normal_(self.style_tokens, mean=0, std=0.5)
+
+ def forward(self, x):
+ # add masking
+ # batch=x.size(0)
+ # x=x.view(batch,1,-1,80) # (N,1,t,80)
+ x = self.reference_encoder(x) # (N,128,t,80//x)
+ # print(x.shape)
+ # x = self.bottleneck(x)
+ # print(x.shape)
+ # print(x.shape,'encoder')
+ # x = x.mean(dim=-1)#.mean(dim=-1)
+ # # x=x.transpose(1,2).contiguous() #(N,t,128,80//x)
+ # # time=x.size(1)
+ # # x=x.view(batch,time,-1)
+ # # _,x=self.gru(x)
+ # # print(x.shape,'gru')
+ # x=x.view(batch,1,-1)
+ # keys = self.style_tokens.unsqueeze(0).expand(batch, -1, -1) # [N, token_num, E // num_heads]
+ # # print(keys.shape,'keys')
+ # style_embed = self.attention(x, keys)
+ # # print(style_embed.shape,'gst tokens')
+
+ # add normalization?
+
+ return x
+
+
+# class GST(nn.Module):
+# """
+# inputs --- [N, Ty/r, n_mels*r] mels
+# outputs --- [N, ref_enc_gru_size]
+# """
+
+# def __init__(self, spec_channels=80, gin_channels=512, layernorm=True):
+# super().__init__()
+# self.spec_channels = spec_channels
+# ref_enc_filters = [32, 32, 64, 64, 128, 128]
+# K = len(ref_enc_filters)
+# filters = [1] + ref_enc_filters
+# convs = [
+# weight_norm(
+# nn.Conv2d(
+# in_channels=filters[i],
+# out_channels=filters[i + 1],
+# kernel_size=(3, 3),
+# stride=(2, 2),
+# padding=(1, 1),
+# )
+# )
+# for i in range(K)
+# ]
+# self.convs = nn.ModuleList(convs)
+
+# out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
+# self.gru = nn.GRU(
+# input_size=ref_enc_filters[-1] * out_channels,
+# hidden_size=256 // 2,
+# batch_first=True,
+# )
+# self.proj = nn.Linear(128, gin_channels)
+# if layernorm:
+# self.layernorm = nn.LayerNorm(self.spec_channels)
+# else:
+# self.layernorm = None
+
+# def forward(self, inputs, mask=None):
+# N = inputs.size(0)
+
+# out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
+# if self.layernorm is not None:
+# out = self.layernorm(out)
+
+# for conv in self.convs:
+# out = conv(out)
+# # out = wn(out)
+# out = F.silu(out) # [N, 128, Ty//2^K, n_mels//2^K]
+
+# out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
+# T = out.size(1)
+# N = out.size(0)
+# out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
+
+# self.gru.flatten_parameters()
+# memory, out = self.gru(out) # out --- [1, N, 128]
+
+# return self.proj(out.squeeze(0))
+
+# def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
+# for i in range(n_convs):
+# L = (L - kernel_size + 2 * pad) // stride + 1
+# return L
+
+
+if __name__ == "__main__":
+ device = torch.device("cpu")
+ m = GST(512, 10).to(device)
+ mels = torch.rand((16, 80, 1000)).to(device)
+
+ o = m(mels)
+ print(o.shape, "final output")
+
+ from torchinfo import summary
+
+ summary(m, input_data={"x": torch.randn(16, 80, 500).to(device)})
diff --git a/S2A/train.py b/S2A/train.py
new file mode 100755
index 0000000000000000000000000000000000000000..ff43250e09cd7a360c467b427d4de2978d0f2e56
--- /dev/null
+++ b/S2A/train.py
@@ -0,0 +1,661 @@
+import os
+import sys
+import time
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+import linecache
+import mmap
+import pickle as pkl
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchaudio
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from mel_spec import get_mel_spectrogram
+from torch.distributed import destroy_process_group, init_process_group
+from torch.distributed.elastic.utils.data import ElasticDistributedSampler
+from torch.nn.parallel import DistributedDataParallel as DDP
+from torch.utils.data import DataLoader, Dataset
+from torch.utils.data.distributed import DistributedSampler
+from tqdm.auto import tqdm
+
+import wandb
+from config import config
+from S2A.diff_model import DiffModel
+from S2A.flow_matching import BASECFM
+from S2A.inference import infer
+from S2A.utilities import (dynamic_range_compression, get_mask,
+ get_mask_from_lengths, load_wav_to_torch,
+ normalize_tacotron_mel)
+from Text import code_labels, labels, text_labels
+
+# import torch
+torch.backends.cuda.matmul.allow_tf32 = True
+torch.backends.cudnn.allow_tf32 = True
+
+torch.manual_seed(config.seed_value)
+np.random.seed(config.seed_value)
+random.seed(config.seed_value)
+
+CLIP_LENGTH = config.CLIP_LENGTH
+
+# code encdec
+text_enc = {j: i for i, j in enumerate(text_labels)}
+text_dec = {i: j for i, j in enumerate(text_labels)}
+
+# text encdec
+code_enc = {j: i for i, j in enumerate(code_labels)}
+code_dec = {i: j for i, j in enumerate(code_labels)}
+
+
+def read_specific_line(filename, line_number):
+ line = linecache.getline(filename, line_number)
+ return line.strip()
+
+
+class Acoustic_dataset(Dataset):
+ def __init__(
+ self,
+ transcript_path,
+ semantic_path=None,
+ ref_mels_path=None,
+ ref_k=1,
+ scale=True,
+ ar_active=False,
+ clip=True,
+ dur_=None,
+ ):
+ super(Acoustic_dataset).__init__()
+ self.scale = scale
+ self.ar_active = ar_active
+ self.clip = clip
+ self.dur_ = dur_
+ if self.dur_ is None:
+ self.dur_ = 2
+ if not scale:
+ with open(transcript_path, "r") as file:
+ data = file.read().strip("\n").split("\n")[:]
+
+ with open(semantic_path, "r") as file:
+ semb = file.read().strip("\n").split("\n")
+
+ with open(ref_mels_path, "rb") as file:
+ self.ref_mels = pkl.load(file)
+
+ semb = {
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
+ }
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
+
+ self.data = [[i, semb[i], data[i]] for i in data.keys()][:]
+
+ else:
+ self.transcript_path = transcript_path
+ line_index = {}
+ with open(transcript_path, "rb") as file:
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
+ line_number = 0
+ offset = 0
+ progress_bar = tqdm(desc="processing:")
+ while offset < len(mmapped_file):
+ line_index[line_number] = offset
+ offset = mmapped_file.find(b"\n", offset) + 1
+ line_number += 1
+ progress_bar.update(1)
+ progress_bar.close()
+ self.mmapped_file = mmapped_file
+ self.data_len = len(line_index)
+ self.line_index = line_index
+
+ self.ref_k = ref_k
+ self.max_wav_value = config.MAX_WAV_VALUE
+
+ def get_mel(self, filepath, semb_ids=None, align=False, ref_clip=False):
+ audio_norm, sampling_rate = torchaudio.load(filepath)
+ dur = audio_norm.shape[-1] / sampling_rate
+
+ if self.clip and dur > self.dur_ and align:
+ max_audio_start = int(dur - self.dur_)
+ if max_audio_start <= 0:
+ audio_start = 0
+ else:
+ audio_start = np.random.randint(0, max_audio_start)
+
+ audio_norm = audio_norm[
+ :,
+ audio_start * sampling_rate : (audio_start + self.dur_) * sampling_rate,
+ ]
+ semb_ids = semb_ids[audio_start * 50 : ((audio_start + self.dur_) * 50) - 1]
+
+ # 86 mel -> 1s for 22050 setting
+ # 93 mel -> 1s for 24000 setting
+
+ if ref_clip == True:
+ dur_ = 6
+ max_audio_start = int(dur - dur_)
+ if max_audio_start <= 0:
+ audio_start = 0
+ else:
+ audio_start = np.random.randint(0, max_audio_start)
+ audio_norm = audio_norm[
+ :, audio_start * sampling_rate : (audio_start + dur_) * sampling_rate
+ ]
+
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
+ energy = []
+ if align:
+ return melspec, list(energy), semb_ids
+ return melspec, list(energy)
+
+ def __len__(self):
+ if self.scale:
+ return self.data_len
+ return len(self.data)
+
+ def __getitem__(self, index):
+ """
+ mel_spec,semb
+ """
+ if not self.scale:
+ lang, path, semb, text = self.data[index]
+ ref_mels = self.ref_mels[path][: self.ref_k]
+ semb_ids = [int(i) + 1 for i in semb] # 0 for pad
+
+ else:
+ self.mmapped_file.seek(self.line_index[index])
+ line = self.mmapped_file.readline().decode("utf-8")
+
+ lang, path, text, semb_ids = line.split("|")
+ semb_ids = [int(i) + 1 for i in semb_ids.split()]
+ ref_mels = [path][: self.ref_k]
+
+ try:
+ mel_spec, energy, semb_ids = self.get_mel(path, semb_ids, align=True)
+ if len(semb_ids) == 0:
+ raise Exception("Sorry, no semb ids" + str(line))
+ except Exception as e:
+ print(index, e)
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_mels) == 0:
+ print(index, e, "no ref mels present")
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ while len(ref_mels) < self.ref_k:
+ ref_mels.append(ref_mels[-1])
+
+ if mel_spec is None:
+ print(index, e, "mel_spec error present")
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ def get_random_portion(mel, mask_lengths):
+ clip = mask_lengths <= CLIP_LENGTH
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
+ for n, z in enumerate(clip):
+ if not z:
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
+ return ref_mel
+
+ try:
+ ref_mels = [self.get_mel(path, ref_clip=True)[0] for path in ref_mels]
+ except Exception as e:
+ print(index, e, "ref_mels mel_spec error")
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ ref_c = []
+ for i in range(self.ref_k):
+ if ref_mels[i] is None:
+ continue
+ ref_c.append(ref_mels[i])
+
+ if len(ref_c) == 0:
+ print("no refs mel spec found")
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_c) != self.ref_k:
+ while len(ref_c) < self.ref_k:
+ ref_c.append(ref_c[-1])
+
+ ref_mels = ref_c
+ max_target_len = max([x.size(1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len)) * 1e-9
+ )
+ mel_length = []
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, : mel.size(1)] = mel
+ mel_length.append(mel.shape[-1])
+
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
+
+ text_ids = (
+ [text_enc[""]]
+ + [text_enc[i] for i in text.strip() if i in text_enc]
+ + [text_enc[""]]
+ )
+ if self.ar_active:
+ semb_ids = (
+ [code_enc[""]]
+ + [code_enc[str(i - 1)] for i in semb_ids]
+ + [code_enc[""]]
+ )
+
+ return {
+ "mel": mel_spec,
+ "code": semb_ids,
+ "path": path,
+ "ref_mels": ref_mels,
+ "text": text_ids,
+ } # , 'ref_mel_length':mel_length}
+
+
+def get_padded_seq(sequences, pad_random, before=False, pad__=0):
+ max_len = max([len(s) for s in sequences])
+ seq_len = []
+ for i in range(len(sequences)):
+ seq_len.append(len(sequences[i]))
+ if pad_random:
+ pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
+ else:
+ pad_ = [pad__] * (max_len - len(sequences[i]))
+ if not before:
+ sequences[i] = sequences[i] + pad_
+ else:
+ sequences[i] = pad_ + sequences[i]
+
+ return sequences, seq_len
+
+
+def collate(batch):
+ mel_specs = []
+ code = []
+ paths = []
+ ref_mels = []
+ text_ids = []
+
+ for b in batch:
+ mel_specs.append(b["mel"])
+ code.append(b["code"])
+ paths.append(b["path"])
+ ref_mels.append(b["ref_mels"])
+ text_ids.append(b["text"])
+
+ if code[-1][-1] == code_enc[""]:
+ code, code_len = get_padded_seq(code, pad_random=False, pad__=code_enc[""])
+ else:
+ code, code_len = get_padded_seq(code, pad_random=False)
+
+ text_ids, text_len = get_padded_seq(
+ text_ids, pad_random=False, before=True, pad__=text_enc[""]
+ )
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn(
+ (
+ len(batch),
+ ref_mels[0].shape[0],
+ config.n_mel_channels,
+ ref_max_target_len,
+ )
+ )
+ * 1e-9
+ )
+
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
+
+ max_target_len = max([x.size(-1) for x in mel_specs])
+ mel_padded = torch.randn((len(batch), config.n_mel_channels, max_target_len)) * 1e-9
+ mel_length = []
+ for i, mel in enumerate(mel_specs):
+ mel_padded[i, :, : mel.size(-1)] = mel
+ mel_length.append(mel.shape[-1])
+
+ return (
+ normalize_tacotron_mel(mel_padded),
+ torch.tensor(code),
+ torch.tensor(mel_length),
+ torch.tensor(code_len),
+ ref_mels_padded,
+ torch.tensor(text_ids),
+ torch.tensor(text_len),
+ paths,
+ )
+
+
+def train(
+ model,
+ diffuser,
+ train_dataloader,
+ val_dataloader,
+ schedule_sampler=None,
+ rank=0,
+ ar_active=False,
+ m1=None,
+ checkpoint_initial=None,
+):
+ ddp_kwargs = DistributedDataParallelKwargs(find_unused_parameters=True)
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config.gradient_accumulation_steps,
+ kwargs_handlers=[ddp_kwargs],
+ )
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
+ conf_ = {}
+ for i, j in config.__dict__.items():
+ conf_[str(i)] = str(j)
+ wandb_log = wandb.init(
+ project=config.wandb_project,
+ entity=config.user_name,
+ name=config.model_name,
+ config=conf_,
+ )
+ wandb_log.watch(model, log_freq=100)
+ else:
+ wandb_log = None
+
+ model.train()
+ optimizer = optim.AdamW(
+ model.parameters(), lr=config.sa_lr, weight_decay=config.sa_weight_decay
+ )
+ lr = config.sa_lr
+ min_val_loss = 1000
+ step_num = 0
+ start_epoch = 0
+ if checkpoint_initial is not None:
+ print(checkpoint_initial)
+ model.load_state_dict(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["model"],
+ strict=True,
+ )
+ model.train()
+ optimizer.load_state_dict(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
+ "optimizer"
+ ]
+ )
+ step_num = int(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["step"]
+ )
+ step_num = 0
+ start_epoch = (
+ int(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
+ "epoch"
+ ]
+ )
+ + 1
+ )
+ print(f"resuming training from epoch {start_epoch} and step {step_num}")
+
+ train_dataloader, model, optimizer = accelerator.prepare(
+ train_dataloader, model, optimizer
+ )
+
+ FM = BASECFM()
+ device = next(model.parameters()).device
+ if ar_active:
+ m1 = m1.to(device)
+
+ loading_time = []
+ for i in range(start_epoch, config.sa_epochs):
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
+ if accelerator.is_local_main_process:
+ train_loader = tqdm(train_dataloader, desc="Training epoch %d" % (i))
+ else:
+ train_loader = train_dataloader
+
+ for inputs in train_loader:
+ with accelerator.accumulate(model):
+ optimizer.zero_grad()
+ x1, code_emb, mask_lengths, _, ref_mels, text_ids, _, _ = inputs
+ mask = get_mask_from_lengths(mask_lengths).unsqueeze(1)
+ mask = mask.squeeze(1).float()
+
+ loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels)
+
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ step_num += 1
+
+ epoch_loss["loss"].append(loss.item())
+
+ if step_num % config.gradient_accumulation_steps == 0:
+ epoch_training_loss = torch.tensor(
+ sum(epoch_loss["loss"]) / len(epoch_loss["loss"])
+ ).to(device)
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
+ epoch_training_loss = (
+ accelerator.gather_for_metrics(epoch_training_loss)
+ .mean()
+ .item()
+ )
+
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
+ wandb_log.log({"training_loss": epoch_training_loss})
+
+ if (
+ step_num % (config.sa_eval_step * config.gradient_accumulation_steps)
+ == 0
+ ):
+ print(f"evaluation at step_num {step_num}")
+ if accelerator.is_local_main_process:
+ # save the latest checkpoint
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": step_num // config.gradient_accumulation_steps,
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "norms": config.norms,
+ }
+ torch.save(
+ checkpoint,
+ os.path.join(config.save_root_dir, "latest.pt",),
+ )
+
+ if accelerator.is_local_main_process:
+ val_loss, val_mse, val_vlb, time_steps_mean = val(
+ model,
+ FM,
+ val_dataloader,
+ infer_=config.sa_infer,
+ epoch=i,
+ rank=accelerator.is_local_main_process,
+ ar_active=ar_active,
+ m1=m1,
+ )
+ model.train()
+
+ print(
+ "validation loss : ",
+ val_loss,
+ "\nvalidation mse loss : ",
+ val_mse,
+ "\nvalidation vlb loss : ",
+ val_vlb,
+ )
+ if config.sa_wandb_logs:
+ wandb_log.log({"val_loss": val_loss})
+ if val_loss < min_val_loss:
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": step_num // config.gradient_accumulation_steps,
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ "norms": config.norms,
+ }
+ torch.save(
+ checkpoint,
+ os.path.join(config.save_root_dir, "_best.pt"),
+ )
+ min_val_loss = val_loss
+
+ if i == start_epoch + 12:
+ exit()
+ if config.sa_wandb_logs and accelerator.is_local_main_process:
+ wandb_log.finish()
+
+
+def val(
+ model,
+ FM,
+ val_dataloader,
+ infer_=False,
+ epoch=0,
+ rank=False,
+ ar_active=False,
+ m1=None,
+):
+ """
+ Return the loss value
+ """
+ model.eval()
+ epoch_loss = {"vlb": [], "mse": [], "loss": []}
+ code_emb = None
+ x = None
+ mask_lengths = None
+ time_steps_mean = []
+ device = next(model.parameters()).device
+ if rank:
+ val_dataloader = tqdm(val_dataloader, desc="validation epoch %d" % (epoch))
+ else:
+ val_dataloader = val_dataloader
+
+ with torch.no_grad():
+ for inputs in val_dataloader:
+ x1, code_emb, mask_lengths, code_len, ref_mels, text_ids, _, _ = inputs
+
+ mask = get_mask_from_lengths(mask_lengths).unsqueeze(1).to(device)
+ mask = mask.squeeze(1).float()
+ x1 = x1.to(device)
+ code_emb = code_emb.to(device)
+ text_ids = text_ids.to(device)
+ ref_mels = ref_mels.to(device)
+
+ loss, _, t = FM.compute_loss(model, x1, mask, code_emb, ref_mels)
+ time_steps_mean.extend(t.detach().cpu().squeeze(-1).squeeze(-1).tolist())
+ mse = loss
+ vlb = loss
+
+ epoch_loss["loss"].append(loss.item())
+ epoch_loss["mse"].append(mse.item())
+ epoch_loss["vlb"].append(vlb.item())
+
+ epoch_vlb_loss = sum(epoch_loss["vlb"]) / len(epoch_loss["vlb"])
+ epoch_training_loss = sum(epoch_loss["loss"]) / len(epoch_loss["loss"])
+ epoch_mse_loss = sum(epoch_loss["mse"]) / len(epoch_loss["mse"])
+ if rank and infer_ and epoch % config.sa_infer_epoch == 0:
+ k = 4
+ if ar_active:
+ code_embs = [code_emb[i, :, : code_len[i]] for i in range(k)]
+ else:
+ code_embs = [code_emb[i, : code_len[i]] for i in range(k)]
+ audio_paths, mels = infer(
+ model, mask_lengths[:k], code_embs, ref_mels[:k, :], epoch
+ )
+
+ if config.sa_wandb_logs:
+ images = [
+ wandb.Image(mel[0], caption="epoch: " + str(epoch)) for mel in mels
+ ]
+ x = [
+ wandb.Image(x1[i, :, : mask_lengths[i]], caption="Actual: ")
+ for i in range(k)
+ ]
+ wandb.log(
+ {
+ "predicted audio": [
+ wandb.Audio(audio_path) for audio_path in audio_paths
+ ],
+ "predicted melspec": images,
+ "actual melspec": x,
+ "epoch": epoch,
+ }
+ )
+
+ return (
+ epoch_training_loss,
+ epoch_mse_loss,
+ epoch_vlb_loss,
+ sum(time_steps_mean) / len(time_steps_mean),
+ )
+
+
+if __name__ == "__main__":
+ os.makedirs(os.path.join(config.save_root_dir, config.model_name, "S2A"), exist_ok=True)
+
+ model = DiffModel(
+ input_channels=100,
+ output_channels=100,
+ model_channels=512, # 1024
+ num_heads=8, # 16
+ dropout=0.10,
+ num_layers=8,
+ enable_fp16=False,
+ condition_free_per=0.0,
+ multispeaker=True,
+ style_tokens=100,
+ training=True,
+ ar_active=False,
+ in_latent_channels=len(code_labels),
+ )
+ m1 = None
+ checkpoint = None
+ print("Model Loaded")
+ print("batch_size :", config.sa_batch_size)
+ print("Diffusion timesteps:", config.sa_timesteps_max)
+
+ file_name_train = config.train_file
+ file_name_val = config.val_file
+
+ train_dataset = Acoustic_dataset(file_name_train, scale=config.scale)
+ train_dataloader = DataLoader(
+ train_dataset,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=config.sa_num_workers,
+ batch_size=config.sa_batch_size,
+ shuffle=True,
+ drop_last=False,
+ collate_fn=collate,
+ )
+
+ val_dataset = Acoustic_dataset(file_name_val, scale=config.scale, dur_=5)
+ val_dataloader = DataLoader(
+ val_dataset,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=config.sa_num_workers,
+ batch_size=config.sa_batch_size,
+ shuffle=True,
+ drop_last=True,
+ collate_fn=collate,
+ )
+
+ train(
+ model,
+ diffuser=None,
+ train_dataloader=train_dataloader,
+ val_dataloader=val_dataloader,
+ rank=0,
+ ar_active=False,
+ m1=m1,
+ checkpoint_initial=checkpoint,
+ )
diff --git a/S2A/utilities.py b/S2A/utilities.py
new file mode 100755
index 0000000000000000000000000000000000000000..b6f6f7525ccc96a1129e18f9c77b97155a917723
--- /dev/null
+++ b/S2A/utilities.py
@@ -0,0 +1,175 @@
+import librosa.util as librosa_util
+import numpy as np
+import torch
+from scipy.io.wavfile import read
+from scipy.signal import get_window
+
+# import librosa
+from config import config
+
+# find these values
+# TACOTRON_MEL_MAX = 2.4
+# TACOTRON_MEL_MIN = -11.5130
+# tensor(-11.5129) tensor(2.0743) indic tts
+# tensor(-11.5129) tensor(2.3314) LibriTTS
+# tensor(-11.5129) tensor(2.3996) Eng10k scale
+
+
+def normalize_tacotron_mel(data, mu=config.mu, std=config.std):
+ # return data
+ if not isinstance(mu, (float, int)):
+ if isinstance(mu, list):
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
+ elif isinstance(mu, torch.Tensor):
+ mu = mu.to(data.device)
+ elif isinstance(mu, np.ndarray):
+ mu = torch.from_numpy(mu).to(data.device)
+ mu = mu.unsqueeze(-1)
+
+ if not isinstance(std, (float, int)):
+ if isinstance(std, list):
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
+ elif isinstance(std, torch.Tensor):
+ std = std.to(data.device)
+ elif isinstance(std, np.ndarray):
+ std = torch.from_numpy(std).to(data.device)
+ std = std.unsqueeze(-1)
+
+ return (data - mu) / std
+
+
+def denormalize_tacotron_mel(data, mu=config.mu, std=config.std):
+ # return data
+ if not isinstance(mu, float):
+ if isinstance(mu, list):
+ mu = torch.tensor(mu, dtype=data.dtype, device=data.device)
+ elif isinstance(mu, torch.Tensor):
+ mu = mu.to(data.device)
+ elif isinstance(mu, np.ndarray):
+ mu = torch.from_numpy(mu).to(data.device)
+ mu = mu.unsqueeze(-1)
+
+ if not isinstance(std, float):
+ if isinstance(std, list):
+ std = torch.tensor(std, dtype=data.dtype, device=data.device)
+ elif isinstance(std, torch.Tensor):
+ std = std.to(data.device)
+ elif isinstance(std, np.ndarray):
+ std = torch.from_numpy(std).to(data.device)
+ std = std.unsqueeze(-1)
+
+ return data * std + mu
+
+
+# def denormalize_tacotron_mel(norm_mel):
+# return ((norm_mel+1)/2)*(TACOTRON_MEL_MAX-TACOTRON_MEL_MIN)+TACOTRON_MEL_MIN
+
+
+# def normalize_tacotron_mel(mel):
+# return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
+
+# min_max to global scaling
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ if not max_len:
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
+ mask = (ids < lengths.unsqueeze(1)).bool()
+ return mask
+
+
+def get_mask(lengths, max_len=None):
+ if not max_len:
+ max_len = torch.max(lengths).item()
+ lens = torch.arange(
+ max_len,
+ )
+ mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1)
+ return mask
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def window_sumsquare(
+ window,
+ n_frames,
+ hop_length=200,
+ win_length=800,
+ n_fft=800,
+ dtype=np.float32,
+ norm=None,
+):
+ """
+ # from librosa 0.6
+ Compute the sum-square envelope of a window function at a given hop length.
+ This is used to estimate modulation effects induced by windowing
+ observations in short-time fourier transforms.
+ Parameters
+ ----------
+ window : string, tuple, number, callable, or list-like
+ Window specification, as in `get_window`
+ n_frames : int > 0
+ The number of analysis frames
+ hop_length : int > 0
+ The number of samples to advance between frames
+ win_length : [optional]
+ The length of the window function. By default, this matches `n_fft`.
+ n_fft : int > 0
+ The length of each analysis frame.
+ dtype : np.dtype
+ The data type of the output
+ Returns
+ -------
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+ The sum-squared envelope of the window function
+ """
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ # Compute the squared window at the desired length
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
+
+ # Fill the envelope
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+ return x
+
+
+def load_wav_to_torch(full_path):
+ sampling_rate, data = read(
+ full_path,
+ )
+ # print(data)
+ # data,sampling_rate = librosa.load(full_path)
+ # print(data)
+ return torch.FloatTensor(data), sampling_rate
+
+
+if __name__ == "__main__":
+ lens = torch.tensor([2, 3, 7, 5, 4])
+ mask = get_mask(lens)
+ print(mask)
+ print(mask.shape)
diff --git a/Semantic_tokens/convert_factorize.py b/Semantic_tokens/convert_factorize.py
new file mode 100755
index 0000000000000000000000000000000000000000..fc6a01392ac2c629caa7bc519a82d0c3fe412fd3
--- /dev/null
+++ b/Semantic_tokens/convert_factorize.py
@@ -0,0 +1,273 @@
+import argparse
+import glob
+import math
+import multiprocessing as mp
+import os
+import statistics
+import sys
+from multiprocessing import Pool
+
+import numpy as np
+import pandas as pd
+import soundfile as sf
+import torchaudio
+from pydub import AudioSegment
+from tqdm import tqdm
+
+
+def check_sample_rate(file_path):
+ try:
+ info = sf.info(file_path)
+ if info.samplerate != 24000:
+ return file_path
+ except Exception as e:
+ return None # In case of error, return None
+
+
+def process_files(file_list):
+ with Pool() as pool:
+ result = list(
+ tqdm(pool.imap(check_sample_rate, file_list), total=len(file_list))
+ )
+ return [file for file in result if file is not None]
+
+
+def read_paths_from_file(filename):
+ with open(filename, "r", encoding="utf8") as file:
+ path_list = []
+ for i in tqdm(file):
+ path = i.split("|")[1]
+ path_list.append(path.strip("\n"))
+
+ return path_list[:]
+
+
+def gather_paths_from_glob():
+ return glob.glob("./**/*.wav", recursive=True)
+
+
+def detect_leading_silence(sound, silence_threshold=-50, chunk_size=64):
+ trim_ms = 0
+ assert chunk_size > 0
+ while sound[
+ trim_ms : trim_ms + chunk_size
+ ].dBFS < silence_threshold and trim_ms < len(sound):
+ trim_ms += chunk_size
+ return trim_ms
+
+
+def preprocess_audio(path, target_dBFS, frame_rate):
+ durations = []
+ dbfs = []
+ audio = AudioSegment.from_file(path)
+ dbfs.append(audio.dBFS)
+ audio = audio.set_channels(1)
+ audio = audio.set_frame_rate(frame_rate).set_sample_width(2)
+
+ start_trim = detect_leading_silence(audio)
+ end_trim = detect_leading_silence(audio.reverse())
+
+ duration = len(audio)
+ audio = audio[start_trim : duration - end_trim]
+ audio = (
+ AudioSegment.silent(duration=256, frame_rate=22050)
+ + audio
+ + AudioSegment.silent(duration=256, frame_rate=22050)
+ )
+
+ if path[-4:] == ".wav":
+ audio.export(path[:-4] + ".wav", format="wav")
+ elif path[-5:] == ".flac":
+ audio.export(path[:-5] + ".flac", format="flac")
+ else:
+ audio.export(path[:-4] + ".wav", format="wav")
+
+ durations.append(audio.duration_seconds)
+
+ return dbfs, durations
+
+
+def preprocess_audio_chunk(args):
+ path_list_chunk, target_dBFS, frame_rate, n = args
+ dbfs = []
+ durations = []
+ for i in tqdm(path_list_chunk, desc="preprocess " + str(n)):
+ try:
+ audio = AudioSegment.from_file(i)
+ dbfs_i, durations_i = preprocess_audio(i, target_dBFS, frame_rate)
+ dbfs.extend(dbfs_i)
+ durations.extend(durations_i)
+ except Exception as e:
+ print(n, i, e)
+
+ return dbfs, durations
+
+
+def preprocess_audio_paths(path_list, target_dBFS, frame_rate, num_workers):
+ chunk_size = len(path_list) // num_workers
+ path_chunks = [
+ path_list[i : i + chunk_size] for i in range(0, len(path_list), chunk_size)
+ ]
+
+ with Pool(num_workers) as pool:
+ results = pool.map(
+ preprocess_audio_chunk,
+ [
+ (chunk, target_dBFS, frame_rate, n)
+ for n, chunk in enumerate(path_chunks)
+ ],
+ )
+
+ dbfs = []
+ durations = []
+ for dbfs_i, durations_i in results:
+ dbfs.extend(dbfs_i)
+ durations.extend(durations_i)
+ return dbfs, durations
+
+
+def gather_metadata_chunk(args):
+ path_list_chunk = args
+ dbfs = []
+ durations = []
+ files = []
+ for i in tqdm(path_list_chunk):
+ try:
+ path = i.split("|")[0]
+ audio = AudioSegment.from_file(path)
+ if audio.dBFS == -math.inf:
+ print("=====================")
+ print(path)
+ print("=====================")
+ continue
+ dbfs.append(audio.dBFS)
+ durations.append(audio.duration_seconds)
+ files.append((audio.duration_seconds, i))
+ if audio.duration_seconds == 0:
+ print(i)
+ except Exception as e:
+ print(e, i)
+
+ return dbfs, durations, files
+
+
+def gather_metadata(path_list, num_workers):
+ chunk_size = len(path_list) // num_workers
+ path_chunks = [
+ path_list[i : i + chunk_size] for i in range(0, len(path_list), chunk_size)
+ ]
+
+ with Pool(num_workers) as pool:
+ results = pool.map(gather_metadata_chunk, [chunk for chunk in path_chunks])
+
+ dbfs = []
+ durations = []
+ files = []
+ for dbfs_i, durations_i, files_i in results:
+ dbfs.extend(dbfs_i)
+ durations.extend(durations_i)
+ files.extend(files_i)
+
+ files = sorted(files, key=lambda x: x[0])
+ with open(os.path.join(data_dir, "files_duration.txt"), "w") as file:
+ file.write("\n".join([i[1] + "|" + str(i[0]) for i in files]))
+
+ with open(os.path.join(data_dir, "files.txt"), "w") as file:
+ file.write("\n".join([i[1] for i in files if 2.0 < float(i[0]) < 15.0]))
+
+ return dbfs, durations
+
+
+def process_audio_data(input_file, mode, num_workers, data_dir):
+ if input_file:
+ path_list = read_paths_from_file(input_file)
+ else:
+ path_list = gather_paths_from_glob()
+
+ speakers = []
+ for n, i in enumerate(path_list):
+ try:
+ speakers.append(i.split("/")[-2])
+ except:
+ print(n, i)
+
+ print("total audio files:", len(path_list))
+
+ if mode == "preprocess":
+ print("Preprocessing!")
+ target_dBFS = -24.196741 # not using
+
+ frame_rate = 24000
+ dbfs, durations = preprocess_audio_paths(
+ path_list[:], target_dBFS, frame_rate, num_workers
+ )
+
+ print("min duration : ", min(durations))
+ print("max duration : ", max(durations))
+ print("avg duration : ", sum(durations) / len(durations))
+ print("Standard Deviation of durations % s" % (statistics.stdev(durations)))
+ print("total duration : ", sum(durations))
+ print("DONE")
+
+ if mode == "metadata":
+ print("Gathering metadata")
+ dbfs, durations = gather_metadata(path_list[:], num_workers)
+
+ print("min duration : ", min(durations))
+ print("max duration : ", max(durations))
+ print("avg duration : ", sum(durations) / len(durations))
+ print("total duration : ", sum(durations))
+ print("Standard Deviation of sample is % s" % (statistics.stdev(durations)))
+ print("DONE")
+
+ # pd.DataFrame({'dBFS': dbfs, 'duration': durations, 'files': [i[1] for i in files]}).to_csv("meta.csv", index=False)
+
+
+if __name__ == "__main__":
+ parser = argparse.ArgumentParser(
+ description="Audio processing script for preprocessing and metadata gathering",
+ formatter_class=argparse.RawDescriptionHelpFormatter,
+ epilog="""
+Examples:
+ # Preprocess audio files from a text file
+ python convert_factorize.py --input paths.txt --mode preprocess
+
+ # Gather metadata from audio files in current directory
+ python convert_factorize.py --mode metadata
+
+ # Preprocess audio files found recursively in current directory
+ python convert_factorize.py --mode preprocess
+ """
+ )
+
+ parser.add_argument(
+ "--input", "-i",
+ type=str,
+ help="Path to text file containing audio file paths in 'language|abspath|text' format"
+ )
+
+ parser.add_argument(
+ "--mode", "-m",
+ type=str,
+ choices=["preprocess", "metadata"],
+ required=True,
+ help="Processing mode: 'preprocess' to process audio files, 'metadata' to gather statistics"
+ )
+
+ parser.add_argument(
+ "--workers", "-w",
+ type=int,
+ default=4,
+ help="Number of worker processes for parallel processing (default: 4)"
+ )
+
+ args = parser.parse_args()
+
+ data_dir = "/".join(args.input.split("/")[:-1])
+ print(f"Data directory: {data_dir}")
+
+ print(f"Input file: {args.input}")
+ print(f"Mode: {args.mode}")
+ print(f"Number of workers: {args.workers}")
+
+ process_audio_data(args.input, args.mode, args.workers, data_dir)
diff --git a/Semantic_tokens/extract_m4t_tokens_multi.py b/Semantic_tokens/extract_m4t_tokens_multi.py
new file mode 100755
index 0000000000000000000000000000000000000000..034327308a17843e617fceaa4ce1fddad6fb85cc
--- /dev/null
+++ b/Semantic_tokens/extract_m4t_tokens_multi.py
@@ -0,0 +1,228 @@
+import argparse
+import multiprocessing
+import os
+import random
+import shutil
+import sys
+import time
+from functools import partial
+
+import numpy as np
+import torch
+import torchaudio
+from tqdm import tqdm
+
+from seamless_communication.models.unit_extractor import (
+ KmeansModel, UnitExtractor, Wav2Vec2LayerOutputModel)
+
+
+def train_test_split_large_file(input_file, train_file, test_file, test_ratio=0.2, seed=42):
+ """
+ Memory-efficient train-test split for large files.
+ Performs two passes: first to count lines, second to split.
+ """
+ random.seed(seed)
+
+ # First pass: count lines
+ with open(input_file, "r", encoding="utf-8") as f:
+ total_lines = sum(1 for _ in f)
+
+ n_test = int(total_lines * test_ratio)
+
+ # Choose random line numbers for test set
+ test_indices = set(random.sample(range(total_lines), n_test))
+
+ # Second pass: write to train/test files
+ with (
+ open(input_file, "r", encoding="utf-8") as in_f,
+ open(train_file, "w", encoding="utf-8") as train_f,
+ open(test_file, "w", encoding="utf-8") as test_f,
+ ):
+ for idx, line in enumerate(in_f):
+ if idx in test_indices:
+ test_f.write(line)
+ else:
+ train_f.write(line)
+
+ print(
+ f"Split {total_lines} lines into {total_lines - n_test} train and {n_test} test lines."
+ )
+
+
+lock = multiprocessing.Lock()
+
+
+def process_data(data, device, out_layer_number, process_no, kmeans_uri, model_name, data_dir, batch_size=10000):
+ lock.acquire()
+ unit_extractor = UnitExtractor(
+ model_name, kmeans_uri, device=torch.device(f"cuda:{device}")
+ )
+ lock.release()
+ results = []
+ # if i == 0:
+ data = tqdm(data, desc="process no : " + str(process_no))
+ for i in data:
+ try:
+ audio, sr = torchaudio.load(i[1])
+ audio = (
+ torchaudio.functional.resample(audio, sr, 16000)
+ .squeeze(0)
+ .unsqueeze(-1)
+ )
+ with torch.no_grad():
+ units = (
+ unit_extractor.predict(audio.to(device), out_layer_number - 1)
+ .detach()
+ .cpu()
+ .numpy()
+ )
+ text = " ".join([str(k) for k in units]).strip()
+ results.append("|".join(i) + "|" + text + "\n")
+ except Exception as e:
+ print(i, e)
+
+ if len(results) == batch_size:
+ semt = "".join(results)
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(process_no) + "_semt.txt"), "a") as file:
+ file.write(semt)
+ results = []
+
+ if len(results) != 0:
+ semt = "".join(results)
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(process_no) + "_semt.txt"), "a") as file:
+ file.write(semt)
+ results = []
+
+
+def parse_arguments():
+ """Parse command line arguments."""
+ parser = argparse.ArgumentParser(
+ description="Extract M4T semantic tokens from audio files using multi-GPU processing"
+ )
+
+ parser.add_argument(
+ "input_file",
+ help="Path to the input file containing audio file paths and metadata"
+ )
+
+ parser.add_argument(
+ "--kmeans-uri",
+ default="https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy",
+ help="URI for the kmeans model (default: M4T kmeans model)"
+ )
+
+ parser.add_argument(
+ "--model-name",
+ default="xlsr2_1b_v2",
+ help="Model name for unit extraction (default: xlsr2_1b_v2)"
+ )
+
+ parser.add_argument(
+ "--out-layer-number",
+ type=int,
+ default=35,
+ help="Output layer number for feature extraction (default: 35)"
+ )
+
+ parser.add_argument(
+ "--gpu-multiplier",
+ type=int,
+ default=1,
+ help="Multiplier for number of GPUs to use (default: 1)"
+ )
+
+ parser.add_argument(
+ "--test-ratio",
+ type=float,
+ default=0.1,
+ help="Ratio of data to use for validation set (default: 0.1)"
+ )
+
+ parser.add_argument(
+ "--seed",
+ type=int,
+ default=42,
+ help="Random seed for reproducibility (default: 42)"
+ )
+
+ parser.add_argument(
+ "--batch-size",
+ type=int,
+ default=10000,
+ help="Number of results to accumulate before writing to file (default: 10000)"
+ )
+
+ return parser.parse_args()
+
+
+if __name__ == "__main__":
+ start = time.time()
+
+ # Parse command line arguments
+ args = parse_arguments()
+
+ # Extract parameters from arguments
+ kmeans_uri = args.kmeans_uri
+ model_name = args.model_name
+ out_layer_number = args.out_layer_number
+ k = args.gpu_multiplier
+ test_ratio = args.test_ratio
+ seed = args.seed
+ batch_size = args.batch_size
+
+ num_gpus_to_use = torch.cuda.device_count() * k
+ data_dir = "/".join(args.input_file.split("/")[:-1])
+
+ if os.path.exists(os.path.join(data_dir, "SEMANTICS_")):
+ shutil.rmtree(os.path.join(data_dir, "SEMANTICS_"))
+ os.makedirs(os.path.join(data_dir, "SEMANTICS_"), exist_ok=True)
+
+ with open(args.input_file, "r") as file:
+ data = file.read().strip("\n").split("\n")[:]
+
+ data = [i.split("|") for i in data][:]
+
+ # Split data into chunks for each GPU
+ chunk_size = len(data) // num_gpus_to_use
+ data_chunks = [data[i : i + chunk_size] for i in range(0, len(data), chunk_size)]
+ if len(data_chunks) != num_gpus_to_use:
+ data_chunks[-2] += data_chunks[-1]
+ data_chunks = data_chunks[:-1]
+
+ processes = []
+ for i in range(num_gpus_to_use):
+ p = multiprocessing.Process(
+ target=process_data,
+ args=(data_chunks[i], i // k, out_layer_number, i, kmeans_uri, model_name, data_dir, batch_size)
+ )
+ processes.append(p)
+ p.start()
+
+ for p in processes:
+ p.join()
+
+ for i in range(num_gpus_to_use):
+ with open(os.path.join(data_dir, "SEMANTICS_/" + str(i) + "_semt.txt"), "r") as file:
+ data = file.read()
+
+ with open(
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt"), "a"
+ ) as file:
+ file.write(data)
+
+ input_file = os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt")
+ train_file = (
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt_train.txt")
+ )
+ test_file = (
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt_val.txt")
+ )
+
+ train_test_split_large_file(
+ input_file, train_file, test_file, test_ratio=test_ratio, seed=seed
+ )
+ print("processing took: with %d instances" % (num_gpus_to_use), time.time() - start)
+ print(
+ "data ready at:",
+ os.path.join(data_dir, "SEMANTICS_/" + os.path.basename(args.input_file).split(".")[0] + "_semt.txt"),
+ )
diff --git a/Semantic_tokens/seamless_communication b/Semantic_tokens/seamless_communication
new file mode 160000
index 0000000000000000000000000000000000000000..90e2b57ac4d82fa2bfaa25caeffe39ceb8b2ebec
--- /dev/null
+++ b/Semantic_tokens/seamless_communication
@@ -0,0 +1 @@
+Subproject commit 90e2b57ac4d82fa2bfaa25caeffe39ceb8b2ebec
diff --git a/T2S/__pycache__/autoregressive.cpython-310.pyc b/T2S/__pycache__/autoregressive.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..abade8db9d9611e6c448cc2e9d4649e227389984
Binary files /dev/null and b/T2S/__pycache__/autoregressive.cpython-310.pyc differ
diff --git a/T2S/__pycache__/gpt_inference.cpython-310.pyc b/T2S/__pycache__/gpt_inference.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..bd3dd1e2a0011843a888223f1c3290b7bb415693
Binary files /dev/null and b/T2S/__pycache__/gpt_inference.cpython-310.pyc differ
diff --git a/T2S/__pycache__/mel_spec.cpython-310.pyc b/T2S/__pycache__/mel_spec.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..71f0393cd748a8271177d93e506fcf73fc17c3cd
Binary files /dev/null and b/T2S/__pycache__/mel_spec.cpython-310.pyc differ
diff --git a/T2S/__pycache__/t2s_modules.cpython-310.pyc b/T2S/__pycache__/t2s_modules.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..ca9c7ee31f86afb16284642b884a34d014538d22
Binary files /dev/null and b/T2S/__pycache__/t2s_modules.cpython-310.pyc differ
diff --git a/T2S/__pycache__/utilities.cpython-310.pyc b/T2S/__pycache__/utilities.cpython-310.pyc
new file mode 100755
index 0000000000000000000000000000000000000000..73a08a96cf63147424de61ef5dd3f0e3a4c031a7
Binary files /dev/null and b/T2S/__pycache__/utilities.cpython-310.pyc differ
diff --git a/T2S/autoregressive.py b/T2S/autoregressive.py
new file mode 100755
index 0000000000000000000000000000000000000000..8eb58f3139ce41692b861a8ca4dcce67fef50e0c
--- /dev/null
+++ b/T2S/autoregressive.py
@@ -0,0 +1,296 @@
+import functools
+import os
+import sys
+from typing import Any
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+from torch.utils.data import DataLoader, Dataset
+from tqdm import tqdm
+from transformers import (GemmaConfig, GemmaModel, GPT2Config, GPT2LMHeadModel,
+ GPT2Model, GPT2Tokenizer)
+
+from config import config
+from Text import code_labels, labels, text_labels
+
+from .gpt_inference import GPT2InferenceModel
+from .t2s_modules import GST
+
+# code encdec
+text_enc = {j: i for i, j in enumerate(text_labels)}
+text_dec = {i: j for i, j in enumerate(text_labels)}
+
+# text encdec
+code_enc = {j: i for i, j in enumerate(code_labels)}
+code_dec = {i: j for i, j in enumerate(code_labels)}
+
+
+def null_position_embeddings(range, dim):
+ return torch.zeros((range.shape[0], range.shape[1], dim), device=range.device)
+
+
+class TS_model(nn.Module):
+ def __init__(
+ self, n_embed=1024, n_layer=30, n_head=16, n_positions=config.t2s_position
+ ):
+ super(TS_model, self).__init__()
+ assert (n_embed / n_head) % 2 == 0, "n_embed n_head not a division of 2"
+ self.vocab_size = len(labels)
+ self.n_positions = n_positions
+ self.n_embed = n_embed
+ self.n_layer = n_layer
+ self.n_head = n_head
+
+ if self.vocab_size % 2 != 0:
+ self.vocab_size += 1
+ k = 1
+
+ self.config = GemmaConfig(
+ vocab_size=self.vocab_size,
+ hidden_size=self.n_embed,
+ intermediate_size=self.n_embed * k,
+ num_hidden_layers=self.n_layer,
+ num_attention_heads=self.n_head,
+ num_key_value_heads=self.n_head,
+ head_dim=int(self.n_embed / self.n_head),
+ hidden_act="gelu_pytorch_tanh",
+ hidden_activation=None,
+ max_position_embeddings=self.n_positions,
+ initializer_range=0.02,
+ rms_norm_eps=1e-06,
+ use_cache=True,
+ pad_token_id=0,
+ eos_token_id=1,
+ bos_token_id=2,
+ tie_word_embeddings=True,
+ rope_theta=10000.0,
+ attention_bias=False,
+ attention_dropout=0.0,
+ )
+
+ self.gpt = GemmaModel(self.config)
+ del self.gpt.embed_tokens
+
+ self.GST = GST(
+ model_channels=self.n_embed,
+ num_heads=self.n_head,
+ in_channels=config.n_mel_channels,
+ )
+ self.text_head = nn.Linear(self.n_embed, len(text_labels))
+ self.code_head = nn.Linear(self.n_embed, len(code_labels))
+
+ self.text_embed = nn.Embedding(len(text_labels), self.n_embed)
+ self.code_embed = nn.Embedding(len(code_labels), self.n_embed)
+ self.language_embed = nn.Embedding(len(config.lang_index), self.n_embed)
+ self.final_norm = nn.LayerNorm(self.n_embed)
+
+ def init_random_embeddings(self):
+ self.text_embed.weight.data.uniform_(-1, 1)
+ self.code_embed.weight.data.uniform_(-1, 1)
+
+ def get_speaker_latent(self, ref_mels):
+ ref_mels = ref_mels.unsqueeze(1) if len(ref_mels.shape) == 3 else ref_mels
+
+ conds = []
+ for j in range(ref_mels.shape[1]):
+ conds.append(self.GST(ref_mels[:, j, :, :]))
+
+ conds = torch.cat(conds, dim=-1)
+ conds = conds.mean(dim=-1)
+
+ return conds.unsqueeze(1)
+
+ def forward(
+ self,
+ text_ids,
+ codes_ids=None,
+ speaker_embed=None,
+ ref_clips=None,
+ language=torch.tensor(0),
+ attn_mask=None,
+ return_loss=False,
+ ):
+ assert speaker_embed is not None or ref_clips is not None
+ text_embed = self.text_embed(text_ids)
+
+ lanugage_embed = self.language_embed(language).unsqueeze(1)
+ code_embed = None
+ code_probs = None
+
+ if codes_ids is not None:
+ code_embed = self.code_embed(codes_ids)
+
+ if ref_clips is not None:
+ speaker_embed = self.get_speaker_latent(ref_clips)
+
+ text_embed, code_embed = self.get_logits(
+ lanugage_embed=lanugage_embed,
+ speaker_embed=speaker_embed,
+ text_embed=text_embed,
+ code_embed=code_embed,
+ attn_mask=attn_mask,
+ )
+ text_probs = self.text_head(text_embed).permute(0, 2, 1)
+
+ if codes_ids is not None:
+ code_probs = self.code_head(code_embed).permute(0, 2, 1)
+
+ if return_loss:
+ loss_text = F.cross_entropy(
+ text_probs[:, :, :-1], text_ids[:, 1:].long(), reduce=False
+ )
+ loss_mel = F.cross_entropy(
+ code_probs[:, :, :-1], codes_ids[:, 1:].long(), reduce=False
+ )
+ return loss_text, loss_mel, code_probs
+
+ return text_probs, code_probs
+
+ def get_logits(
+ self, lanugage_embed, speaker_embed, text_embed, code_embed=None, attn_mask=None
+ ):
+ if code_embed is not None:
+ embed = torch.cat(
+ [lanugage_embed, speaker_embed, text_embed, code_embed], dim=1
+ )
+ position_ids = torch.zeros(
+ (embed.shape[0], embed.shape[1]), device=embed.device
+ )
+ indices = torch.tensor(
+ [0, 0]
+ + list(range(text_embed.shape[1]))
+ + list(range(code_embed.shape[1])),
+ device=embed.device,
+ )
+ position_ids[:, : indices.size(0)] = indices
+ else:
+ embed = torch.cat([lanugage_embed, speaker_embed, text_embed], dim=1)
+ position_ids = torch.zeros(
+ (embed.shape[0], embed.shape[1]), device=embed.device
+ )
+ indices = torch.tensor(
+ [0, 0] + list(range(text_embed.shape[1])), device=embed.device
+ )
+ position_ids[:, : indices.size(0)] = indices
+
+ if attn_mask is None:
+ attn_mask = torch.ones_like(embed).to(embed.device)
+ else:
+ attn_mask = torch.cat(
+ [torch.ones((embed.shape[0], 2)).to(embed.device), attn_mask], dim=1
+ )
+ gpt_output = self.gpt(
+ inputs_embeds=embed,
+ attention_mask=attn_mask,
+ position_ids=position_ids,
+ return_dict=True,
+ )
+ enc = gpt_output.last_hidden_state[:, 2:]
+ enc = self.final_norm(enc)
+ if code_embed is not None:
+ return enc[:, : text_embed.shape[1]], enc[:, -code_embed.shape[1] :]
+
+ return enc[:, : text_embed.shape[1]], None
+
+ def init_gpt_for_inference(self, kv_cache=True, use_deepspeed=False):
+ self.gpt_inference = GPT2InferenceModel(
+ self.config,
+ self.gpt,
+ None,
+ self.code_embed,
+ self.final_norm,
+ self.code_head,
+ kv_cache=kv_cache,
+ )
+ self.gpt.embed_tokens = self.code_embed
+
+ if use_deepspeed:
+ import deepspeed
+
+ self.ds_engine = deepspeed.init_inference(
+ model=self.gpt_inference.half(), # Transformers models
+ mp_size=1, # Number of GPU
+ dtype=torch.float32, # desired data type of output
+ replace_method="auto", # Lets DS autmatically identify the layer to replace
+ replace_with_kernel_inject=True, # replace the model with the kernel injector
+ )
+ self.gpt_inference = self.ds_engine.module.eval()
+
+ def compute_embeddings(self, language, cond_latents, text_inputs, code_inputs):
+ text_embed = self.text_embed(text_inputs)
+ lanugage_embed = self.language_embed(language).unsqueeze(1)
+
+ emb = torch.cat([lanugage_embed, cond_latents, text_embed], dim=1)
+
+ position_ids = torch.zeros(
+ (emb.shape[0], emb.shape[1] + len(code_inputs)), device=emb.device
+ )
+ indices = torch.tensor(
+ [0, 0] + list(range(text_embed.shape[1])) + list(range(len(code_inputs))),
+ device=emb.device,
+ )
+ position_ids[:, : indices.size(0)] = indices
+
+ self.gpt_inference.store_prefix_emb(emb)
+ gpt_inputs = torch.full(
+ (
+ emb.shape[0],
+ emb.shape[1] + len(code_inputs), # +1 for the start_audio_token
+ ),
+ fill_value=1,
+ dtype=torch.long,
+ device=text_inputs.device,
+ )
+ gpt_inputs[:, -len(code_inputs) :] = torch.tensor(code_inputs)
+ return (gpt_inputs, position_ids)
+
+ def generate(
+ self,
+ language,
+ cond_latents,
+ text_inputs,
+ code_inputs=[code_enc[""]],
+ **hf_generate_kwargs,
+ ):
+ gpt_inputs, position_ids = self.compute_embeddings(
+ language, cond_latents, text_inputs, code_inputs
+ )
+ gen = self.gpt_inference.generate(
+ gpt_inputs,
+ bos_token_id=code_enc[""],
+ pad_token_id=code_enc[""],
+ eos_token_id=code_enc[""],
+ max_length=self.n_positions,
+ position_ids=position_ids,
+ **hf_generate_kwargs,
+ )
+ if "return_dict_in_generate" in hf_generate_kwargs:
+ return gen.sequences[:, gpt_inputs.shape[1] :], gen
+ return gen[:, gpt_inputs.shape[1] - len(code_inputs) + 1 :]
+
+ def get_generator(self, fake_inputs, **hf_generate_kwargs):
+ return self.gpt_inference.generate_stream(
+ fake_inputs,
+ bos_token_id=code_enc[""],
+ pad_token_id=code_enc[""],
+ eos_token_id=code_enc[""],
+ max_length=self.n_positions,
+ do_stream=True,
+ **hf_generate_kwargs,
+ )
+
+
+class LearnedPositionEmbeddings(nn.Module):
+ def __init__(self, seq_len, model_dim, init=0.02):
+ super().__init__()
+ self.emb = nn.Embedding(seq_len, model_dim)
+ self.emb.weight.data.normal_(mean=0.0, std=init)
+
+ def forward(self, x):
+ sl = x.shape[1]
+ return self.emb(torch.arange(0, sl, device=x.device))
+
+ def get_fixed_embedding(self, ind, dev):
+ return self.emb(torch.tensor([ind], device=dev)).unsqueeze(0)
diff --git a/T2S/dataset.py b/T2S/dataset.py
new file mode 100755
index 0000000000000000000000000000000000000000..d11a28f7fbedcaa8754a9fae433f846b9c74d951
--- /dev/null
+++ b/T2S/dataset.py
@@ -0,0 +1,562 @@
+import os
+import sys
+from typing import Any
+
+sys.path.append("../")
+import linecache
+import mmap
+import pickle as pkl
+import random
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchaudio
+import transformers
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from autoregressive import TS_model
+from cleaners import english_cleaners
+from librosa.filters import mel as librosa_mel_fn
+from mel_spec import get_mel_spectrogram
+from meta_stats import process_file, process_file_for_heads
+from stft import STFT
+from torch.utils.data import (DataLoader, Dataset, WeightedRandomSampler,
+ get_worker_info)
+from tqdm import tqdm
+from utilities import get_mask_from_lengths
+
+import wandb
+from config import config
+from Text import code_labels, labels, text_labels
+
+torch.manual_seed(config.seed_value)
+np.random.seed(config.seed_value)
+random.seed(config.seed_value)
+print(text_labels)
+# add semantic tokens:
+# tok_enc = {j:i for i,j in enumerate(labels)}
+# tok_dec = {j:i for i,j in enumerate(labels)}
+
+# code encdec
+text_enc = {j: i for i, j in enumerate(text_labels)}
+text_dec = {i: j for i, j in enumerate(text_labels)}
+
+# text encdec
+code_enc = {j: i for i, j in enumerate(code_labels)}
+code_dec = {i: j for i, j in enumerate(code_labels)}
+
+
+def read_specific_line(filename, line_number):
+ line = linecache.getline(filename, line_number)
+ return line.strip() # Remove any leading or trailing whitespace
+
+
+CLIP_LENGTH = config.CLIP_LENGTH
+
+
+class semantic_dataset_batch(Dataset):
+ def __init__(
+ self,
+ transcript_path,
+ semantic_path=None,
+ ref_mels_path=None,
+ ref_k=3,
+ scale=False,
+ process_id=None,
+ total_processes=None,
+ ):
+ super().__init__()
+ self.scale = scale
+ if not scale:
+ with open(transcript_path, "r") as file:
+ data = file.read().strip("\n").split("\n")[:]
+
+ with open(semantic_path, "r") as file:
+ semb = file.read().strip("\n").split("\n")
+
+ with open(ref_mels_path, "rb") as file:
+ self.ref_mels = pkl.load(file)
+
+ semb = {
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
+ }
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
+
+ self.data = [[i, semb[i], data[i]] for i in data.keys()]
+
+ else:
+ # with open(transcript_path,'r') as file:
+ # get meta for dataset
+ # for count, line in enumerate(file):
+ # pass
+ # count = 80
+ print(transcript_path)
+ # self.weights,self.count = process_file(transcript_path)
+ self.heads, self.weights, self.count = process_file_for_heads(
+ transcript_path, total_processes, process_id
+ )
+ print("length :", self.count)
+ self.data_len = self.count
+ self.transcript_path = transcript_path
+ line_index = {}
+ with open(transcript_path, "rb") as file:
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
+ line_number = 0
+ offset = 0
+ while offset < len(mmapped_file):
+ line_index[line_number] = offset
+ offset = mmapped_file.find(b"\n", offset) + 1
+ # print(line_number,offset)
+ line_number += 1
+ self.mmapped_file = mmapped_file
+ self.line_index = line_index
+
+ self.process_id = process_id
+ self.total_processes = total_processes
+ self.iterator = None
+
+ self.ref_k = ref_k
+ self.max_wav_value = config.MAX_WAV_VALUE
+ self.stft_fn = STFT(config.filter_length, config.hop_length, config.win_length)
+
+ mel_basis = librosa_mel_fn(
+ sr=config.sampling_rate,
+ n_fft=config.filter_length,
+ n_mels=config.n_mel_channels,
+ fmin=config.mel_fmin,
+ fmax=config.mel_fmax,
+ )
+
+ self.mel_basis = torch.from_numpy(mel_basis).float()
+
+ def get_mel(self, filepath):
+ # audio, sampling_rate = load_wav_to_torch(filepath)
+ # audio_norm = audio / self.max_wav_value
+ audio_norm, sampling_rate = torchaudio.load(filepath)
+
+ # dur = audio_norm.shape[-1]/sampling_rate
+
+ # if dur<0.5:
+ # return None,None,None
+
+ # if self.clip and dur>10 and align:
+ # # print('big file',dur)
+ # max_audio_start = int(dur - 10)
+ # audio_start = random.randint(0, max_audio_start)
+
+ # audio_norm = audio_norm[:,audio_start*sampling_rate:(audio_start+10)*sampling_rate]
+ # semb_ids = semb_ids[audio_start*50:(audio_start+10)*50 -1]
+
+ # 86 mel -> 1s for 22050 setting
+ # ` 93 mel ->`1s for 24000 setting
+
+ # add 64ms of values to start and end
+ # audio_norm += torch.randn(audio_norm.shape[0])*1e-8
+ # audio_norm = torch.concat([torch.randn(1412)*1e-8,audio_norm,torch.randn(1412)*1e-8])
+ # audio_norm = audio_norm.unsqueeze(0)
+ # y = torch.autograd.Variable(audio_norm, requires_grad=False)
+
+ # assert(torch.min(y.data) >= -1)
+ # assert(torch.max(y.data) <= 1)
+ # magnitudes, phases = self.stft_fn.transform(y)
+ # magnitudes = magnitudes.data
+ # mel_output = torch.matmul(self.mel_basis, magnitudes)
+ # mel_output = dynamic_range_compression(mel_output)
+ # melspec = torch.squeeze(mel_output, 0)
+ # energy = torch.norm(magnitudes, dim=1).squeeze(0)
+ # melspec,energy = mel_spectrogram(audio_norm)
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
+ energy = []
+ # if align:
+ # return melspec,list(energy),semb_ids
+ return melspec, list(energy)
+
+ def __len__(self):
+ if self.scale:
+ return self.data_len
+ return len(self.data)
+
+ # def get_process_heads(self,):
+ # '''
+ # divide data and heads based on the batch_size and weights
+ # '''
+
+ # new_heads ={}
+ # new_weights =[]
+ # process_batch_size = config.ts_batch_size*config.ts_gradient_accumulation_steps
+ # sm=0
+ # for i,j in zip(self.heads,self.weights):
+
+ # if sm + j > process_batch_size:
+ # if sm+j == process_batch_size:
+ # new_heads[i] = self.heads[i]
+ # new_weights.append(j)
+ # else:
+ # new_heads[i] = self.heads[i][:len(self.heads[i])*(process_batch_size-sm)//process_batch_size]
+ # new_weights.append(process_batch_size-sm)
+ # else:
+ # new_heads[i] = self.heads[i]
+ # new_weights.append(j)
+
+ # self.get_worker_heads()
+
+ # old heads and weights
+ # new_heads = {}
+ # for i in self.heads:
+ # segment_size = (len(self.heads[i]) + self.total_processes - 1) // self.total_processes
+ # start_idx = self.process_id * segment_size
+ # end_idx = start_idx + segment_size
+
+ # if end_idx > len(self.heads[i]):
+ # # Create a list that wraps around to the beginning
+ # segment = self.heads[i][start_idx:] + self.heads[i][:end_idx - len(self.heads[i])]
+ # else:
+ # segment = self.heads[i][start_idx:end_idx]
+ # new_heads[i]=segment
+ # self.heads = new_heads
+ # print(self.process_id,[len(self.heads[i]) for i in self.heads])
+ # self.get_worker_heads()
+
+ def get_worker_heads(
+ self,
+ ):
+ self.worker_id = get_worker_info().id
+ self.num_worker = get_worker_info().num_workers
+ new_heads = {}
+ for i in self.heads:
+ segment_size = (len(self.heads[i]) + self.num_worker - 1) // self.num_worker
+ start_idx = self.worker_id * segment_size
+ end_idx = start_idx + segment_size
+
+ if end_idx > len(self.heads[i]):
+ # Create a list that wraps around to the beginning
+ segment = (
+ self.heads[i][start_idx:]
+ + self.heads[i][: end_idx - len(self.heads[i])]
+ )
+ else:
+ segment = self.heads[i][start_idx:end_idx]
+ new_heads[i] = segment
+ self.heads = new_heads
+ # print("worker:",self.worker_id,self.process_id,[len(self.heads[i]) for i in self.heads],self.weights)
+
+ def get_head(self):
+ # self.get_process_heads()
+ self.get_worker_heads()
+ # print("weights:",self.weights,[h for h in self.heads])
+ self.indices = [0] * len(self.heads)
+ # self.process_heads = [{i:self.heads[i][self.process_id:]}for i in self.heads]
+ while True:
+ for (
+ n,
+ (head, weight),
+ ) in enumerate(zip(self.heads, self.weights)):
+ # if process_id == 0:
+ # print(weight,head)
+ for i in range(weight):
+ if self.indices[n] < len(self.heads[head]):
+ # print(self.heads[head][self.indices[n]],worker_id,self.indices)
+ yield self.heads[head][self.indices[n]]
+ self.indices[n] += 1
+ else:
+ self.indices[n] = 0
+ random.shuffle(self.heads[head])
+ # shuffle the indices
+
+ def __getitem__(self, index) -> Any:
+ if self.iterator is None:
+ self.iterator = self.get_head()
+ if not self.scale:
+ lang, path, semb, text = self.data[index]
+ ref_mels = self.ref_mels[path][: self.ref_k]
+
+ else:
+ # line = read_specific_line(self.transcript_path,index+1)
+
+ index = next(self.iterator)
+ # print(self.worker_id,self.process_id,index)
+ self.mmapped_file.seek(self.line_index[index])
+ line = self.mmapped_file.readline().decode("utf-8")
+
+ lang, path, text, semb_ids, ref_mels = line.split("|")
+ # a=5/0
+ # semb_ids = [int(i)+1 for i in semb_ids.split()]
+ semb = semb_ids.split()
+ ref_mels = [i.split(",") for i in ref_mels.split("\t")][: self.ref_k]
+
+ if len(semb) < 25:
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_mels) == 0:
+ ref_mels.append((path, 1))
+ ref_mels.append((path, 1))
+ ref_mels.append((path, 1))
+
+ while len(ref_mels) < self.ref_k:
+ ref_mels.append(ref_mels[-1])
+
+ text = text.lower().strip()
+ # try:
+ text_ids = [text_enc[""]] + [text_enc[i] for i in text] + [text_enc[""]]
+ semb_ids = (
+ [code_enc[""]] + [code_enc[i] for i in semb] + [code_enc[""]]
+ )
+
+ # except Exception as e:
+ # print(e)
+ # print(lang,path,text,index)
+ # exit
+ # input_ids = text_ids+semb_ids
+ # pad_length = config.t2s_position-(len(text_ids)+len(semb_ids))
+
+ # token_type_ids = [0]*len(text_ids)+[1]*len(semb_ids)+[0]*pad_length
+ # positional_ids = [i for i in range(len(text_ids))]+[i for i in range(len(semb_ids))]+[0]*pad_length
+ # labels = [-100]*len(text_ids)+semb_ids+[-100]*pad_length
+ # attention_mask = [1]*len(input_ids)+[0]*pad_length
+ # input_ids += [tok_enc['']]*pad_length
+
+ def get_random_portion(mel, mask_lengths):
+ clip = mask_lengths <= CLIP_LENGTH
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
+ for n, z in enumerate(clip):
+ if not z:
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
+ return ref_mel
+
+ try:
+ ref_mels = [self.get_mel(path)[0] for path, score in ref_mels]
+ except Exception as e:
+ print(index, e)
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ ref_c = []
+ for i in range(self.ref_k):
+ if ref_mels[i] is None:
+ continue
+ ref_c.append(ref_mels[i])
+
+ if len(ref_c) == 0:
+ # print('no refs worthy')
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_c) != self.ref_k:
+ # print('less refs found',len(ref_c))
+ while len(ref_c) < self.ref_k:
+ ref_c.append(ref_c[-1])
+
+ ref_mels = ref_c
+
+ max_target_len = max([x.size(1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len))
+ ) * 1e-9
+ mel_length = []
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, : mel.size(1)] = mel
+ mel_length.append(mel.shape[-1])
+
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
+
+ return {
+ "text_ids": text_ids,
+ "semb_ids": semb_ids,
+ "ref_mels": ref_mels,
+ "lang": torch.tensor(config.lang_index[lang]),
+ }
+
+
+# def get_padded_seq(sequences):
+
+# max_len=max([len(s) for s in sequences])
+# for i in range(len(sequences)):
+# sequences[i]=sequences[i]+tok_enc['']*(max_len-len(sequences[i]))
+
+# return sequences
+
+
+def get_padded_seq(sequences, pad_random, before=False, pad__=0):
+ max_len = max([len(s) for s in sequences])
+ seq_len = []
+ for i in range(len(sequences)):
+ seq_len.append(len(sequences[i]))
+ if pad_random:
+ pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
+ else:
+ pad_ = [pad__] * (max_len - len(sequences[i]))
+ if not before:
+ sequences[i] = sequences[i] + pad_
+ else:
+ sequences[i] = pad_ + sequences[i]
+
+ return sequences, seq_len
+
+
+def collate(batch):
+ text_ids = []
+ semb_ids = []
+ # paths=[]
+ ref_mels = []
+ langs = []
+ # ref_mels_length=[]
+
+ for b in batch:
+ text_ids.append(b["text_ids"])
+ semb_ids.append(b["semb_ids"])
+ # paths.append(b['path'])
+ ref_mels.append(b["ref_mels"])
+ langs.append(b["lang"])
+ # ref_mels_length.append(b['ref_mel_length'])
+
+ text_ids, text_len = get_padded_seq(
+ text_ids, pad_random=False, before=False, pad__=text_enc[""]
+ )
+ code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc[""])
+
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn(
+ (
+ len(batch),
+ ref_mels[0].shape[0],
+ config.n_mel_channels,
+ ref_max_target_len,
+ )
+ )
+ ) * 1e-9
+
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
+
+ # print(mel_padded.shape,torch.tensor(code).shape,torch.tensor(mel_length),get_mask_from_lengths(torch.tensor(mel_length)))
+
+ return (
+ torch.tensor(text_ids),
+ torch.tensor(code),
+ torch.tensor(text_len),
+ torch.tensor(code_len),
+ ref_mels_padded,
+ torch.tensor(langs),
+ )
+
+
+def get_dataset(transcript_path, get_process_id, total_processes):
+ return semantic_dataset_batch(
+ transcript_path,
+ scale=True,
+ process_id=get_process_id,
+ total_processes=total_processes,
+ )
+
+
+if __name__ == "__main__":
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config.ts_gradient_accumulation_steps
+ ) # ,kwargs_handlers=[ddp_kwargs]) mixed_precision="fp16",
+
+ get_process_id = accelerator.process_index
+ total_processes = accelerator.num_processes
+
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
+ # train_dataset_ = semantic_dataset_batch(config.data_path+'/transcript_train_20s_final_normalized_filtered.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',
+ # scale=True,process_id=get_process_id,total_processes = total_processes)
+ train_dataset_ = semantic_dataset_batch(
+ config.data_path + "/transcript_train_20s_final_normalized_filtered.txt",
+ "../" + config.data_path + "/semt.txt",
+ "../" + config.data_path + "/ref_clips.pkl",
+ scale=True,
+ process_id=get_process_id,
+ total_processes=total_processes,
+ )
+ # sampler = WeightedRandomSampler(
+ # train_dataset_.weights,
+ # train_dataset_.count,
+ # replacement=False)
+ train_dataset = DataLoader(
+ train_dataset_,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=config.ts_num_workers,
+ batch_size=config.ts_batch_size,
+ shuffle=False,
+ drop_last=False,
+ collate_fn=collate,
+ sampler=None,
+ )
+ print("batch", config.ts_batch_size)
+ # val_dataset = DataLoader(semantic_dataset_batch(config.data_path+'/transcript_test_20_final_normalized.txt','../'+config.data_path+'/semt.txt','../'+config.data_path+'/ref_clips.pkl',scale=True,process_id=get_process_id,total_processes = total_processes),pin_memory=True,
+ # persistent_workers=True,num_workers=2,batch_size=config.ts_batch_size,shuffle=True,drop_last=False,collate_fn=collate)
+
+ train_dataloader = accelerator.prepare(train_dataset)
+ # if accelerator.is_local_main_process:
+ # from IPython import embed
+ # embed()
+
+ # checkiong the sampler working
+ import math
+ from collections import defaultdict
+
+ def calculate_duration(code_len):
+ return math.ceil(((code_len + 1) / 50) * 2) / 2
+
+ sampling = defaultdict(int)
+ dataset = []
+ batch_data = {}
+ batch = 0
+ batch_data[batch] = defaultdict(int)
+ for n, data in enumerate(tqdm(train_dataloader)):
+ # break
+ text_ids, code, text_len, code_len, ref_clips, langs = data
+ # print(text_ids)
+ # print('=====')
+ # # break
+ for i, j in zip(code_len, text_ids):
+ dur = calculate_duration(i - 2)
+ # print(dur,i,code.shape)
+ # sampling[calculate_duration(i)]+=1
+ dataset.append(list(j.detach().cpu().numpy()))
+
+ if dur > 19.5:
+ batch_data[batch]["20_sentence"] += 1
+ continue
+ if dur <= 5:
+ batch_data[batch]["5s"] += 1
+ continue
+ elif dur <= 10:
+ batch_data[batch]["10s"] += 1
+ continue
+ elif dur <= 15:
+ batch_data[batch]["15s"] += 1
+ continue
+ elif dur <= 20:
+ batch_data[batch]["20s"] += 1
+ continue
+ # print(batch)
+ if (n + 1) % config.ts_gradient_accumulation_steps == 0:
+ batch += 1
+ batch_data[batch] = defaultdict(int)
+ # break
+ # if n==20:
+ # break
+ # # print(sampling)
+ with open(
+ f"Sampling_data_meta/sampling_{accelerator.process_index}.pkl", "wb"
+ ) as file:
+ pkl.dump(batch_data, file)
+ with open(
+ f"Sampling_data_meta/sampling_dataset_{accelerator.process_index}.pkl", "wb"
+ ) as file:
+ pkl.dump(dataset, file)
+ print(batch_data[0])
+ # # # return 0
diff --git a/T2S/gpt_inference.py b/T2S/gpt_inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..1f6747d4cfda9b19d89ba28e7c1ba2a5aa5b6a6a
--- /dev/null
+++ b/T2S/gpt_inference.py
@@ -0,0 +1,198 @@
+# copied from coqui
+import math
+
+import torch
+from torch import nn
+from transformers import (GemmaForCausalLM, GemmaPreTrainedModel,
+ GPT2PreTrainedModel)
+from transformers.modeling_outputs import (CausalLMOutputWithCrossAttentions,
+ CausalLMOutputWithPast)
+
+
+class GPT2InferenceModel(GemmaForCausalLM):
+ """Override GPT2LMHeadModel to allow for prefix conditioning."""
+
+ def __init__(self, config, gpt, pos_emb, embeddings, norm, linear, kv_cache):
+ super().__init__(config)
+ self.transformer = gpt
+ self.pos_embedding = pos_emb
+ self.embeddings = embeddings
+ self.final_norm = norm
+ self.lm_head = nn.Sequential(norm, linear)
+ self.kv_cache = kv_cache
+ self.cur_pos = None
+
+ def store_prefix_emb(self, prefix_emb):
+ self.cached_prefix_emb = prefix_emb
+
+ def prepare_inputs_for_generation(
+ self,
+ input_ids,
+ past_key_values=None,
+ attention_mask=None,
+ inputs_embeds=None,
+ cache_position=None,
+ use_cache=True,
+ **kwargs,
+ ):
+ past_length = 0
+ if past_key_values.get_seq_length() != 0:
+ # Past key values are always initialized with a `Cache` object -> no need for if-else anymore
+ past_length = (
+ cache_position[0]
+ if cache_position is not None
+ else past_key_values.get_seq_length()
+ )
+ # from IPython import embed; embed()
+ max_cache_length = (
+ torch.tensor(past_key_values.get_seq_length(), device=input_ids.device)
+ if past_key_values.get_seq_length() is not None
+ else None
+ )
+ cache_length = (
+ past_length
+ if max_cache_length is None
+ else torch.min(max_cache_length, past_length)
+ )
+
+ # Keep only the unprocessed tokens:
+ # 1 - If the length of the attention_mask exceeds the length of input_ids, then we are in a setting where
+ # some of the inputs are exclusively passed as part of the cache (e.g. when passing input_embeds as input)
+ if (
+ attention_mask is not None
+ and attention_mask.shape[1] > input_ids.shape[1]
+ ):
+ input_ids = input_ids[:, -(attention_mask.shape[1] - past_length) :]
+ # 2 - If the past_length is smaller than input_ids', then input_ids holds all input tokens. We can discard
+ # input_ids based on the past_length.
+ elif past_length < input_ids.shape[1]:
+ input_ids = input_ids[:, past_length:]
+ # 3 - Otherwise (past_length >= input_ids.shape[1]), let's assume input_ids only has unprocessed tokens.
+
+ # If we are about to go beyond the maximum cache length, we need to crop the input attention mask.
+ if (
+ max_cache_length is not None
+ and attention_mask is not None
+ and cache_length + input_ids.shape[1] > max_cache_length
+ ):
+ attention_mask = attention_mask[:, -max_cache_length:]
+
+ position_ids = kwargs.get("position_ids", None)
+ if attention_mask is not None and position_ids is None:
+ # create position_ids on the fly for batch generation
+ position_ids = attention_mask.long().cumsum(-1) - 1
+ position_ids.masked_fill_(attention_mask == 0, 1)
+ if past_key_values:
+ position_ids = position_ids[:, -input_ids.shape[1] :]
+
+ # if `inputs_embeds` are passed, we only want to use them in the 1st generation step
+ if inputs_embeds is not None and past_length == 0:
+ model_inputs = {"inputs_embeds": inputs_embeds}
+ else:
+ # The `contiguous()` here is necessary to have a static stride during decoding. torchdynamo otherwise
+ # recompiles graphs as the stride of the inputs is a guard. Ref: https://github.com/huggingface/transformers/pull/29114
+ # TODO: use `next_tokens` directly instead.
+ model_inputs = {"input_ids": input_ids.contiguous()}
+
+ input_length = (
+ position_ids.shape[-1] if position_ids is not None else input_ids.shape[-1]
+ )
+ if cache_position is None:
+ cache_position = torch.arange(
+ past_length, past_length + input_length, device=input_ids.device
+ )
+ elif use_cache:
+ cache_position = cache_position[-input_length:]
+
+ model_inputs.update(
+ {
+ "position_ids": position_ids,
+ "cache_position": cache_position,
+ "past_key_values": past_key_values,
+ "use_cache": use_cache,
+ "attention_mask": attention_mask,
+ }
+ )
+ return model_inputs
+
+ def forward(
+ self,
+ input_ids=None,
+ past_key_values=None,
+ attention_mask=None,
+ token_type_ids=None,
+ position_ids=None,
+ head_mask=None,
+ inputs_embeds=None,
+ encoder_hidden_states=None,
+ encoder_attention_mask=None,
+ labels=None,
+ use_cache=None,
+ output_attentions=None,
+ output_hidden_states=None,
+ return_dict=None,
+ cache_position=None,
+ ):
+ assert self.cached_prefix_emb is not None
+ assert inputs_embeds is None # Not supported by this inference model.
+ assert labels is None # Training not supported by this inference model.
+ return_dict = (
+ return_dict if return_dict is not None else self.config.use_return_dict
+ )
+
+ # assert len(past_key_values) + len(input_ids) == attention_mask.shape[1]
+
+ # Create embedding
+ prefix_len = self.cached_prefix_emb.shape[1]
+ if input_ids.shape[1] != 1:
+ gen_inputs = input_ids[:, prefix_len:]
+ gen_emb = self.embeddings(gen_inputs)
+ gen_emb = gen_emb # + self.pos_embedding(gen_emb)
+ if self.cached_prefix_emb.shape[0] != gen_emb.shape[0]:
+ prefix_emb = self.cached_prefix_emb.repeat_interleave(
+ gen_emb.shape[0] // self.cached_prefix_emb.shape[0], 0
+ )
+ else:
+ prefix_emb = self.cached_prefix_emb.to(gen_emb.dtype)
+ emb = torch.cat([prefix_emb, gen_emb], dim=1)
+ self.cur_pos = position_ids[:, -1:]
+ else:
+ emb = self.embeddings(input_ids)
+ self.cur_pos += 1.0
+ position_ids = self.cur_pos
+
+ transformer_outputs = self.transformer(
+ inputs_embeds=emb,
+ past_key_values=past_key_values,
+ attention_mask=attention_mask,
+ position_ids=position_ids,
+ use_cache=use_cache,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ return_dict=return_dict,
+ )
+ hidden_states = transformer_outputs[0]
+ lm_logits = self.lm_head(hidden_states)
+
+ if not return_dict:
+ return (lm_logits,) + transformer_outputs[1:]
+
+ return CausalLMOutputWithPast(
+ loss=None,
+ logits=lm_logits,
+ past_key_values=transformer_outputs.past_key_values,
+ hidden_states=transformer_outputs.hidden_states,
+ attentions=transformer_outputs.attentions,
+ )
+
+ @staticmethod
+ def _reorder_cache(past_key_values, beam_idx):
+ reordered_past = ()
+ for layer_past in past_key_values:
+ reordered_past += (
+ tuple(
+ past_state.index_select(0, beam_idx.to(past_state.device))
+ for past_state in layer_past
+ ),
+ )
+ return reordered_past
diff --git a/T2S/gpt_model_train.py b/T2S/gpt_model_train.py
new file mode 100755
index 0000000000000000000000000000000000000000..1a80f8081763bc5bedfc2ece9057423a83de758c
--- /dev/null
+++ b/T2S/gpt_model_train.py
@@ -0,0 +1,586 @@
+import os
+import sys
+
+sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
+import linecache
+import mmap
+import pickle as pkl
+import random
+from typing import Any
+
+import numpy as np
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.optim as optim
+import torchaudio
+import transformers
+from accelerate import Accelerator, DistributedDataParallelKwargs
+from torch.utils.data import DataLoader, Dataset, WeightedRandomSampler
+from tqdm import tqdm
+
+import wandb
+from config import config
+from T2S.autoregressive import TS_model
+from T2S.mel_spec import get_mel_spectrogram
+from T2S.utilities import get_mask_from_lengths
+from Text import code_labels, labels, text_labels
+
+torch.manual_seed(config.seed_value)
+np.random.seed(config.seed_value)
+random.seed(config.seed_value)
+
+# code encdec
+text_enc = {j: i for i, j in enumerate(text_labels)}
+text_dec = {i: j for i, j in enumerate(text_labels)}
+
+# text encdec
+code_enc = {j: i for i, j in enumerate(code_labels)}
+code_dec = {i: j for i, j in enumerate(code_labels)}
+
+
+def read_specific_line(filename, line_number):
+ line = linecache.getline(filename, line_number)
+ return line.strip() # Remove any leading or trailing whitespace
+
+
+CLIP_LENGTH = config.CLIP_LENGTH
+
+
+class semantic_dataset(Dataset):
+ def __init__(
+ self,
+ transcript_path,
+ semantic_path=None,
+ ref_mels_path=None,
+ ref_k=1,
+ scale=True,
+ ):
+ super().__init__()
+ self.scale = scale
+ if not scale:
+ with open(transcript_path, "r") as file:
+ data = file.read().strip("\n").split("\n")[:]
+
+ with open(semantic_path, "r") as file:
+ semb = file.read().strip("\n").split("\n")
+
+ with open(ref_mels_path, "rb") as file:
+ self.ref_mels = pkl.load(file)
+
+ semb = {
+ i.split("\t")[0]: [j for j in i.split("\t")[1].split()] for i in semb
+ }
+ data = {i.split("|")[0]: i.split("|")[1].strip().lower() for i in data}
+
+ self.data = [[i, semb[i], data[i]] for i in data.keys()]
+
+ else:
+ line_index = {}
+ with open(transcript_path, "rb") as file:
+ mmapped_file = mmap.mmap(file.fileno(), 0, access=mmap.ACCESS_READ)
+ line_number = 0
+ offset = 0
+ pbar = tqdm()
+ while offset < len(mmapped_file):
+ line_index[line_number] = offset
+ offset = mmapped_file.find(b"\n", offset) + 1
+ line_number += 1
+ pbar.update(1)
+ pbar.close()
+ self.mmapped_file = mmapped_file
+ self.line_index = line_index
+ self.data_len = len(line_index)
+ print("data length:", self.data_len)
+ self.transcript_path = transcript_path
+
+ self.ref_k = ref_k
+ self.max_wav_value = config.MAX_WAV_VALUE
+
+ def get_mel(self, filepath):
+ audio_norm, sampling_rate = torchaudio.load(filepath)
+ melspec = get_mel_spectrogram(audio_norm, sampling_rate).squeeze(0)
+ energy = []
+ return melspec, list(energy)
+
+ def __len__(self):
+ if self.scale:
+ return self.data_len
+ return len(self.data)
+
+ def __getitem__(self, index) -> Any:
+ if not self.scale:
+ lang, path, semb, text = self.data[index]
+ ref_mels = self.ref_mels[path][: self.ref_k]
+
+ else:
+ self.mmapped_file.seek(self.line_index[index])
+ line = self.mmapped_file.readline().decode("utf-8")
+
+ try:
+ lang, path, text, semb_ids = line.split("|")
+ except Exception as e:
+ print(index, line)
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+ semb = semb_ids.split()
+ ref_mels = [path]
+ # ref_mels = [i.split(',') for i in ref_mels.split('\t')][:self.ref_k]
+
+ if len(semb) < 5:
+ print(index, "No Semb tokens found")
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_mels) == 0:
+ ref_mels.append((path, 1))
+ ref_mels.append((path, 1))
+ ref_mels.append((path, 1))
+
+ while len(ref_mels) < self.ref_k:
+ ref_mels.append(ref_mels[-1])
+
+ text = text.lower().strip()
+ try:
+ text_ids = (
+ [text_enc[""]] + [text_enc[i] for i in text] + [text_enc[""]]
+ )
+ semb_ids = (
+ [code_enc[""]] + [code_enc[i] for i in semb] + [code_enc[""]]
+ )
+ except Exception as e:
+ print(index, e)
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ def get_random_portion(mel, mask_lengths):
+ clip = mask_lengths <= CLIP_LENGTH
+ ref_mel = mel[:, :, :CLIP_LENGTH].clone()
+ for n, z in enumerate(clip):
+ if not z:
+ start = np.random.randint(0, mask_lengths[n].item() - CLIP_LENGTH)
+ ref_mel[n, :, :] = mel[n, :, start : start + CLIP_LENGTH].clone()
+ return ref_mel
+
+ try:
+ ref_mels = [self.get_mel(path)[0] for path in ref_mels]
+ except Exception as e:
+ print(index, e, path)
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ ref_c = []
+ for i in range(self.ref_k):
+ if ref_mels[i] is None:
+ continue
+ ref_c.append(ref_mels[i])
+
+ if len(ref_c) == 0:
+ if index + 1 < self.data_len:
+ return self.__getitem__(index + 1)
+ return self.__getitem__(0)
+
+ if len(ref_c) != self.ref_k:
+ while len(ref_c) < self.ref_k:
+ ref_c.append(ref_c[-1])
+
+ ref_mels = ref_c
+
+ max_target_len = max([x.size(1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn((self.ref_k, config.n_mel_channels, max_target_len))
+ ) * 1e-9
+ mel_length = []
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, : mel.size(1)] = mel
+ mel_length.append(mel.shape[-1])
+
+ ref_mels = get_random_portion(ref_mels_padded, torch.tensor(mel_length))
+
+ return {
+ "text_ids": text_ids,
+ "semb_ids": semb_ids,
+ "ref_mels": ref_mels,
+ "lang": torch.tensor(config.lang_index[lang]),
+ }
+
+
+def get_padded_seq(sequences, pad_random, before=False, pad__=0):
+ max_len = max([len(s) for s in sequences])
+ seq_len = []
+ for i in range(len(sequences)):
+ seq_len.append(len(sequences[i]))
+ if pad_random:
+ pad_ = pad_ = list((np.random.rand(max_len - len(sequences[i]))) * 1e-9)
+ else:
+ pad_ = [pad__] * (max_len - len(sequences[i]))
+ if not before:
+ sequences[i] = sequences[i] + pad_
+ else:
+ sequences[i] = pad_ + sequences[i]
+
+ return sequences, seq_len
+
+
+def collate(batch):
+ text_ids = []
+ semb_ids = []
+ ref_mels = []
+ langs = []
+
+ for b in batch:
+ text_ids.append(b["text_ids"])
+ semb_ids.append(b["semb_ids"])
+ ref_mels.append(b["ref_mels"])
+ langs.append(b["lang"])
+
+ text_ids, text_len = get_padded_seq(
+ text_ids, pad_random=False, before=False, pad__=text_enc[""]
+ )
+ code, code_len = get_padded_seq(semb_ids, pad_random=False, pad__=code_enc[""])
+
+ ref_max_target_len = max([x.size(-1) for x in ref_mels])
+ ref_mels_padded = (
+ torch.randn(
+ (
+ len(batch),
+ ref_mels[0].shape[0],
+ config.n_mel_channels,
+ ref_max_target_len,
+ )
+ )
+ ) * 1e-9
+
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, :, : mel.size(-1)] = mel
+
+ return (
+ torch.tensor(text_ids),
+ torch.tensor(code),
+ torch.tensor(text_len),
+ torch.tensor(code_len),
+ ref_mels_padded,
+ torch.tensor(langs),
+ )
+
+
+def train(model, train_dataset, val_dataset, save_dir, checkpoint_initial=None):
+ accelerator = Accelerator(
+ gradient_accumulation_steps=config.ts_gradient_accumulation_steps
+ ) # ,kwargs_handlers=[ddp_kwargs]) mixed_precision="fp16",
+
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
+ conf_ = {}
+ for i, j in config.__dict__.items():
+ conf_[str(i)] = str(j)
+ wandb_log = wandb.init(
+ project=config.wandb_project,
+ entity=config.user_name,
+ name=config.model_name,
+ config=conf_,
+ )
+ wandb_log.watch(model, log_freq=100)
+ else:
+ wandb_log = None
+
+ optimizer = optim.Adam(
+ model.parameters(), lr=config.ts_lr, weight_decay=config.ts_weight_decay
+ )
+ # optimizer = transformers.Adafactor(model.parameters(), lr=config.ts_lr,weight_decay=config.ts_weight_decay, relative_step =False, scale_parameter =False)
+ lr = config.ts_lr
+ step_num = 0
+ start_epoch = 0
+ if checkpoint_initial is not None:
+ model.load_state_dict(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["model"],
+ strict=True,
+ )
+ if (
+ config.ts_finetuning
+ ): # freezing heads results in less hallucinations after Ft.
+ for param in model.text_head.parameters():
+ param.requires_grad = False
+
+ for param in model.code_head.parameters():
+ param.requires_grad = False
+
+ model.train()
+
+ print("loading optimizer")
+ optimizer.load_state_dict(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
+ "optimizer"
+ ]
+ )
+ step_num = (
+ int(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))["step"]
+ )
+ + 1
+ )
+ step_num = 0
+ start_epoch = (
+ int(
+ torch.load(checkpoint_initial, map_location=torch.device("cpu"))[
+ "epoch"
+ ]
+ )
+ + 1
+ )
+ print(f"Resuming training from epoch {start_epoch} and step {step_num}")
+
+ train_dataloader, val_dataloader, model, optimizer = accelerator.prepare(
+ train_dataset, val_dataset, model, optimizer
+ )
+ val_dataloader = val_dataset
+ min_val_loss = 1000
+ model.train()
+
+ for i in range(start_epoch, config.ts_epochs):
+ epoch_loss = []
+ if accelerator.is_main_process:
+ train_loader = tqdm(
+ train_dataloader,
+ desc="Rank %d: Training epoch %d"
+ % (accelerator.local_process_index, i),
+ )
+ else:
+ train_loader = train_dataloader
+
+ for n, inputs in enumerate(train_loader):
+ with accelerator.accumulate(model):
+ # with accelerator.autocast():
+ text_ids, code, text_len, code_len, ref_clips, langs = inputs
+ mask_text = get_mask_from_lengths(text_len)
+ code_mask = get_mask_from_lengths(code_len)
+ attn_mask = torch.cat([mask_text, code_mask], dim=1)
+ loss_text, loss_code, _ = model(
+ text_ids=text_ids,
+ ref_clips=ref_clips,
+ codes_ids=code,
+ language=langs,
+ return_loss=True,
+ attn_mask=attn_mask,
+ )
+
+ loss_text *= mask_text[:, 1:].float()
+ loss_text = loss_text.sum() / mask_text[:, 1:].sum()
+
+ loss_code *= code_mask[:, 1:].float()
+ loss_code = loss_code.sum() / code_mask[:, 1:].sum()
+
+ loss = loss_text * config.text_loss_weight + loss_code
+
+ accelerator.backward(loss)
+ accelerator.clip_grad_norm_(model.parameters(), 1.0)
+ optimizer.step()
+ optimizer.zero_grad()
+ step_num += 1
+
+ if (
+ step_num % config.ts_gradient_accumulation_steps == 0
+ and config.ts_wandb_logs
+ and accelerator.is_main_process
+ ):
+ wandb_log.log(
+ {
+ "training_loss": loss.item(),
+ "step": step_num // config.ts_gradient_accumulation_steps,
+ }
+ )
+
+ epoch_loss.append(loss.item())
+
+ if (
+ not config.ts_finetuning
+ and step_num
+ % (config.ts_gradient_accumulation_steps * config.ts_eval_step)
+ == 0
+ ):
+ val_loss = val(model, val_dataloader, accelerator.is_main_process)
+ val_loss = accelerator.gather_for_metrics(val_loss).mean().item()
+ model.train()
+ if config.ts_wandb_logs and accelerator.is_main_process:
+ wandb_log.log(
+ {
+ "val_loss": val_loss,
+ "epoch": i,
+ "scheduled_learning_rate": lr,
+ "step": step_num // config.ts_gradient_accumulation_steps,
+ }
+ )
+
+ if val_loss < min_val_loss:
+ # save the model
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": str(step_num // config.gradient_accumulation_steps),
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ torch.save(
+ checkpoint,
+ os.path.join(config.save_root_dir, "_best.pt"),
+ )
+ min_val_loss = val_loss
+
+ # save the latest checkpoint
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": str(step_num // config.gradient_accumulation_steps),
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ torch.save(
+ checkpoint,
+ os.path.join(config.save_root_dir, str(step_num // config.gradient_accumulation_steps) + "_latest.pt"),
+ )
+ print(f"Saved latest checkpoint at {os.path.join(config.save_root_dir, str(step_num // config.gradient_accumulation_steps) + '_latest.pt')}")
+
+ val_loss = val(model, val_dataloader, accelerator.is_main_process)
+ val_loss = accelerator.gather_for_metrics(val_loss).mean().item()
+ model.train()
+ if config.ts_wandb_logs and accelerator.is_main_process:
+ wandb_log.log(
+ {
+ "val_loss": val_loss,
+ "epoch": i,
+ "scheduled_learning_rate": lr,
+ "step": step_num // config.ts_gradient_accumulation_steps,
+ }
+ )
+
+ if val_loss < min_val_loss:
+ # save the model
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": str(step_num // config.gradient_accumulation_steps),
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ torch.save(
+ checkpoint, os.path.join(config.save_root_dir, "best.pt")
+ )
+ min_val_loss = val_loss
+ print(f"Saved best checkpoint at {os.path.join(config.save_root_dir, 'best.pt')}")
+ accelerator.wait_for_everyone()
+ unwrapped_model = accelerator.unwrap_model(model)
+ checkpoint = {
+ "epoch": i,
+ "step": str(step_num // config.gradient_accumulation_steps),
+ "model": unwrapped_model.state_dict(),
+ "optimizer": optimizer.state_dict(),
+ }
+ torch.save(
+ checkpoint,
+ os.path.join(config.save_root_dir, str(i) + "_latest.pt"),
+ )
+
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
+ wandb_log.log(
+ {
+ "scheduled_learning_rate": lr,
+ "epoch": i,
+ "step": step_num // config.ts_gradient_accumulation_steps,
+ }
+ )
+ print(
+ "epoch_number : ", i, " training loss : ", sum(epoch_loss) / len(epoch_loss)
+ )
+
+ if config.ts_wandb_logs and accelerator.is_local_main_process:
+ wandb_log.finish()
+
+
+def val(model, val_dataloader, _main=False):
+ """
+ Return the loss value
+ """
+ print("VALIDATION STARTING:")
+ model.eval()
+ val_loss = []
+ device = next(model.parameters()).device
+ if _main:
+ val_dataloader = tqdm(val_dataloader)
+ with torch.no_grad():
+ for inputs in val_dataloader:
+ text_ids, code, text_len, code_len, ref_clips, langs = inputs
+ mask_text = get_mask_from_lengths(text_len).to(device)
+ code_mask = get_mask_from_lengths(code_len).to(device)
+ attn_mask = torch.cat([mask_text, code_mask], dim=1)
+ loss_text, loss_code, _ = model(
+ text_ids=text_ids.to(device),
+ ref_clips=ref_clips.to(device),
+ codes_ids=code.to(device),
+ language=langs.to(device),
+ return_loss=True,
+ attn_mask=attn_mask,
+ )
+
+ loss_text *= mask_text[:, 1:].float()
+ loss_text = loss_text.sum() / mask_text[:, 1:].sum()
+ loss_code *= code_mask[:, 1:].float()
+ loss_code = loss_code.sum() / code_mask[:, 1:].sum()
+ loss = loss_text * config.text_loss_weight + loss_code
+
+ val_loss.append(loss.item())
+
+ val_loss = sum(val_loss) / len(val_loss)
+ print(" Validation loss : ", val_loss)
+ return torch.tensor(val_loss).to(device)
+
+
+def main():
+
+ os.makedirs(os.path.join(config.save_root_dir, config.model_name, "T2S"), exist_ok=True)
+
+ file_name_train = config.train_file
+ file_name_val = config.val_file
+
+ checkpoint = config.t2s_checkpoint
+ model = TS_model(n_embed=1024, n_layer=30, n_head=16)
+
+ val_dataset = DataLoader(
+ semantic_dataset(file_name_val, scale=True),
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=2,
+ batch_size=config.ts_batch_size,
+ shuffle=True,
+ drop_last=False,
+ collate_fn=collate,
+ )
+
+ train_dataset_ = semantic_dataset(file_name_train, scale=True)
+ train_dataset = DataLoader(
+ train_dataset_,
+ pin_memory=True,
+ persistent_workers=True,
+ num_workers=config.ts_num_workers,
+ batch_size=config.ts_batch_size,
+ shuffle=True,
+ drop_last=False,
+ collate_fn=collate,
+ )
+
+ train(
+ model,
+ train_dataset,
+ val_dataset,
+ save_dir=os.path.join(config.save_root_dir, config.model_name, "T2S"),
+ checkpoint_initial=checkpoint
+ )
+
+
+if __name__ == "__main__":
+ main()
diff --git a/T2S/mel_spec.py b/T2S/mel_spec.py
new file mode 100755
index 0000000000000000000000000000000000000000..eca4a9d38ef3840f20a55d9fcdd52019689c889e
--- /dev/null
+++ b/T2S/mel_spec.py
@@ -0,0 +1,178 @@
+# Copyright (c) 2024 NVIDIA CORPORATION.
+# Licensed under the MIT license.
+
+# Adapted from https://github.com/jik876/hifi-gan under the MIT license.
+# LICENSE is in incl_licenses directory.
+
+if __name__ == "__main__":
+ import os
+ import sys
+
+ sys.path.append("../")
+
+import math
+import os
+import pathlib
+import random
+
+import numpy as np
+import torch
+import torch.utils.data
+from librosa.filters import mel as librosa_mel_fn
+from librosa.util import normalize
+from scipy.io.wavfile import read
+from tqdm import tqdm
+
+from config import config
+
+MAX_WAV_VALUE = 32767.0 # NOTE: 32768.0 -1 to prevent int16 overflow (results in popping sound in corner cases)
+
+
+def load_wav(full_path, sr_target):
+ sampling_rate, data = read(full_path)
+ if sampling_rate != sr_target:
+ raise RuntimeError(
+ f"Sampling rate of the file {full_path} is {sampling_rate} Hz, but the model requires {sr_target} Hz"
+ )
+ return data, sampling_rate
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ return np.log(np.clip(x, a_min=clip_val, a_max=None) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ return np.exp(x) / C
+
+
+def dynamic_range_compression_torch(x, C=1, clip_val=1e-5):
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression_torch(x, C=1):
+ return torch.exp(x) / C
+
+
+def spectral_normalize_torch(magnitudes):
+ return dynamic_range_compression_torch(magnitudes)
+
+
+def spectral_de_normalize_torch(magnitudes):
+ return dynamic_range_decompression_torch(magnitudes)
+
+
+mel_basis_cache = {}
+hann_window_cache = {}
+
+
+def mel_spectrogram(
+ y: torch.Tensor,
+ n_fft: int,
+ num_mels: int,
+ sampling_rate: int,
+ hop_size: int,
+ win_size: int,
+ fmin: int,
+ fmax: int = None,
+ center: bool = False,
+) -> torch.Tensor:
+ """
+ Calculate the mel spectrogram of an input signal.
+ This function uses slaney norm for the librosa mel filterbank (using librosa.filters.mel) and uses Hann window for STFT (using torch.stft).
+
+ Args:
+ y (torch.Tensor): Input signal.
+ n_fft (int): FFT size.
+ num_mels (int): Number of mel bins.
+ sampling_rate (int): Sampling rate of the input signal.
+ hop_size (int): Hop size for STFT.
+ win_size (int): Window size for STFT.
+ fmin (int): Minimum frequency for mel filterbank.
+ fmax (int): Maximum frequency for mel filterbank. If None, defaults to half the sampling rate (fmax = sr / 2.0) inside librosa_mel_fn
+ center (bool): Whether to pad the input to center the frames. Default is False.
+
+ Returns:
+ torch.Tensor: Mel spectrogram.
+ """
+ if torch.min(y) < -1.0:
+ print(f"[WARNING] Min value of input waveform signal is {torch.min(y)}")
+ if torch.max(y) > 1.0:
+ print(f"[WARNING] Max value of input waveform signal is {torch.max(y)}")
+
+ device = y.device
+ key = f"{n_fft}_{num_mels}_{sampling_rate}_{hop_size}_{win_size}_{fmin}_{fmax}_{device}"
+
+ if key not in mel_basis_cache:
+ mel = librosa_mel_fn(
+ sr=sampling_rate, n_fft=n_fft, n_mels=num_mels, fmin=fmin, fmax=fmax
+ )
+ mel_basis_cache[key] = torch.from_numpy(mel).float().to(device)
+ hann_window_cache[key] = torch.hann_window(win_size).to(device)
+
+ mel_basis = mel_basis_cache[key]
+ hann_window = hann_window_cache[key]
+
+ padding = (n_fft - hop_size) // 2
+ y = torch.nn.functional.pad(
+ y.unsqueeze(1), (padding, padding), mode="reflect"
+ ).squeeze(1)
+
+ spec = torch.stft(
+ y,
+ n_fft,
+ hop_length=hop_size,
+ win_length=win_size,
+ window=hann_window,
+ center=center,
+ pad_mode="reflect",
+ normalized=False,
+ onesided=True,
+ return_complex=True,
+ )
+ spec = torch.sqrt(torch.view_as_real(spec).pow(2).sum(-1) + 1e-9)
+
+ mel_spec = torch.matmul(mel_basis, spec)
+ mel_spec = spectral_normalize_torch(mel_spec)
+
+ return mel_spec
+
+
+def get_mel_spectrogram(wav, sr):
+ """
+ Generate mel spectrogram from a waveform using given hyperparameters.
+
+ Args:
+ wav (torch.Tensor): Input waveform.
+ h: Hyperparameters object with attributes n_fft, num_mels, sampling_rate, hop_size, win_size, fmin, fmax.
+
+ Returns:
+ torch.Tensor: Mel spectrogram.
+ """
+
+ assert sr == config.sampling_rate, (
+ f"Given SR : {sr}, Required SR: {config.sampling_rate}"
+ )
+
+ return mel_spectrogram(
+ wav,
+ config.filter_length,
+ config.n_mel_channels,
+ config.sampling_rate,
+ config.hop_length,
+ config.win_length,
+ config.mel_fmin,
+ config.mel_fmax,
+ )
+
+
+if __name__ == "__main__":
+ import torchaudio
+
+ path = "/delta/NeuralSpeak_cfm_conv/Samples/IITM_cfm_bigv_harsh/S2A/orig/0_test.wav"
+ wav, sr = torchaudio.load(path)
+
+ wav = wav[:, :sr]
+ print(wav.shape)
+ mel_spec = get_mel_spectrogram(wav, sr)
+ duration = wav.shape[-1] / sr
+ print(duration, mel_spec.shape)
diff --git a/T2S/meta_stats.py b/T2S/meta_stats.py
new file mode 100755
index 0000000000000000000000000000000000000000..a49b0c93cb84608a23f7ef108e7960956dd8052e
--- /dev/null
+++ b/T2S/meta_stats.py
@@ -0,0 +1,282 @@
+# import math
+# from collections import defaultdict
+# from config import config
+
+
+# def calculate_duration(semt):
+# return math.ceil(((len(semt.split()) + 1) / 50) * 2) / 2
+
+# def get_weights(weights,expected_data,languages):
+# new_weights = []
+
+# expected_weights = config.weights_percentage_duration
+
+# total_files = sum([expected_data['total'][i] for i in expected_data['total']])
+
+# duration_multiplier = {i:config.weights_percentage_duration[i]/(expected_data['total'][i]/total_files) for i in config.weights_percentage_duration}
+# print(expected_data['total'],duration_multiplier)
+# for i in weights:
+# new_weights.append(duration_multiplier[i[1]])
+# return new_weights
+
+
+# def process_file(file_path):
+# weights = []
+# expected_data = defaultdict(lambda: {i:0 for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]})
+# languages = defaultdict(int)
+# count = 0
+# with open(file_path, 'r') as file:
+# for line in file:
+# count+=1
+# lang, path, text, semt, ref_files = line.split('|')
+# languages[lang]+=1
+# dur = calculate_duration(semt)
+# # weights.append([lang,,1.0])
+# # duration_files['total'][dur] += 1
+
+# if len(text.strip().split(' '))==1:
+# expected_data[lang]["single_word"]+=1
+# expected_data['total']["single_word"]+=1
+# weights.append([lang,"single_word",1.0])
+# continue
+
+# if dur >19.5 and dur<=20:
+# expected_data[lang]["20_sentence"]+=1
+# expected_data['total']["20_sentence"]+=1
+# weights.append([lang,"20_sentence",1.0])
+# continue
+
+# if dur<=5:
+# expected_data[lang]["5s"]+=1
+# expected_data['total']["5s"]+=1
+# weights.append([lang,"5s",1.0])
+# continue
+# elif dur<=10:
+# expected_data[lang]["10s"]+=1
+# expected_data['total']["10s"]+=1
+# weights.append([lang,"10s",1.0])
+# continue
+# elif dur<=15:
+# expected_data[lang]["15s"]+=1
+# expected_data['total']["15s"]+=1
+# weights.append([lang,"15s",1.0])
+# continue
+# elif dur<=20:
+# expected_data[lang]["20s"]+=1
+# expected_data['total']["20s"]+=1
+# weights.append([lang,"20s",1.0])
+# continue
+# else:
+# # expected_data[lang][">20"]+=1
+# # expected_data['total'][">20"]+=1
+# # weights.append([lang,">20",1.0])
+# continue
+
+# final_weights = get_weights(weights,expected_data,languages)
+
+
+# return final_weights,count
+
+# def process_file_for_heads(file_path,total_processes,process_id):
+# weights = []
+# # heads = defaultdict(lambda: {i:[] for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]}) # to include langauges
+# heads = {i:[] for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]}
+
+# expected_data = defaultdict(lambda: {i:0 for i in ["single_word","5s","10s","15s","20s","20_sentence",">20"]})
+# languages = defaultdict(int)
+# count = 0
+# line_number = -1
+# with open(file_path, 'r') as file:
+# for line in file:
+# count+=1
+# line_number+=1
+# lang, path, text, semt, ref_files = line.split('|')
+# languages[lang]+=1
+# dur = calculate_duration(semt)
+# # weights.append([lang,,1.0])
+# # duration_files['total'][dur] += 1
+
+# if len(text.strip().split(' '))==1:
+# expected_data[lang]["single_word"]+=1
+# expected_data['total']["single_word"]+=1
+# weights.append([lang,"single_word",1.0])
+# heads["single_word"].append(line_number)
+# continue
+
+# if dur >19.5 and dur<=20:
+# expected_data[lang]["20_sentence"]+=1
+# expected_data['total']["20_sentence"]+=1
+# weights.append([lang,"20_sentence",1.0])
+# heads["20_sentence"].append(line_number)
+# continue
+
+# if dur<=5:
+# expected_data[lang]["5s"]+=1
+# expected_data['total']["5s"]+=1
+# weights.append([lang,"5s",1.0])
+# heads["5s"].append(line_number)
+# continue
+# elif dur<=10:
+# expected_data[lang]["10s"]+=1
+# expected_data['total']["10s"]+=1
+# weights.append([lang,"10s",1.0])
+# heads["10s"].append(line_number)
+# continue
+# elif dur<=15:
+# expected_data[lang]["15s"]+=1
+# expected_data['total']["15s"]+=1
+# weights.append([lang,"15s",1.0])
+# heads["15s"].append(line_number)
+# continue
+# elif dur<=20:
+# expected_data[lang]["20s"]+=1
+# expected_data['total']["20s"]+=1
+# weights.append([lang,"20s",1.0])
+# heads["20s"].append(line_number)
+# continue
+# else:
+# # expected_data[lang][">20"]+=1
+# # expected_data['total'][">20"]+=1
+# # weights.append([lang,">20",1.0])
+# continue
+
+# line_number+=1
+# # final_weights = get_weights(weights,expected_data,languages)
+# # final_weights = [1]*len(heads) # same weightage
+# if config.ts_gradient_accumulation_steps>1:
+# batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps//config.ts_num_workers
+# else:
+# batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps
+# # batch = config.ts_batch_size*total_processes*config.ts_gradient_accumulation_steps
+# # heads = heads[:-1]
+# heads = {i:heads[i] for i in heads if len(heads[i])!=0}
+# total_size = sum([len(heads[i]) for i in heads if len(heads[i])!=0])
+# norm_nums = [len(heads[i])/total_size for i in heads if len(heads[i])!=0]
+# final_weights = []
+
+# for i in norm_nums:
+# final_weights.append(max(1,math.ceil(i*batch)))
+
+# rem_elem = sum(final_weights)-batch
+# final_weights[final_weights.index(max(final_weights))]-=rem_elem
+
+# # heads,final_weights = sorted(zip(heads,final_weights),key=lambda x:x[1])
+
+# # process_head = []
+# # proc = 0
+# # sm=0
+# # for i in final_weights:
+# # # sm+=i
+# # if sm+i >
+
+# # process_batch_size = config.ts_batch_size*config.ts_gradient_accumulation_steps
+# # proc = 0
+# # lens = {i:len(heads[i]) for i in heads}
+# # while proc <= process_id:
+# # new_heads ={}
+# # new_weights =[]
+# # sm=0
+# # for i,j in zip(heads,range(len(final_weights))):
+# # if sm + final_weights[j] > process_batch_size:
+# # if sm+final_weights[j] == process_batch_size:
+# # new_heads[i] = heads[i]
+# # new_weights.append(final_weights[j])
+# # heads.pop(i)
+# # else:
+# # new_heads[i] = heads[i][:1+(lens[i]*(process_batch_size-sm)//process_batch_size)]
+# # heads[i] = heads[i][1+(lens[i]*(process_batch_size-sm)//process_batch_size):]
+# # if len(heads[i]) == 0:
+# # heads.pop(i)
+# # new_weights.append(process_batch_size-sm)
+# # final_weights[j]-= process_batch_size-sm
+# # sm = 0
+# # proc+=1
+# # final_weights=final_weights[j:]
+# # break
+# # else:
+# # new_heads[i] = heads[i]
+# # new_weights.append(final_weights[j])
+# # heads.pop(i)
+
+
+# # if len(heads) == 0:
+# # break
+
+# # print("weights",new_weights,[(i,len(heads[i])) for i in new_heads])
+# # return new_heads,new_weights,count
+# # # make it more effective as to real_batch_size instead of worker_batch_size
+
+# # # #[867, 31444, 35458, 6764, 1561, 96, 0] per gpu for iitm
+# # # [10,400,400,60,20,1]
+# # #
+# print("weights",final_weights,[(i,len(heads[i])) for i in heads])
+# print(process_id)
+# new_heads, new_weights = process_batches(heads,final_weights,process_id+1)
+
+# assert len(new_heads) != 0 and len(new_weights) == len(new_heads), print(new_heads)
+
+# print("process id",process_id,new_weights,[(i,len(new_heads[i])) for i in new_heads])
+# return new_heads, new_weights, count
+
+# def process_batches(heads, final_weights, process_id=0):
+# if config.ts_gradient_accumulation_steps>1:
+# process_batch_size = config.ts_batch_size * config.ts_gradient_accumulation_steps//config.ts_num_workers
+# else:
+# process_batch_size = config.ts_batch_size * config.ts_gradient_accumulation_steps
+# proc = 0
+# # Create a copy of the original dictionaries to avoid modifying them during iteration
+# remaining_heads = heads.copy()
+# remaining_weights = final_weights.copy()
+# lens = {i: len(heads[i]) for i in heads}
+
+# while proc <= process_id and remaining_heads:
+# new_heads = {}
+# new_weights = []
+# current_sum = 0
+
+# # Convert items to list to avoid dictionary size change during iteration
+# items = list(remaining_heads.items())
+
+# for key, head_list in items:
+# weight = remaining_weights[0] # Get the corresponding weight
+
+# if current_sum + weight > process_batch_size:
+# # Calculate how much of this head we can include
+# remaining_space = process_batch_size - current_sum
+# if current_sum + weight == process_batch_size:
+# # If it fits exactly
+# new_heads[key] = head_list
+# new_weights.append(weight)
+# del remaining_heads[key]
+# remaining_weights.pop(0)
+# # print("inside first")
+# else:
+# # If we need to split the head
+# split_point = 1 + (lens[key] * remaining_space) // process_batch_size
+# new_heads[key] = head_list[:split_point]
+# remaining_heads[key] = head_list[split_point:]
+# # print(process_id,"inside >",remaining_heads)
+# if not remaining_heads[key]: # If the remaining list is empty
+# del remaining_heads[key]
+
+# new_weights.append(remaining_space)
+# remaining_weights[0] -= remaining_space
+# # print(process_id,"inside >",remaining_heads)
+# # print("inside seciond")
+# proc += 1
+# break
+# else:
+# # If the current head fits completely
+# new_heads[key] = head_list
+# new_weights.append(weight)
+# del remaining_heads[key]
+# remaining_weights.pop(0)
+# current_sum += weight
+# # print("inside third")
+# if len(remaining_heads)==0:
+# proc+=1
+# if proc == process_id:
+# # print("process id",process_id,proc,new_weights,[(i,len(new_heads[i])) for i in new_heads])
+# return new_heads, new_weights
+
+# return {}, [] # Return empty structures if no valid batch is found
diff --git a/T2S/plot_embed.py b/T2S/plot_embed.py
new file mode 100755
index 0000000000000000000000000000000000000000..3728c81dfd05e0985cb8068b61b13ca5fb5550e4
--- /dev/null
+++ b/T2S/plot_embed.py
@@ -0,0 +1,41 @@
+# import matplotlib.pyplot as plt
+# import numpy as np
+# import subprocess
+# import json
+# from umap import UMAP
+# from tqdm import tqdm
+
+# def count_lines_shell(file_path):
+# result = subprocess.run(["wc", "-l", file_path], capture_output=True, text=True)
+# return int(result.stdout.split()[0])
+
+# def load_chunk(file_path,chunk_size):
+# lines = count_lines_shell(file_path)
+# with open(file_path,'r') as file:
+# dataset = []
+# # embed = []
+# for i in tqdm(file,total=lines):
+# data = json.loads(i)
+# key = list(data.keys())[0]
+# dataset.append([key,data[key][0]])
+# # embed.append(data[key][1])
+# if len(dataset)==chunk_size:
+# return dataset
+# dataset=[]
+# # embed=[]
+# if len(dataset)!=0:
+# return dataset
+
+
+# if __name__ == '__main__':
+
+# file_name = "pocketfm_pure_textlossless_data_stats.json"
+# bs = -1
+# # data = load_chunk(file_name,-1)
+# embed = np.load("/nlsasfs/home/dubverse/varshulg/work/NeuralSpeak/T2S/pocketfm_embeddings.npy")
+# print(embed.shape)
+# plt.scatter(embed[:,0],embed[:,1])
+# # plt.imsave("gst_embed.png")
+# plt.savefig('gst_embed_pocketfm.png')#, dpi=300, bbox_inches='tight')
+
+
diff --git a/T2S/stft.py b/T2S/stft.py
new file mode 100755
index 0000000000000000000000000000000000000000..748488322b72bd6d84c0d41a4483ed724d7cfce3
--- /dev/null
+++ b/T2S/stft.py
@@ -0,0 +1,109 @@
+# import torch
+# import numpy as np
+# import torch.nn.functional as F
+# from torch.autograd import Variable
+# from scipy.signal import get_window
+# from librosa.util import pad_center, tiny
+# from utilities import window_sumsquare
+
+
+# class STFT(torch.nn.Module):
+# """adapted from Prem Seetharaman's https://github.com/pseeth/pytorch-stft"""
+# def __init__(self, filter_length=800, hop_length=200, win_length=800,
+# window='hann'):
+# super(STFT, self).__init__()
+# self.filter_length = filter_length
+# self.hop_length = hop_length
+# self.win_length = win_length
+# self.window = window
+# self.forward_transform = None
+# scale = self.filter_length / self.hop_length
+# fourier_basis = np.fft.fft(np.eye(self.filter_length))
+
+# cutoff = int((self.filter_length / 2 + 1))
+# fourier_basis = np.vstack([np.real(fourier_basis[:cutoff, :]),
+# np.imag(fourier_basis[:cutoff, :])])
+
+# forward_basis = torch.FloatTensor(fourier_basis[:, None, :])
+# inverse_basis = torch.FloatTensor(
+# np.linalg.pinv(scale * fourier_basis).T[:, None, :])
+
+# if window is not None:
+# assert(filter_length >= win_length)
+# # get window and zero center pad it to filter_length
+# fft_window = get_window(window, win_length, fftbins=True)
+# fft_window = pad_center(fft_window, size = filter_length)
+# fft_window = torch.from_numpy(fft_window).float()
+
+# # window the bases
+# forward_basis *= fft_window
+# inverse_basis *= fft_window
+
+# self.register_buffer('forward_basis', forward_basis.float())
+# self.register_buffer('inverse_basis', inverse_basis.float())
+
+# def transform(self, input_data):
+# num_batches = input_data.size(0)
+# num_samples = input_data.size(1)
+
+# self.num_samples = num_samples
+
+# # similar to librosa, reflect-pad the input
+# input_data = input_data.view(num_batches, 1, num_samples)
+# input_data = F.pad(
+# input_data.unsqueeze(1),
+# (int(self.filter_length / 2), int(self.filter_length / 2), 0, 0),
+# mode='reflect')
+# input_data = input_data.squeeze(1)
+
+# forward_transform = F.conv1d(
+# input_data,
+# Variable(self.forward_basis, requires_grad=False),
+# stride=self.hop_length,
+# padding=0)
+
+# cutoff = int((self.filter_length / 2) + 1)
+# real_part = forward_transform[:, :cutoff, :]
+# imag_part = forward_transform[:, cutoff:, :]
+
+# magnitude = torch.sqrt(real_part**2 + imag_part**2)
+# phase = torch.autograd.Variable(
+# torch.atan2(imag_part.data, real_part.data))
+
+# return magnitude, phase
+
+# def inverse(self, magnitude, phase):
+# recombine_magnitude_phase = torch.cat(
+# [magnitude*torch.cos(phase), magnitude*torch.sin(phase)], dim=1)
+
+# inverse_transform = F.conv_transpose1d(
+# recombine_magnitude_phase,
+# Variable(self.inverse_basis, requires_grad=False),
+# stride=self.hop_length,
+# padding=0)
+
+# if self.window is not None:
+# window_sum = window_sumsquare(
+# self.window, magnitude.size(-1), hop_length=self.hop_length,
+# win_length=self.win_length, n_fft=self.filter_length,
+# dtype=np.float32)
+# # remove modulation effects
+# approx_nonzero_indices = torch.from_numpy(
+# np.where(window_sum > tiny(window_sum))[0])
+# window_sum = torch.autograd.Variable(
+# torch.from_numpy(window_sum), requires_grad=False)
+# window_sum = window_sum.cuda() if magnitude.is_cuda else window_sum
+# inverse_transform[:, :, approx_nonzero_indices] /= window_sum[approx_nonzero_indices]
+
+# # scale by hop ratio
+# inverse_transform *= float(self.filter_length) / self.hop_length
+
+# inverse_transform = inverse_transform[:, :, int(self.filter_length/2):]
+# inverse_transform = inverse_transform[:, :, :-int(self.filter_length/2):]
+
+# return inverse_transform
+
+# def forward(self, input_data):
+# self.magnitude, self.phase = self.transform(input_data)
+# reconstruction = self.inverse(self.magnitude, self.phase)
+# return reconstruction
diff --git a/T2S/stream_generator.py b/T2S/stream_generator.py
new file mode 100755
index 0000000000000000000000000000000000000000..e468091dae76b77a585cb7c1fd569f42d2c9d314
--- /dev/null
+++ b/T2S/stream_generator.py
@@ -0,0 +1,1002 @@
+# Adapted from: https://github.com/LowinLi/transformers-stream-generator
+
+import copy
+import inspect
+import random
+import warnings
+from typing import Callable, List, Optional, Union
+
+import numpy as np
+import torch
+import torch.distributed as dist
+from torch import nn
+from transformers import (BeamSearchScorer, ConstrainedBeamSearchScorer,
+ DisjunctiveConstraint, GenerationConfig,
+ GenerationMixin, LogitsProcessorList,
+ PhrasalConstraint, PreTrainedModel,
+ StoppingCriteriaList)
+from transformers.generation.utils import GenerateOutput, SampleOutput, logger
+
+
+def setup_seed(seed):
+ if seed == -1:
+ return
+ torch.manual_seed(seed)
+ if torch.cuda.is_available():
+ torch.cuda.manual_seed_all(seed)
+ np.random.seed(seed)
+ random.seed(seed)
+ torch.backends.cudnn.deterministic = True
+
+
+class StreamGenerationConfig(GenerationConfig):
+ def __init__(self, **kwargs):
+ super().__init__(**kwargs)
+ self.do_stream = kwargs.pop("do_stream", False)
+
+
+class NewGenerationMixin(GenerationMixin):
+ @torch.no_grad()
+ def generate(
+ self,
+ inputs: Optional[torch.Tensor] = None,
+ generation_config: Optional[StreamGenerationConfig] = None,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ prefix_allowed_tokens_fn: Optional[
+ Callable[[int, torch.Tensor], List[int]]
+ ] = None,
+ synced_gpus: Optional[bool] = False,
+ seed=0,
+ **kwargs,
+ ) -> Union[GenerateOutput, torch.LongTensor]:
+ r"""
+
+ Generates sequences of token ids for models with a language modeling head.
+
+
+
+ Most generation-controlling parameters are set in `generation_config` which, if not passed, will be set to the
+ model's default generation configuration. You can override any `generation_config` by passing the corresponding
+ parameters to generate(), e.g. `.generate(inputs, num_beams=4, do_sample=True)`.
+
+ For an overview of generation strategies and code examples, check out the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ inputs (`torch.Tensor` of varying shape depending on the modality, *optional*):
+ The sequence used as a prompt for the generation or as model inputs to the encoder. If `None` the
+ method initializes it with `bos_token_id` and a batch size of 1. For decoder-only models `inputs`
+ should of in the format of `input_ids`. For encoder-decoder models *inputs* can represent any of
+ `input_ids`, `input_values`, `input_features`, or `pixel_values`.
+ generation_config (`~generation.GenerationConfig`, *optional*):
+ The generation configuration to be used as base parametrization for the generation call. `**kwargs`
+ passed to generate matching the attributes of `generation_config` will override them. If
+ `generation_config` is not provided, the default will be used, which had the following loading
+ priority: 1) from the `generation_config.json` model file, if it exists; 2) from the model
+ configuration. Please note that unspecified parameters will inherit [`~generation.GenerationConfig`]'s
+ default values, whose documentation should be checked to parameterize generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ Custom logits processors that complement the default logits processors built from arguments and
+ generation config. If a logit processor is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ Custom stopping criteria that complement the default stopping criteria built from arguments and a
+ generation config. If a stopping criteria is passed that is already created with the arguments or a
+ generation config an error is thrown. This feature is intended for advanced users.
+ prefix_allowed_tokens_fn (`Callable[[int, torch.Tensor], List[int]]`, *optional*):
+ If provided, this function constraints the beam search to allowed tokens only at each step. If not
+ provided no constraint is applied. This function takes 2 arguments: the batch ID `batch_id` and
+ `input_ids`. It has to return a list with the allowed tokens for the next generation step conditioned
+ on the batch ID `batch_id` and the previously generated tokens `inputs_ids`. This argument is useful
+ for constrained generation conditioned on the prefix, as described in [Autoregressive Entity
+ Retrieval](https://arxiv.org/abs/2010.00904).
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ kwargs:
+ Ad hoc parametrization of `generate_config` and/or additional model-specific kwargs that will be
+ forwarded to the `forward` function of the model. If the model is an encoder-decoder model, encoder
+ specific kwargs should not be prefixed and decoder specific kwargs should be prefixed with *decoder_*.
+
+ Return:
+ [`~utils.ModelOutput`] or `torch.LongTensor`: A [`~utils.ModelOutput`] (if `return_dict_in_generate=True`
+ or when `config.return_dict_in_generate=True`) or a `torch.FloatTensor`.
+
+ If the model is *not* an encoder-decoder model (`model.config.is_encoder_decoder=False`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchDecoderOnlyOutput`],
+ - [`~generation.SampleDecoderOnlyOutput`],
+ - [`~generation.BeamSearchDecoderOnlyOutput`],
+ - [`~generation.BeamSampleDecoderOnlyOutput`]
+
+ If the model is an encoder-decoder model (`model.config.is_encoder_decoder=True`), the possible
+ [`~utils.ModelOutput`] types are:
+
+ - [`~generation.GreedySearchEncoderDecoderOutput`],
+ - [`~generation.SampleEncoderDecoderOutput`],
+ - [`~generation.BeamSearchEncoderDecoderOutput`],
+ - [`~generation.BeamSampleEncoderDecoderOutput`]
+ """
+ # setup_seed(seed)
+ # 1. Handle `generation_config` and kwargs that might update it, and validate the `.generate()` call
+ self._validate_model_class()
+
+ # priority: `generation_config` argument > `model.generation_config` (the default generation config)
+ if generation_config is None:
+ # legacy: users may modify the model configuration to control generation -- update the generation config
+ # model attribute accordingly, if it was created from the model config
+ if self.generation_config._from_model_config:
+ new_generation_config = StreamGenerationConfig.from_model_config(
+ self.config
+ )
+ if new_generation_config != self.generation_config:
+ warnings.warn(
+ "You have modified the pretrained model configuration to control generation. This is a"
+ " deprecated strategy to control generation and will be removed soon, in a future version."
+ " Please use a generation configuration file (see"
+ " https://huggingface.co/docs/transformers/main_classes/text_generation)"
+ )
+ self.generation_config = new_generation_config
+ generation_config = self.generation_config
+
+ generation_config = copy.deepcopy(generation_config)
+ model_kwargs = generation_config.update(
+ **kwargs
+ ) # All unused kwargs must be model kwargs
+ # self._validate_model_kwargs(model_kwargs.copy())
+
+ # 2. Set generation parameters if not already defined
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+
+ if (
+ generation_config.pad_token_id is None
+ and generation_config.eos_token_id is not None
+ ):
+ if model_kwargs.get("attention_mask", None) is None:
+ logger.warning(
+ "The attention mask and the pad token id were not set. As a consequence, you may observe "
+ "unexpected behavior. Please pass your input's `attention_mask` to obtain reliable results."
+ )
+ eos_token_id = generation_config.eos_token_id
+ if isinstance(eos_token_id, list):
+ eos_token_id = eos_token_id[0]
+ logger.warning(
+ f"Setting `pad_token_id` to `eos_token_id`:{eos_token_id} for open-end generation."
+ )
+ generation_config.pad_token_id = eos_token_id
+
+ # 3. Define model inputs
+ # inputs_tensor has to be defined
+ # model_input_name is defined if model-specific keyword input is passed
+ # otherwise model_input_name is None
+ # all model-specific keyword inputs are removed from `model_kwargs`
+ inputs_tensor, model_input_name, model_kwargs = self._prepare_model_inputs(
+ inputs, generation_config.bos_token_id, model_kwargs
+ )
+ batch_size = inputs_tensor.shape[0]
+
+ # 4. Define other model kwargs
+ model_kwargs["output_attentions"] = generation_config.output_attentions
+ model_kwargs["output_hidden_states"] = generation_config.output_hidden_states
+ model_kwargs["use_cache"] = generation_config.use_cache
+
+ accepts_attention_mask = "attention_mask" in set(
+ inspect.signature(self.forward).parameters.keys()
+ )
+ requires_attention_mask = "encoder_outputs" not in model_kwargs
+
+ if (
+ model_kwargs.get("attention_mask", None) is None
+ and requires_attention_mask
+ and accepts_attention_mask
+ ):
+ # print(generation_config.pad_token_id,generation_config.eos_token_id,inputs_tensor)
+ model_kwargs["attention_mask"] = (
+ self._prepare_attention_mask_for_generation(
+ inputs_tensor,
+ generation_config.pad_token_id,
+ generation_config.eos_token_id,
+ )
+ )
+
+ # decoder-only models should use left-padding for generation
+ if not self.config.is_encoder_decoder:
+ if (
+ generation_config.pad_token_id is not None
+ and torch.sum(inputs_tensor[:, -1] == generation_config.pad_token_id)
+ > 0
+ ):
+ logger.warning(
+ "A decoder-only architecture is being used, but right-padding was detected! For correct "
+ "generation results, please set `padding_side='left'` when initializing the tokenizer."
+ )
+
+ if self.config.is_encoder_decoder and "encoder_outputs" not in model_kwargs:
+ # if model is encoder decoder encoder_outputs are created
+ # and added to `model_kwargs`
+ model_kwargs = self._prepare_encoder_decoder_kwargs_for_generation(
+ inputs_tensor, model_kwargs, model_input_name
+ )
+
+ # 5. Prepare `input_ids` which will be used for auto-regressive generation
+ if self.config.is_encoder_decoder:
+ input_ids = self._prepare_decoder_input_ids_for_generation(
+ batch_size,
+ decoder_start_token_id=generation_config.decoder_start_token_id,
+ bos_token_id=generation_config.bos_token_id,
+ model_kwargs=model_kwargs,
+ device=inputs_tensor.device,
+ )
+ else:
+ # if decoder-only then inputs_tensor has to be `input_ids`
+ input_ids = inputs_tensor
+
+ # 6. Prepare `max_length` depending on other stopping criteria.
+ input_ids_seq_length = input_ids.shape[-1]
+ has_default_max_length = (
+ kwargs.get("max_length") is None
+ and generation_config.max_length is not None
+ )
+ if has_default_max_length and generation_config.max_new_tokens is None:
+ warnings.warn(
+ "Neither `max_length` nor `max_new_tokens` has been set, `max_length` will default to"
+ f" {generation_config.max_length} (`generation_config.max_length`). Controlling `max_length` via the"
+ " config is deprecated and `max_length` will be removed from the config in v5 of Transformers -- we"
+ " recommend using `max_new_tokens` to control the maximum length of the generation.",
+ UserWarning,
+ )
+ elif has_default_max_length and generation_config.max_new_tokens is not None:
+ generation_config.max_length = (
+ generation_config.max_new_tokens + input_ids_seq_length
+ )
+ elif (
+ not has_default_max_length and generation_config.max_new_tokens is not None
+ ):
+ raise ValueError(
+ "Both `max_new_tokens` and `max_length` have been set but they serve the same purpose -- setting a"
+ " limit to the generated output length. Remove one of those arguments. Please refer to the"
+ " documentation for more information. "
+ "(https://huggingface.co/docs/transformers/main/en/main_classes/text_generation)"
+ )
+
+ if (
+ generation_config.min_length is not None
+ and generation_config.min_length > generation_config.max_length
+ ):
+ raise ValueError(
+ f"Unfeasible length constraints: the minimum length ({generation_config.min_length}) is larger than"
+ f" the maximum length ({generation_config.max_length})"
+ )
+ if input_ids_seq_length >= generation_config.max_length:
+ input_ids_string = (
+ "decoder_input_ids" if self.config.is_encoder_decoder else "input_ids"
+ )
+ logger.warning(
+ f"Input length of {input_ids_string} is {input_ids_seq_length}, but `max_length` is set to"
+ f" {generation_config.max_length}. This can lead to unexpected behavior. You should consider"
+ " increasing `max_new_tokens`."
+ )
+
+ # 7. determine generation mode
+ is_constraint_gen_mode = (
+ generation_config.constraints is not None
+ or generation_config.force_words_ids is not None
+ )
+
+ is_contrastive_search_gen_mode = (
+ generation_config.top_k is not None
+ and generation_config.top_k > 1
+ and generation_config.do_sample is False
+ and generation_config.penalty_alpha is not None
+ and generation_config.penalty_alpha > 0
+ )
+
+ is_greedy_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_sample_gen_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and generation_config.do_stream is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_sample_gen_stream_mode = (
+ (generation_config.num_beams == 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_stream is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is False
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_beam_sample_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups == 1)
+ and generation_config.do_sample is True
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+ is_group_beam_gen_mode = (
+ (generation_config.num_beams > 1)
+ and (generation_config.num_beam_groups > 1)
+ and not is_constraint_gen_mode
+ and not is_contrastive_search_gen_mode
+ )
+
+ if generation_config.num_beam_groups > generation_config.num_beams:
+ raise ValueError(
+ "`num_beam_groups` has to be smaller or equal to `num_beams`"
+ )
+ if is_group_beam_gen_mode and generation_config.do_sample is True:
+ raise ValueError(
+ "Diverse beam search cannot be used in sampling mode. Make sure that `do_sample` is set to `False`."
+ )
+
+ if self.device.type != input_ids.device.type:
+ warnings.warn(
+ "You are calling .generate() with the `input_ids` being on a device type different"
+ f" than your model's device. `input_ids` is on {input_ids.device.type}, whereas the model"
+ f" is on {self.device.type}. You may experience unexpected behaviors or slower generation."
+ " Please make sure that you have put `input_ids` to the"
+ f" correct device by calling for example input_ids = input_ids.to('{self.device.type}') before"
+ " running `.generate()`.",
+ UserWarning,
+ )
+ # 8. prepare distribution pre_processing samplers
+ logits_processor = self._get_logits_processor(
+ generation_config=generation_config,
+ input_ids_seq_length=input_ids_seq_length,
+ encoder_input_ids=inputs_tensor,
+ prefix_allowed_tokens_fn=prefix_allowed_tokens_fn,
+ logits_processor=logits_processor,
+ )
+
+ # 9. prepare stopping criteria
+ stopping_criteria = self._get_stopping_criteria(
+ generation_config=generation_config, stopping_criteria=stopping_criteria
+ )
+ # 10. go into different generation modes
+ if is_greedy_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " greedy search."
+ )
+
+ # 11. run greedy search
+ return self.greedy_search(
+ input_ids,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_contrastive_search_gen_mode:
+ if generation_config.num_return_sequences > 1:
+ raise ValueError(
+ f"num_return_sequences has to be 1, but is {generation_config.num_return_sequences} when doing"
+ " contrastive search."
+ )
+
+ return self.contrastive_search(
+ input_ids,
+ top_k=generation_config.top_k,
+ penalty_alpha=generation_config.penalty_alpha,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+ elif is_sample_gen_stream_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ # 12. expand input_ids with `num_return_sequences` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 13. run sample
+ return self.sample_stream(
+ input_ids,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+ elif is_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_beam_sample_gen_mode:
+ # 11. prepare logits warper
+ logits_warper = self._get_logits_warper(generation_config)
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+ # 12. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size * generation_config.num_return_sequences,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ )
+
+ # 13. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams
+ * generation_config.num_return_sequences,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+
+ # 14. run beam sample
+ return self.beam_sample(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ logits_warper=logits_warper,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_group_beam_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if generation_config.num_beams % generation_config.num_beam_groups != 0:
+ raise ValueError(
+ "`num_beams` should be divisible by `num_beam_groups` for group beam search."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ has_default_typical_p = (
+ kwargs.get("typical_p") is None and generation_config.typical_p == 1.0
+ )
+ if not has_default_typical_p:
+ raise ValueError(
+ "Decoder argument `typical_p` is not supported with beam groups."
+ )
+
+ # 11. prepare beam search scorer
+ beam_scorer = BeamSearchScorer(
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ max_length=stopping_criteria.max_length,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ num_beam_groups=generation_config.num_beam_groups,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.group_beam_search(
+ input_ids,
+ beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ elif is_constraint_gen_mode:
+ if generation_config.num_return_sequences > generation_config.num_beams:
+ raise ValueError(
+ "`num_return_sequences` has to be smaller or equal to `num_beams`."
+ )
+
+ if stopping_criteria.max_length is None:
+ raise ValueError(
+ "`max_length` needs to be a stopping_criteria for now."
+ )
+
+ if generation_config.num_beams <= 1:
+ raise ValueError(
+ "`num_beams` needs to be greater than 1 for constrained generation."
+ )
+
+ if generation_config.do_sample:
+ raise ValueError(
+ "`do_sample` needs to be false for constrained generation."
+ )
+
+ if (
+ generation_config.num_beam_groups is not None
+ and generation_config.num_beam_groups > 1
+ ):
+ raise ValueError(
+ "`num_beam_groups` not supported yet for constrained generation."
+ )
+
+ final_constraints = []
+ if generation_config.constraints is not None:
+ final_constraints = generation_config.constraints
+
+ if generation_config.force_words_ids is not None:
+
+ def typeerror():
+ raise ValueError(
+ "`force_words_ids` has to either be a `List[List[List[int]]]` or `List[List[int]]`"
+ f"of positive integers, but is {generation_config.force_words_ids}."
+ )
+
+ if (
+ not isinstance(generation_config.force_words_ids, list)
+ or len(generation_config.force_words_ids) == 0
+ ):
+ typeerror()
+
+ for word_ids in generation_config.force_words_ids:
+ if isinstance(word_ids[0], list):
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ not isinstance(token_ids, list) for token_ids in word_ids
+ ):
+ typeerror()
+ if any(
+ any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in token_ids
+ )
+ for token_ids in word_ids
+ ):
+ typeerror()
+
+ constraint = DisjunctiveConstraint(word_ids)
+ else:
+ if not isinstance(word_ids, list) or len(word_ids) == 0:
+ typeerror()
+ if any(
+ (not isinstance(token_id, int) or token_id < 0)
+ for token_id in word_ids
+ ):
+ typeerror()
+
+ constraint = PhrasalConstraint(word_ids)
+ final_constraints.append(constraint)
+
+ # 11. prepare beam search scorer
+ constrained_beam_scorer = ConstrainedBeamSearchScorer(
+ constraints=final_constraints,
+ batch_size=batch_size,
+ num_beams=generation_config.num_beams,
+ device=inputs_tensor.device,
+ length_penalty=generation_config.length_penalty,
+ do_early_stopping=generation_config.early_stopping,
+ num_beam_hyps_to_keep=generation_config.num_return_sequences,
+ )
+ # 12. interleave input_ids with `num_beams` additional sequences per batch
+ input_ids, model_kwargs = self._expand_inputs_for_generation(
+ input_ids=input_ids,
+ expand_size=generation_config.num_beams,
+ is_encoder_decoder=self.config.is_encoder_decoder,
+ **model_kwargs,
+ )
+ # 13. run beam search
+ return self.constrained_beam_search(
+ input_ids,
+ constrained_beam_scorer=constrained_beam_scorer,
+ logits_processor=logits_processor,
+ stopping_criteria=stopping_criteria,
+ pad_token_id=generation_config.pad_token_id,
+ eos_token_id=generation_config.eos_token_id,
+ output_scores=generation_config.output_scores,
+ return_dict_in_generate=generation_config.return_dict_in_generate,
+ synced_gpus=synced_gpus,
+ **model_kwargs,
+ )
+
+ @torch.no_grad()
+ def sample_stream(
+ self,
+ input_ids: torch.LongTensor,
+ logits_processor: Optional[LogitsProcessorList] = None,
+ stopping_criteria: Optional[StoppingCriteriaList] = None,
+ logits_warper: Optional[LogitsProcessorList] = None,
+ max_length: Optional[int] = None,
+ pad_token_id: Optional[int] = None,
+ eos_token_id: Optional[Union[int, List[int]]] = None,
+ output_attentions: Optional[bool] = None,
+ output_hidden_states: Optional[bool] = None,
+ output_scores: Optional[bool] = None,
+ return_dict_in_generate: Optional[bool] = None,
+ synced_gpus: Optional[bool] = False,
+ **model_kwargs,
+ ) -> Union[SampleOutput, torch.LongTensor]:
+ r"""
+ Generates sequences of token ids for models with a language modeling head using **multinomial sampling** and
+ can be used for text-decoder, text-to-text, speech-to-text, and vision-to-text models.
+
+
+
+ In most cases, you do not need to call [`~generation.GenerationMixin.sample`] directly. Use generate() instead.
+ For an overview of generation strategies and code examples, check the [following
+ guide](./generation_strategies).
+
+
+
+ Parameters:
+ input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
+ The sequence used as a prompt for the generation.
+ logits_processor (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsProcessor`]
+ used to modify the prediction scores of the language modeling head applied at each generation step.
+ stopping_criteria (`StoppingCriteriaList`, *optional*):
+ An instance of [`StoppingCriteriaList`]. List of instances of class derived from [`StoppingCriteria`]
+ used to tell if the generation loop should stop.
+ logits_warper (`LogitsProcessorList`, *optional*):
+ An instance of [`LogitsProcessorList`]. List of instances of class derived from [`LogitsWarper`] used
+ to warp the prediction score distribution of the language modeling head applied before multinomial
+ sampling at each generation step.
+ max_length (`int`, *optional*, defaults to 20):
+ **DEPRECATED**. Use `logits_processor` or `stopping_criteria` directly to cap the number of generated
+ tokens. The maximum length of the sequence to be generated.
+ pad_token_id (`int`, *optional*):
+ The id of the *padding* token.
+ eos_token_id (`int`, *optional*):
+ The id of the *end-of-sequence* token.
+ output_attentions (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the attentions tensors of all attention layers. See `attentions` under
+ returned tensors for more details.
+ output_hidden_states (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the hidden states of all layers. See `hidden_states` under returned tensors
+ for more details.
+ output_scores (`bool`, *optional*, defaults to `False`):
+ Whether or not to return the prediction scores. See `scores` under returned tensors for more details.
+ return_dict_in_generate (`bool`, *optional*, defaults to `False`):
+ Whether or not to return a [`~utils.ModelOutput`] instead of a plain tuple.
+ synced_gpus (`bool`, *optional*, defaults to `False`):
+ Whether to continue running the while loop until max_length (needed for ZeRO stage 3)
+ model_kwargs:
+ Additional model specific kwargs will be forwarded to the `forward` function of the model. If model is
+ an encoder-decoder model the kwargs should include `encoder_outputs`.
+
+ Return:
+ [`~generation.SampleDecoderOnlyOutput`], [`~generation.SampleEncoderDecoderOutput`] or `torch.LongTensor`:
+ A `torch.LongTensor` containing the generated tokens (default behaviour) or a
+ [`~generation.SampleDecoderOnlyOutput`] if `model.config.is_encoder_decoder=False` and
+ `return_dict_in_generate=True` or a [`~generation.SampleEncoderDecoderOutput`] if
+ `model.config.is_encoder_decoder=True`.
+
+ Examples:
+
+ ```python
+ >>> from transformers import (
+ ... AutoTokenizer,
+ ... AutoModelForCausalLM,
+ ... LogitsProcessorList,
+ ... MinLengthLogitsProcessor,
+ ... TopKLogitsWarper,
+ ... TemperatureLogitsWarper,
+ ... StoppingCriteriaList,
+ ... MaxLengthCriteria,
+ ... )
+ >>> import torch
+
+ >>> tokenizer = AutoTokenizer.from_pretrained("gpt2")
+ >>> model = AutoModelForCausalLM.from_pretrained("gpt2")
+
+ >>> # set pad_token_id to eos_token_id because GPT2 does not have a EOS token
+ >>> model.config.pad_token_id = model.config.eos_token_id
+ >>> model.generation_config.pad_token_id = model.config.eos_token_id
+
+ >>> input_prompt = "Today is a beautiful day, and"
+ >>> input_ids = tokenizer(input_prompt, return_tensors="pt").input_ids
+
+ >>> # instantiate logits processors
+ >>> logits_processor = LogitsProcessorList(
+ ... [
+ ... MinLengthLogitsProcessor(15, eos_token_id=model.generation_config.eos_token_id),
+ ... ]
+ ... )
+ >>> # instantiate logits processors
+ >>> logits_warper = LogitsProcessorList(
+ ... [
+ ... TopKLogitsWarper(50),
+ ... TemperatureLogitsWarper(0.7),
+ ... ]
+ ... )
+
+ >>> stopping_criteria = StoppingCriteriaList([MaxLengthCriteria(max_length=20)])
+
+ >>> torch.manual_seed(0) # doctest: +IGNORE_RESULT
+ >>> outputs = model.sample(
+ ... input_ids,
+ ... logits_processor=logits_processor,
+ ... logits_warper=logits_warper,
+ ... stopping_criteria=stopping_criteria,
+ ... )
+
+ >>> tokenizer.batch_decode(outputs, skip_special_tokens=True)
+ ['Today is a beautiful day, and a wonderful day.\n\nI was lucky enough to meet the']
+ ```"""
+ # init values
+ logits_processor = (
+ logits_processor if logits_processor is not None else LogitsProcessorList()
+ )
+ stopping_criteria = (
+ stopping_criteria
+ if stopping_criteria is not None
+ else StoppingCriteriaList()
+ )
+ if max_length is not None:
+ warnings.warn(
+ "`max_length` is deprecated in this function, use"
+ " `stopping_criteria=StoppingCriteriaList(MaxLengthCriteria(max_length=max_length))` instead.",
+ UserWarning,
+ )
+ stopping_criteria = validate_stopping_criteria(
+ stopping_criteria, max_length
+ )
+ logits_warper = (
+ logits_warper if logits_warper is not None else LogitsProcessorList()
+ )
+ pad_token_id = (
+ pad_token_id
+ if pad_token_id is not None
+ else self.generation_config.pad_token_id
+ )
+ eos_token_id = (
+ eos_token_id
+ if eos_token_id is not None
+ else self.generation_config.eos_token_id
+ )
+ if isinstance(eos_token_id, int):
+ eos_token_id = [eos_token_id]
+ output_scores = (
+ output_scores
+ if output_scores is not None
+ else self.generation_config.output_scores
+ )
+ output_attentions = (
+ output_attentions
+ if output_attentions is not None
+ else self.generation_config.output_attentions
+ )
+ output_hidden_states = (
+ output_hidden_states
+ if output_hidden_states is not None
+ else self.generation_config.output_hidden_states
+ )
+ return_dict_in_generate = (
+ return_dict_in_generate
+ if return_dict_in_generate is not None
+ else self.generation_config.return_dict_in_generate
+ )
+
+ # init attention / hidden states / scores tuples
+ scores = () if (return_dict_in_generate and output_scores) else None
+ decoder_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ cross_attentions = (
+ () if (return_dict_in_generate and output_attentions) else None
+ )
+ decoder_hidden_states = (
+ () if (return_dict_in_generate and output_hidden_states) else None
+ )
+
+ # keep track of which sequences are already finished
+ unfinished_sequences = input_ids.new(input_ids.shape[0]).fill_(1)
+
+ this_peer_finished = False # used by synced_gpus only
+ # auto-regressive generation
+ while True:
+ if synced_gpus:
+ # Under synced_gpus the `forward` call must continue until all gpus complete their sequence.
+ # The following logic allows an early break if all peers finished generating their sequence
+ this_peer_finished_flag = torch.tensor(
+ 0.0 if this_peer_finished else 1.0
+ ).to(input_ids.device)
+ # send 0.0 if we finished, 1.0 otherwise
+ dist.all_reduce(this_peer_finished_flag, op=dist.ReduceOp.SUM)
+ # did all peers finish? the reduced sum will be 0.0 then
+ if this_peer_finished_flag.item() == 0.0:
+ break
+
+ # prepare model inputs
+ model_inputs = self.prepare_inputs_for_generation(input_ids, **model_kwargs)
+
+ # forward pass to get next token
+ outputs = self(
+ **model_inputs,
+ return_dict=True,
+ output_attentions=output_attentions,
+ output_hidden_states=output_hidden_states,
+ )
+
+ if synced_gpus and this_peer_finished:
+ continue # don't waste resources running the code we don't need
+
+ next_token_logits = outputs.logits[:, -1, :]
+
+ # pre-process distribution
+ next_token_scores = logits_processor(input_ids, next_token_logits)
+ next_token_scores = logits_warper(input_ids, next_token_scores)
+
+ # Store scores, attentions and hidden_states when required
+ if return_dict_in_generate:
+ if output_scores:
+ scores += (next_token_scores,)
+ if output_attentions:
+ decoder_attentions += (
+ (outputs.decoder_attentions,)
+ if self.config.is_encoder_decoder
+ else (outputs.attentions,)
+ )
+ if self.config.is_encoder_decoder:
+ cross_attentions += (outputs.cross_attentions,)
+
+ if output_hidden_states:
+ decoder_hidden_states += (
+ (outputs.decoder_hidden_states,)
+ if self.config.is_encoder_decoder
+ else (outputs.hidden_states,)
+ )
+
+ # sample
+ probs = nn.functional.softmax(next_token_scores, dim=-1)
+ next_tokens = torch.multinomial(probs, num_samples=1).squeeze(1)
+
+ # finished sentences should have their next token be a padding token
+ if eos_token_id is not None:
+ if pad_token_id is None:
+ raise ValueError(
+ "If `eos_token_id` is defined, make sure that `pad_token_id` is defined."
+ )
+ next_tokens = next_tokens * unfinished_sequences + pad_token_id * (
+ 1 - unfinished_sequences
+ )
+ yield next_tokens, self.final_norm(outputs.hidden_states[-1][:, -1])
+ # update generated ids, model inputs, and length for next step
+ input_ids = torch.cat([input_ids, next_tokens[:, None]], dim=-1)
+ model_kwargs = self._update_model_kwargs_for_generation(
+ outputs, model_kwargs, is_encoder_decoder=self.config.is_encoder_decoder
+ )
+
+ # if eos_token was found in one sentence, set sentence to finished
+ if eos_token_id is not None:
+ unfinished_sequences = unfinished_sequences.mul(
+ (sum(next_tokens != i for i in eos_token_id)).long()
+ )
+
+ # stop when each sentence is finished, or if we exceed the maximum length
+ if unfinished_sequences.max() == 0 or stopping_criteria(input_ids, scores):
+ if not synced_gpus:
+ break
+ else:
+ this_peer_finished = True
+
+
+def init_stream_support():
+ """Overload PreTrainedModel for streaming."""
+ PreTrainedModel.generate_stream = NewGenerationMixin.generate
+ PreTrainedModel.sample_stream = NewGenerationMixin.sample_stream
diff --git a/T2S/t2s_modules.py b/T2S/t2s_modules.py
new file mode 100755
index 0000000000000000000000000000000000000000..26c897594ebc4fa35d03d9ebd644ae00e2ba960b
--- /dev/null
+++ b/T2S/t2s_modules.py
@@ -0,0 +1,609 @@
+import math
+
+import torch
+import torch.nn as nn
+import torch.nn.functional as F
+import torch.nn.init as init
+from einops import rearrange, repeat
+
+
+def zero_module(module):
+ """
+ Zero out the parameters of a module and return it.
+ Using it for Zero Convolutions
+ """
+ for p in module.parameters():
+ p.detach().zero_()
+ return module
+
+
+class GroupNorm32(nn.GroupNorm):
+ def forward(self, x):
+ return super().forward(x.float()).type(x.dtype)
+
+
+def normalization(channels):
+ """
+ Make a standard normalization layer. of groups ranging from 2 to 32.
+
+ :param channels: number of input channels.
+ :return: an nn.Module for normalization.
+ """
+ groups = 32
+ if channels <= 16:
+ groups = 8
+ elif channels <= 64:
+ groups = 16
+ while channels % groups != 0:
+ groups = int(groups / 2)
+ assert groups > 2
+ return GroupNorm32(groups, channels)
+
+
+class mySequential(nn.Sequential):
+ """Using this to pass mask variable to nn layers"""
+
+ def forward(self, *inputs):
+ for module in self._modules.values():
+ if type(inputs) == tuple:
+ inputs = module(*inputs)
+ else:
+ inputs = module(inputs)
+ return inputs
+
+
+class SepConv1D(nn.Module):
+ """Depth wise separable Convolution layer with mask"""
+
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ padding_mode="same",
+ bias=True,
+ ):
+ super(SepConv1D, self).__init__()
+ self.conv1 = nn.Conv1d(
+ nin,
+ nin,
+ kernel_size=kernel_size,
+ stride=stride,
+ groups=nin,
+ dilation=dilation,
+ padding=padding_mode,
+ bias=bias,
+ )
+ self.conv2 = nn.Conv1d(
+ nin, nout, kernel_size=1, stride=1, padding=padding_mode, bias=bias
+ )
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x = self.conv1(x)
+ x = self.conv2(x)
+ return x, mask
+
+
+class Conv1DBN(nn.Module):
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ stride=1,
+ dilation=1,
+ dropout=0.1,
+ padding_mode="same",
+ bias=False,
+ ):
+ super(Conv1DBN, self).__init__()
+ self.conv1 = nn.Conv1d(
+ nin,
+ nout,
+ kernel_size=kernel_size,
+ stride=stride,
+ padding=padding_mode,
+ dilation=dilation,
+ bias=bias,
+ )
+ self.bn = nn.BatchNorm1d(nout)
+ self.drop = nn.Dropout(dropout)
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x = self.conv1(x)
+ x = self.bn(x)
+ x = F.relu(x)
+ x = self.drop(x)
+ return x, mask
+
+
+class Conv1d(nn.Module):
+ """normal conv1d with mask"""
+
+ def __init__(self, nin, nout, kernel_size, padding, bias=True):
+ super(Conv1d, self).__init__()
+ self.l = nn.Conv1d(nin, nout, kernel_size, padding=padding, bias=bias)
+
+ def forward(self, x, mask):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ y = self.l(x)
+ return y, mask
+
+
+class SqueezeExcite(nn.Module):
+ """Let the CNN decide how to add across channels"""
+
+ def __init__(self, nin, ratio=8):
+ super(SqueezeExcite, self).__init__()
+ self.nin = nin
+ self.ratio = ratio
+
+ self.fc = mySequential(
+ nn.Linear(nin, nin // ratio, bias=True),
+ nn.SiLU(inplace=True),
+ nn.Linear(nin // ratio, nin, bias=True),
+ )
+
+ def forward(self, x, mask=None):
+ if mask is None:
+ mask = torch.ones((x.shape[0], x.shape[-1]), dtype=torch.bool).to(x.device)
+ mask = ~mask
+ x = x.float()
+ x.masked_fill_(mask.unsqueeze(1), 0.0)
+ mask = ~mask
+ y = (
+ torch.sum(x, dim=-1, keepdim=True)
+ / mask.unsqueeze(1).sum(dim=-1, keepdim=True)
+ ).type(x.dtype)
+ # y=torch.mean(x,-1,keepdim=True)
+ y = y.transpose(1, -1)
+ y = self.fc(y)
+ y = torch.sigmoid(y)
+ y = y.transpose(1, -1)
+ y = x * y
+ return y, mask
+
+
+class SCBD(nn.Module):
+ """SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet"""
+
+ def __init__(
+ self, nin, nout, kernel_size, p=0.1, rd=True, separable=True, bias=True
+ ):
+ super(SCBD, self).__init__()
+ if separable:
+ self.SC = SepConv1D(nin, nout, kernel_size, bias=bias)
+ else:
+ self.SC = Conv1d(nin, nout, kernel_size, padding="same", bias=bias)
+
+ if rd: # relu and Dropout
+ self.mout = mySequential(
+ normalization(nout),
+ nn.SiLU(), # nn.BatchNorm1d(nout,eps)
+ nn.Dropout(p),
+ )
+ else:
+ self.mout = normalization(nout) # nn.BatchNorm1d(nout,eps)
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ x, _ = self.SC(x, mask)
+ y = self.mout(x)
+ return y, mask
+
+
+class QuartzNetBlock(nn.Module):
+ """Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
+ if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
+ """
+
+ def __init__(
+ self,
+ nin,
+ nout,
+ kernel_size,
+ dropout=0.1,
+ R=5,
+ se=False,
+ ratio=8,
+ separable=False,
+ bias=True,
+ ):
+ super(QuartzNetBlock, self).__init__()
+ self.se = se
+ self.residual = mySequential(
+ nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
+ normalization(nout), # nn.BatchNorm1d(nout,eps)
+ )
+ model = []
+
+ for i in range(R - 1):
+ model.append(SCBD(nin, nout, kernel_size, dropout, eps=0.001, bias=bias))
+ nin = nout
+
+ if separable:
+ model.append(
+ SCBD(nin, nout, kernel_size, dropout, eps=0.001, rd=False, bias=bias)
+ )
+ else:
+ model.append(
+ SCBD(
+ nin,
+ nout,
+ kernel_size,
+ dropout,
+ eps=0.001,
+ rd=False,
+ separable=False,
+ bias=bias,
+ )
+ )
+ self.model = mySequential(*model)
+
+ if self.se:
+ # model.append(SqueezeExcite(nin,ratio))
+ self.se_layer = SqueezeExcite(nin, ratio)
+
+ self.mout = mySequential(nn.SiLU(), nn.Dropout(dropout))
+
+ def forward(self, x, mask=None):
+ if mask is not None:
+ x = x * mask.unsqueeze(1).to(device=x.device)
+ y, _ = self.model(x, mask)
+ if self.se:
+ y, _ = self.se_layer(y, mask)
+ y += self.residual(x)
+ y = self.mout(y)
+ return y, mask
+
+
+class QKVAttentionLegacy(nn.Module):
+ """
+ A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
+ """
+
+ def __init__(self, n_heads):
+ super().__init__()
+ self.n_heads = n_heads
+
+ def forward(self, qkv, mask=None, rel_pos=None):
+ """
+ Apply QKV attention.
+
+ :param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
+ :return: an [N x (H * C) x T] tensor after attention.
+ """
+ bs, width, length = qkv.shape
+ assert width % (3 * self.n_heads) == 0
+ ch = width // (3 * self.n_heads)
+ q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
+ scale = 1 / math.sqrt(math.sqrt(ch))
+ weight = torch.einsum(
+ "bct,bcs->bts", q * scale, k * scale
+ ) # More stable with f16 than dividing afterwards
+ if rel_pos is not None:
+ weight = rel_pos(
+ weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
+ ).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
+ weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
+ if mask is not None:
+ # The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
+ mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
+ weight = weight * mask
+ a = torch.einsum("bts,bcs->bct", weight, v)
+
+ return a.reshape(bs, -1, length)
+
+
+class AttentionBlock(nn.Module):
+ """
+ An attention block that allows spatial positions to attend to each other.
+
+ Originally ported from here, but adapted to the N-d case.
+ https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
+ """
+
+ def __init__(
+ self,
+ channels,
+ num_heads=1,
+ num_head_channels=-1,
+ do_checkpoint=True,
+ relative_pos_embeddings=False,
+ ):
+ super().__init__()
+ self.channels = channels
+ self.do_checkpoint = do_checkpoint
+ if num_head_channels == -1:
+ self.num_heads = num_heads
+ else:
+ assert channels % num_head_channels == 0, (
+ f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
+ )
+ self.num_heads = channels // num_head_channels
+ self.norm = normalization(channels)
+ self.qkv = nn.Conv1d(channels, channels * 3, 1)
+ # split heads before split qkv
+ self.attention = QKVAttentionLegacy(self.num_heads)
+
+ self.proj_out = zero_module(
+ nn.Conv1d(channels, channels, 1)
+ ) # no effect of attention in the inital stages.
+ # if relative_pos_embeddings:
+ self.relative_pos_embeddings = RelativePositionBias(
+ scale=(channels // self.num_heads) ** 0.5,
+ causal=False,
+ heads=num_heads,
+ num_buckets=32,
+ max_distance=64,
+ ) # need to read about this, vit and swin transformers
+ # self.relative_pos_embeddings = FixedPositionalEmbedding(dim=channels)
+ # else:
+ # self.relative_pos_embeddings = None
+
+ def forward(self, x, mask=None):
+ b, c, *spatial = x.shape
+ x = x.reshape(b, c, -1)
+ qkv = self.qkv(self.norm(x))
+ h = self.attention(qkv, mask, self.relative_pos_embeddings)
+ h = self.proj_out(h)
+ return (x + h).reshape(b, c, *spatial)
+
+
+class AbsolutePositionalEmbedding(nn.Module):
+ def __init__(self, dim, max_seq_len):
+ super().__init__()
+ self.scale = dim**-0.5
+ self.emb = nn.Embedding(max_seq_len, dim)
+
+ def forward(self, x):
+ n = torch.arange(x.shape[1], device=x.device)
+ pos_emb = self.emb(n)
+ pos_emb = rearrange(pos_emb, "n d -> () n d")
+ return pos_emb * self.scale
+
+
+class FixedPositionalEmbedding(nn.Module):
+ def __init__(self, dim):
+ super().__init__()
+ inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
+ self.register_buffer("inv_freq", inv_freq)
+
+ def forward(self, x, seq_dim=1, offset=0):
+ t = (
+ torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ + offset
+ )
+ sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
+ emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
+ return rearrange(emb, "n d -> () n d")
+
+
+class RelativePositionBias(nn.Module):
+ def __init__(self, scale, causal=False, num_buckets=32, max_distance=128, heads=8):
+ super().__init__()
+ self.scale = scale
+ self.causal = causal
+ self.num_buckets = num_buckets
+ self.max_distance = max_distance
+ self.relative_attention_bias = nn.Embedding(num_buckets, heads)
+
+ @staticmethod
+ def _relative_position_bucket(
+ relative_position, causal=True, num_buckets=32, max_distance=128
+ ):
+ ret = 0
+ n = -relative_position
+ if not causal:
+ num_buckets //= 2
+ ret += (n < 0).long() * num_buckets
+ n = torch.abs(n)
+ else:
+ n = torch.max(n, torch.zeros_like(n))
+
+ max_exact = num_buckets // 2
+ is_small = n < max_exact
+
+ val_if_large = (
+ max_exact
+ + (
+ torch.log(n.float() / max_exact)
+ / math.log(max_distance / max_exact)
+ * (num_buckets - max_exact)
+ ).long()
+ )
+ val_if_large = torch.min(
+ val_if_large, torch.full_like(val_if_large, num_buckets - 1)
+ )
+
+ ret += torch.where(is_small, n, val_if_large)
+ return ret
+
+ def forward(self, qk_dots):
+ i, j, device = *qk_dots.shape[-2:], qk_dots.device
+ q_pos = torch.arange(i, dtype=torch.long, device=device)
+ k_pos = torch.arange(j, dtype=torch.long, device=device)
+ rel_pos = k_pos[None, :] - q_pos[:, None]
+ rp_bucket = self._relative_position_bucket(
+ rel_pos,
+ causal=self.causal,
+ num_buckets=self.num_buckets,
+ max_distance=self.max_distance,
+ )
+ values = self.relative_attention_bias(rp_bucket)
+ bias = rearrange(values, "i j h -> () h i j")
+ return qk_dots + (bias * self.scale)
+
+
+class MultiHeadAttention(nn.Module):
+ """
+ only for GST
+ input:
+ query --- [N, T_q, query_dim]
+ key --- [N, T_k, key_dim]
+ output:
+ out --- [N, T_q, num_units]
+ """
+
+ def __init__(self, query_dim, key_dim, num_units, num_heads):
+ super().__init__()
+ self.num_units = num_units
+ self.num_heads = num_heads
+ self.key_dim = key_dim
+
+ self.W_query = nn.Linear(
+ in_features=query_dim, out_features=num_units, bias=False
+ )
+ self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
+ self.W_value = nn.Linear(
+ in_features=key_dim, out_features=num_units, bias=False
+ )
+
+ def forward(self, query, key):
+ querys = self.W_query(query) # [N, T_q, num_units]
+ keys = self.W_key(key) # [N, T_k, num_units]
+ values = self.W_value(key)
+
+ split_size = self.num_units // self.num_heads
+ querys = torch.stack(
+ torch.split(querys, split_size, dim=2), dim=0
+ ) # [h, N, T_q, num_units/h]
+ keys = torch.stack(
+ torch.split(keys, split_size, dim=2), dim=0
+ ) # [h, N, T_k, num_units/h]
+ values = torch.stack(
+ torch.split(values, split_size, dim=2), dim=0
+ ) # [h, N, T_k, num_units/h]
+
+ # score = softmax(QK^T / (d_k ** 0.5))
+ scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
+ scores = scores / (self.key_dim**0.5)
+ scores = F.softmax(scores, dim=3)
+
+ # out = score * V
+ out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
+ out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
+ 0
+ ) # [N, T_q, num_units]
+
+ return out
+
+
+class GST(nn.Module):
+ def __init__(
+ self, model_channels=512, style_tokens=100, num_heads=8, in_channels=80
+ ):
+ super(GST, self).__init__()
+ self.model_channels = model_channels
+ self.style_tokens = style_tokens
+ self.num_heads = num_heads
+
+ # self.reference_encoder=nn.Sequential(
+ # nn.Conv2d(1,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
+ # nn.Conv2d(32,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
+ # nn.Conv2d(32,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
+ # nn.Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
+ # AttentionBlock(64, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
+ # AttentionBlock(128, 8, relative_pos_embeddings=True),
+ # nn.Conv2d(128,model_channels,kernel_size=(3,3),stride=(1,1),padding=(1, 1)),normalization(model_channels),nn.ReLU(inplace=True),
+ # AttentionBlock(model_channels, 16, relative_pos_embeddings=True)
+ # )
+
+ # self.reference_encoder=nn.Sequential(
+ # nn.Conv1d(80,model_channels,3,padding=1,stride=2),
+ # nn.Conv1d(model_channels, model_channels,3,padding=1,stride=2),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
+ # AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
+ # )
+
+ # in_channels=1
+ # num_heads = 8
+ self.reference_encoder = nn.Sequential(
+ nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2),
+ nn.Conv1d(model_channels, model_channels, 3, padding=1, stride=2),
+ AttentionBlock(
+ model_channels,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ AttentionBlock(
+ model_channels,
+ num_heads,
+ relative_pos_embeddings=True,
+ do_checkpoint=False,
+ ),
+ )
+
+ # self.gru=nn.GRU(128*2,256,batch_first=True,bidirectional=True)
+ # self.attention = MultiHeadAttention(query_dim=model_channels, key_dim=model_channels//num_heads, num_units=model_channels*2, num_heads=num_heads)
+ # self.style_tokens = nn.parameter.Parameter(torch.FloatTensor(style_tokens,model_channels//num_heads))
+
+ # init.normal_(self.style_tokens, mean=0, std=0.5)
+
+ def forward(self, x):
+ # add masking
+ batch = x.size(0)
+ # x=x.view(batch,1,-1,80) # (N,1,t,80)
+ x = self.reference_encoder(x) # (N,128,t,80//x)
+ # print(x.shape,'encoder')
+ x = x.mean(dim=-1) # .mean(dim=-1)
+ # # x=x.transpose(1,2).contiguous() #(N,t,128,80//x)
+ # # time=x.size(1)
+ # # x=x.view(batch,time,-1)
+ # # _,x=self.gru(x)
+ # # print(x.shape,'gru')
+ # x=x.view(batch,1,-1)
+ # keys = self.style_tokens.unsqueeze(0).expand(batch, -1, -1) # [N, token_num, E // num_heads]
+ # # print(keys.shape,'keys')
+ # style_embed = self.attention(x, keys)
+ # # print(style_embed.shape,'gst tokens')
+
+ # add normalization?
+
+ return x.view(batch, -1, 1)
+
+
+if __name__ == "__main__":
+ device = torch.device("cpu")
+ m = GST(512, 10).to(device)
+ mels = torch.rand((16, 80, 1000)).to(device)
+
+ o = m(mels)
+ print(o.shape, "final output")
+
+ from torchinfo import summary
+
+ summary(m, input_data={"x": torch.randn(16, 80, 500).to(device)})
diff --git a/T2S/utilities.py b/T2S/utilities.py
new file mode 100755
index 0000000000000000000000000000000000000000..992e336560a9dddfd2bab7c785806e5a23479d8a
--- /dev/null
+++ b/T2S/utilities.py
@@ -0,0 +1,126 @@
+import librosa.util as librosa_util
+import numpy as np
+import torch
+from scipy.io.wavfile import read
+from scipy.signal import get_window
+
+# import librosa
+from config import config
+
+# find these values
+TACOTRON_MEL_MAX = 2.3143386840820312
+TACOTRON_MEL_MIN = -11.512925148010254
+
+
+def denormalize_tacotron_mel(norm_mel):
+ return ((norm_mel + 1) / 2) * (
+ TACOTRON_MEL_MAX - TACOTRON_MEL_MIN
+ ) + TACOTRON_MEL_MIN
+
+
+def normalize_tacotron_mel(mel):
+ return 2 * ((mel - TACOTRON_MEL_MIN) / (TACOTRON_MEL_MAX - TACOTRON_MEL_MIN)) - 1
+
+
+def get_mask_from_lengths(lengths, max_len=None):
+ if not max_len:
+ max_len = torch.max(lengths).item()
+ ids = torch.arange(0, max_len, device=lengths.device, dtype=torch.long)
+ mask = (ids < lengths.unsqueeze(1)).bool()
+ return mask
+
+
+def get_mask(lengths, max_len=None):
+ if not max_len:
+ max_len = torch.max(lengths).item()
+ lens = torch.arange(
+ max_len,
+ )
+ mask = lens[:max_len].unsqueeze(0) < lengths.unsqueeze(1)
+ return mask
+
+
+def dynamic_range_compression(x, C=1, clip_val=1e-5):
+ """
+ PARAMS
+ ------
+ C: compression factor
+ """
+ return torch.log(torch.clamp(x, min=clip_val) * C)
+
+
+def dynamic_range_decompression(x, C=1):
+ """
+ PARAMS
+ ------
+ C: compression factor used to compress
+ """
+ return torch.exp(x) / C
+
+
+def window_sumsquare(
+ window,
+ n_frames,
+ hop_length=200,
+ win_length=800,
+ n_fft=800,
+ dtype=np.float32,
+ norm=None,
+):
+ """
+ # from librosa 0.6
+ Compute the sum-square envelope of a window function at a given hop length.
+ This is used to estimate modulation effects induced by windowing
+ observations in short-time fourier transforms.
+ Parameters
+ ----------
+ window : string, tuple, number, callable, or list-like
+ Window specification, as in `get_window`
+ n_frames : int > 0
+ The number of analysis frames
+ hop_length : int > 0
+ The number of samples to advance between frames
+ win_length : [optional]
+ The length of the window function. By default, this matches `n_fft`.
+ n_fft : int > 0
+ The length of each analysis frame.
+ dtype : np.dtype
+ The data type of the output
+ Returns
+ -------
+ wss : np.ndarray, shape=`(n_fft + hop_length * (n_frames - 1))`
+ The sum-squared envelope of the window function
+ """
+ if win_length is None:
+ win_length = n_fft
+
+ n = n_fft + hop_length * (n_frames - 1)
+ x = np.zeros(n, dtype=dtype)
+
+ # Compute the squared window at the desired length
+ win_sq = get_window(window, win_length, fftbins=True)
+ win_sq = librosa_util.normalize(win_sq, norm=norm) ** 2
+ win_sq = librosa_util.pad_center(win_sq, size=n_fft)
+
+ # Fill the envelope
+ for i in range(n_frames):
+ sample = i * hop_length
+ x[sample : min(n, sample + n_fft)] += win_sq[: max(0, min(n_fft, n - sample))]
+ return x
+
+
+def load_wav_to_torch(full_path):
+ sampling_rate, data = read(
+ full_path,
+ )
+ # print(data)
+ # data,sampling_rate = librosa.load(full_path)
+ # print(data)
+ return torch.FloatTensor(data), sampling_rate
+
+
+if __name__ == "__main__":
+ lens = torch.tensor([2, 3, 7, 5, 4])
+ mask = get_mask(lens)
+ print(mask)
+ print(mask.shape)
diff --git a/Text/__init__.py b/Text/__init__.py
new file mode 100755
index 0000000000000000000000000000000000000000..86c7066e0a26baf882722df99d3e1294355ab661
--- /dev/null
+++ b/Text/__init__.py
@@ -0,0 +1 @@
+from .symbols import *
\ No newline at end of file
diff --git a/Text/__pycache__/__init__.cpython-310.pyc b/Text/__pycache__/__init__.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..92474e15041ad7f225527c4e3f6b0b2ab0f42569
Binary files /dev/null and b/Text/__pycache__/__init__.cpython-310.pyc differ
diff --git a/Text/__pycache__/symbols.cpython-310.pyc b/Text/__pycache__/symbols.cpython-310.pyc
new file mode 100644
index 0000000000000000000000000000000000000000..2c5b1450972b9ae6f7a6ebef092e197788540fa8
Binary files /dev/null and b/Text/__pycache__/symbols.cpython-310.pyc differ
diff --git a/Text/symbols.py b/Text/symbols.py
new file mode 100755
index 0000000000000000000000000000000000000000..88d605c095c0b19bf52a6762f1bdc22895b445ec
--- /dev/null
+++ b/Text/symbols.py
@@ -0,0 +1,88 @@
+import sys
+import pickle as pkl
+sys.path.append("../")
+from config import config
+
+
+# labels=" abcdefghijklmnopqrstuvwxyz.,:;'()?!\""
+# # labels=" !\"'(),-.:;?[]abcdefghijklmnopqrstuvwxyzàâèéêü’“”"
+# labels='''ଊతూിਮ০य़లഢਪਟକఝૂएड`यঢअచଢ଼ਧ—ତলશರଖच,பવड़ષंಈಮਤਇଥkखഗబ= इਸಣਹછ™ୟ.ोೀৎುഊଳંർਘମഴఙसଗൃlଝਜఇഓਐভയಅಠభാടਔಒ೧পஜaૅૠএଲ৯eകँ৭àৱऊટഒਗহિేயీെஈଓഭೊাੌಙ१ଈःസठખm‘ొऍಿcശrట।ऱଋઘਛெਬಂङಹஞ਼ભ১"એੂചಸગಷ়ଁമಓtஒઉಪs్-pଛ›ढ+ಆ'বનধৰউીଅઝ੍ೂʼൂఔfતषഖঢ়৬ਖक़ਵషணझപળଔઞੇವௗઁത২xెഥख़iটਲધಔೇீથ*ഝॅঃஓूఒীనਜ਼எુுహौ९ൗౌফഔોhஔণంफ़ఋçଯઊൽଆ’ୁைഛ२&ঁണ़ైৌআஆোਠਭजொમळಘஷഏি/ચਾ“ਯ$ଐീवऩ८ઢఛఎেథഠ[औಳରथୃൈಝnজਥऑଷੱल೯wओଵढ़மവरడఊbೖਈૃपdêଉఐ;ै ఢ ઔકচ৩ਊൾഉਕ೦ಏj€:ਦಗાളੁशफുழൻಊगફఏఅ?णറഘಞ४ಡಫଠ್ড೨ൊঞमਂસૉॉઅരஙલঘନ്ఠॄvઋృষऎகೕଘઆఞലେূஊఉૈദఫఈदকज़!ధઠవଞறಟਖ਼ਫ਼ইਢഡঠஃஸୂटঅହఆளోईৃಜ॥(ઈଏੀഈक્গ ಚಢഹೃिஏಯyশேଡೋੈਣડఃഷഇਸ਼நখಋோনૐਏgहৗೈृவੰଜग़ੋ୍)ൌరమൺংञਓપయധஇോ५ઃಲళঊತॽന…ঙಭाಇउਅଶরઓି্ூমuపബ\ૌଟबਆुಕଫதছ३దਿದణஐௌ்ৈqఘலહಾ०ಛঐிওऋి৮ेਨଇүଧഞಶéਚ्৫ୋశఓદঈୀ৪ପüুങਗ਼ઑજথఖঝಐऽਰାആജीઇੜ]आବଡ଼ഫಥుಎણଃયछஅેஹംଢબoদഎగଭాേഅঋসഐಃzਡಬਝன–உಖಉഃযସୈೆకॐನഋয়సசଙড়ୱऒऐઐतଂாতરâèनಧ॑டঔभர”జ৷ਫଣଚଦधघೌୌਉ'''
+# labels = '''്ുIଢ଼िோശતયసૅऱઠಆഠৗ.ષਢీિఱहమହளಓರऎ૨ঃমଙൾਓOણBઅਨhੰટപଯrணধऽDേలौೃୀఎ்ഷwீ೫೩லഋJൈদঞਛ3ৌଏషിಭनഐઆഊਏಠଅஞ০യఊிఘଡtಃൂଠҙஙૐబୌ,गઑଉgആૈ॑ाൽೖമજ€ੌୱ్ওಾଂইોकफயেഉഎख़చകয়ঙणGశेतં৯೦ాਜ਼ொోஐকळైҳक़ूएఢçഇેযઋਪଦഡৰഢYଳढਈடપെঢ়ౌসಎਇઉಋণಶப೧ओಒZVਦଶছ४ೆflुೌೇോઝୈఙಇঐெഛপૠઇഞడ়ਔീई೨Pमஷફକేঁଗೊળభన್ଁਣఅஓுৃछഏMीsWଛঅਯচਚరଖजஈ५ॠêफ़ॐನੜ৮ଈଭଢइડಥঝ½ேउଷஎ5ைુ?ಜऍଚಢഓఋ]pఃಫ+०ೋେખਡਫഘठജൻఒঠஉਜೀୋంఞलKਊಪ L-ਠಂఈഥഒdಟ१ిૂಔैఛௌNFAఝऋஔೕલೂऊથಬପற;ાઞ'ଝಞളಲcಖൊుোੀoଃज़ಙಕचఏ़थఠSQୁృିஸாઍଵरಯਥ“വंਸୟড়ਬ२ଐ८ॽੱథx ಧવਗବবತॅੁঊগলவ0ൗী7ੋ૪૫৷ৎશદଟnஏ!ఐभ*ധഫങഅC›ਅధగUಛછ4தeચदભఖ’அૃʼநଜ—తজସഈ्ଣ"Xીಘਾ2സୃઈಏஃଊಿঋaமళâ২ஹଇuূਲऩାఆृm8ઔवyভ`)ଋംଡ଼ਸ਼ ॄரऔਉ˜పಸசञಡజ৩ാ೯ডൺখધ্ూ୍ਭஊദৈHഹਿಗബশ\ଲਗ਼ਧఓಝవନऑzಮखর९शiదఇঘર1ണথ৭[એ॥হ”கҮગষॉରਖഭഖઊઓbय़ੂपઁಈ৪ਰ(టबಣਟझറड़ஜனഃಉঔEনঈਐમयವअਂःழqତँ।৫ଓణഗ১ઃोधൌબଘटದTஇഴু૧ഔఔਖ਼ংঢ&નಐ਼™ூ–ਝૌମસଆஆಅ३કૄটऒ6యൃહk৬া: ௗଧஒडvઘലొ$…ർടjಚୂઐਆਵೈউ‘നઢRফੇਹెചघ੍আకಹੈसਘএଥग़രি=षङതഝàऐ9્ૉਤଞುଫఉਮआতಷహਫ਼ৱҠಊଔढ़ಳఫਕ/үéüè#%<>@^_{|}~'''
+
+# with open("../Text/symbols_final.pkl",'rb') as file:
+# labels = pkl.load(file)
+
+english =''' !"#$%&'()*+,-.0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{} ´»àáâæçèéêíïñôúüœ˜ҙҡүҳ–—‘’“”•…€™️'''
+
+hindi =''' !"'(),-.0123456789;?abcdefghijklmnopqrstuvwxyzँंःअआइईउऊऋऌऍऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऻ़ािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़ॠॡॢ।॥०॰'''
+
+kannada =''' !"#$%&'()*+,-./0123456789:;<=>?[]^`abcdefghiklmnopqrstuvwxy ½ʼॐ।॥ಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ –‘’“”•…'''
+
+tamil =''' !"%&'()*,-./0123456789:;?[]`abcdefghijklmnopqrstuvwxyz ஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗ–‘’“”…'''
+
+assamese =''' !%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz ʼ।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৷৹–—‘’“”'''
+
+dogri =''' !$%&'()+,-./0123456789:;?@[]`abcdefghijklmnopqrstuvwxyzʼँंःअआइईउऊएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ऽािीुूृेैॉोौ्क़ज़ड़ढ़फ़य़।–—‘’“”…′₹'''
+
+bodo =''' !$%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz °º½ʼँंःअआइईउऊऋऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलवशषसह़ािीुूृॅॆेैॉॊोौ्ज़ड़फ़य़।०ॽ–‘’“”'''
+
+marathi =''' !'*+,-./0123456789:;?[`z ँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॑ॕक़ज़ड़ढ़फ़य़ॠॡ।॥०१२३४५६७८९ॲ–‘’“”›'''
+
+bengali =''' !',-.0123456789;?acdefghlmnrstuyz।ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৻—‘’'''
+
+telugu =''' !"'*,-./258:;?o ।ఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖ–‘’”…'''
+
+german =''' '-abcdefghijklmnopqrstuvwxyzßàäèéöü'''
+
+spanish =''' !'-.;?abcdefghijklmnopqrstuvwxyzáéíñóúü'''
+
+french =''' !'-.;?abcdefghijklmnopqrstuvwxyzàâæçèéêëîïôùûüÿœ'''
+
+punjabi =''' !"'(),-.:?bden।ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳ –‘’“”…'''
+
+sanskrit =''' "ँंःअआइईउऊऋऌऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑ॠॡॢ।॥०१२३४५६७८९॰'''
+
+odia =''' "',-.;।ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱ‘’”'''
+
+urdu =''' !"',-.:`abcdeghiklrtuy،ؑؓؔ؛؟ءآأؤئابتثجحخدذرزسشصضطظعغفقكلمنهوىيًَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔।‘’“”…ﭨﮭﮯﯾﷲﷺﺅﺗﺘﺩﺲﻧﻮ'''
+
+gujarati =''' !',-.:;?m ।ઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯–‘’…'''
+
+rajasthani =''' !'(),-.0123456789:;?xँंःअआइईउऊऋऍएऐऑओऔकखगघचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ािीुूृेैॉोौ्क़ख़ग़ज़ड़ढ़फ़ॠ।०१२३७८९‘’…'''
+
+malayalam =''' !,?ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾ'''
+
+manipuri =''' ,-.mnঁংঅআইঈউঊএঐওঔকখগঘঙচছজঝঞটঠডণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎয়০১২৩৪৫৬৭৮৯ৰৱ৷'''
+
+gujrati =''' !"'(),-.0123456789:?{} âઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૅેૈૉોૌ્ૠ૦૧૨૩૪૫૬૭૮૯–—‘’“”…'''
+
+bhojpuri =''' !"'(),-.012346789:?`abeimpy{}·ँंःअआइईउऊऋऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलळवशषसहऺ़ऽािीुूृॅॆेैॉॊोौ्ॐॕॖक़ख़ग़ज़ड़ढ़फ़य़ॠ।०२६९॰ ‘’'''
+
+italian =''' !"$'()+,-.:;<=>?[]_`abcdefghijklmnopqrstuvwxyz{}~¡«°´µº»ßàáâãäåæèéêëìíîïðñòóôöøùúûþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː̨́абдеиклмностуцшѐёљңדהוةرسصغليḥṇṛṣṭễ‑–—‘’“”„…′☆♭♯あアカキサザノフリン・ー万三丰古多家峰張旅禅ꞌ'''
+
+arabic =''' !"',-.:;?egt«»،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٓچکیۖۗۘۙۚۛ–“”…ﺃ'''
+
+korean =''' 가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘'''
+
+russian =''' !"'(),-.:;?abcefghiklmnoprstxz«»абвгдежзийклмнопрстуфхцчшщъыьэюяё‑–—“„…−'''
+
+thai =''' กขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์'''
+
+japanese =''' !',-.?abcdefghijklmnopqrstuvwxyzμ―‘’“”…☆♡⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍녕러분세안여요하︎!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
+
+
+final_characters_ = ''' !"#$%&'()*+,-./:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{}~ ¡«°´µ·º»½ßàáâãäåæçèéêëìíîïðñòóôöøùúûüþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː˜̨́μабвгдежзийклмнопрстуфхцчшщъыьэюяѐёљҙҡңүҳדהו،ؑؓؔ؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔ۖۗۘۙۚۛँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसहऺऻ़ऽािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॑॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़य़ॠॡॢ।॥०१२३४५६७८९॰ॲॽঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৹৻ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾกขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์ḥṇṛṣṭễ ‑–—―‘’“”„•…′›€₹™−☆♡♭♯⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中丰串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禅禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍ꞌ가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘ﭨﮭﮯﯾﷲﷺ︎️ﺃﺅﺗﺘﺩﺲﻧﻮ!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
+# removed numbers
+
+labels= [i for i in final_characters_]
+
+text_labels = [i for i in labels]
+text_labels+='','',''
+
+code_labels= [str(i) for i in range(config.semantic_model_centroids)]
+labels+=code_labels
+code_labels+='','',''
+
+labels+='','','','',''
+
+print('length of the labels: ',len(labels))
\ No newline at end of file
diff --git a/Text/symbols_final.pkl b/Text/symbols_final.pkl
new file mode 100755
index 0000000000000000000000000000000000000000..21986e3b3b109a7e8186f5875b30c888432d60cd
Binary files /dev/null and b/Text/symbols_final.pkl differ
diff --git a/Text/text_meta.txt b/Text/text_meta.txt
new file mode 100755
index 0000000000000000000000000000000000000000..5ba3c77dd7ddf42f10d1de96bf12c9480f56ed55
--- /dev/null
+++ b/Text/text_meta.txt
@@ -0,0 +1,62 @@
+'odia', 'assamese', 'thai', 'gujrati', 'russian', 'japanese', 'punjabi', 'hindi', 'manipuri', 'korean', 'bhojpuri', 'sanskrit', 'english', 'french', 'bodo', 'malayalam', 'telugu', 'kannada', 'dogri', 'marathi', 'german', 'italian', 'rajasthani', 'spanish', 'arabic', 'urdu', 'gujarati', 'tamil', 'bengali',
+
+final_characters_ = ''' !"#$%&'()*+,-./0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{}~ ¡«°´µ·º»½ßàáâãäåæçèéêëìíîïðñòóôöøùúûüþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː˜̨́μабвгдежзийклмнопрстуфхцчшщъыьэюяѐёљҙҡңүҳדהו،ؑؓؔ؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔ۖۗۘۙۚۛँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसहऺऻ़ऽािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॑॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़य़ॠॡॢ।॥०१२३४५६७८९॰ॲॽঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৹৻ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾกขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์ḥṇṛṣṭễ ‑–—―‘’“”„•…′›€₹™−☆♡♭♯⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中丰串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禅禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍ꞌ가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘ﭨﮭﮯﯾﷲﷺ︎️ﺃﺅﺗﺘﺩﺲﻧﻮ!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
+
+english =''' !"#$%&'()*+,-.0123456789:;<=>?@[\]^_`abcdefghijklmnopqrstuvwxyz{} ´»àáâæçèéêíïñôúüœ˜ҙҡүҳ–—‘’“”•…€™️'''
+
+hindi =''' !"'(),-.0123456789;?abcdefghijklmnopqrstuvwxyzँंःअआइईउऊऋऌऍऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसहऻ़ािीुूृॄॅॆेैॉॊोौ्ॏॐ॒॓॔ॕॖॗक़ख़ग़ज़ड़ढ़फ़ॠॡॢ।॥०॰'''
+
+kannada =''' !"#$%&'()*+,-./0123456789:;<=>?[]^`abcdefghiklmnopqrstuvwxy ½ʼॐ।॥ಂಃಅಆಇಈಉಊಋಎಏಐಒಓಔಕಖಗಘಙಚಛಜಝಞಟಠಡಢಣತಥದಧನಪಫಬಭಮಯರಱಲಳವಶಷಸಹ಼ಽಾಿೀುೂೃೆೇೈೊೋೌ್ೕೖೞೠ೦೧೨೩೪೫೬೭೮೯ –‘’“”•…'''
+
+tamil =''' !"%&'()*,-./0123456789:;?[]`abcdefghijklmnopqrstuvwxyz ஃஅஆஇஈஉஊஎஏஐஒஓஔகஙசஜஞடணதநனபமயரறலளழவஷஸஹாிீுூெேைொோௌ்ௗ–‘’“”…'''
+
+assamese =''' !%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz ʼ।॥ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৷৹–—‘’“”'''
+
+dogri =''' !$%&'()+,-./0123456789:;?@[]`abcdefghijklmnopqrstuvwxyzʼँंःअआइईउऊएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलवशषसह़ऽािीुूृेैॉोौ्क़ज़ड़ढ़फ़य़।–—‘’“”…′₹'''
+
+bodo =''' !$%&'()+,-./0123456789:;?[]_abcdefghijklmnopqrstuvwxyz °º½ʼँंःअआइईउऊऋऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलवशषसह़ािीुूृॅॆेैॉॊोौ्ज़ड़फ़य़।०ॽ–‘’“”'''
+
+marathi =''' !'*+,-./0123456789:;?[`z ँंःअआइईउऊऋऌऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनऩपफबभमयरऱलळऴवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॑ॕक़ज़ड़ढ़फ़य़ॠॡ।॥०१२३४५६७८९ॲ–‘’“”›'''
+
+bengali =''' !',-.0123456789;?acdefghlmnrstuyz।ঁংঃঅআইঈউঊঋএঐওঔকখগঘঙচছজঝঞটঠডঢণতথদধনপফবভমযরলশষসহ়ঽািীুূৃৄেৈোৌ্ৎৗড়ঢ়য়০১২৩৪৫৬৭৮৯ৰৱ৵৷৻—‘’'''
+
+telugu =''' !"'*,-./258:;?o ।ఁంఃఅఆఇఈఉఊఋఎఏఐఒఓఔకఖగఘఙచఛజఝఞటఠడఢణతథదధనపఫబభమయరఱలళవశషసహాిీుూృెేైొోౌ్ౖ–‘’”…'''
+
+german =''' '-abcdefghijklmnopqrstuvwxyzßàäèéöü'''
+
+spanish =''' !'-.;?abcdefghijklmnopqrstuvwxyzáéíñóúü'''
+
+french =''' !'-.;?abcdefghijklmnopqrstuvwxyzàâæçèéêëîïôùûüÿœ'''
+
+punjabi =''' !"'(),-.:?bden।ਁਂਅਆਇਈਉਊਏਐਓਔਕਖਗਘਙਚਛਜਝਞਟਠਡਢਣਤਥਦਧਨਪਫਬਭਮਯਰਲਵਸ਼ਸਹ਼ਾਿੀੁੂੇੈੋੌ੍ਖ਼ਗ਼ਜ਼ੜਫ਼ੰੱੲੳ –‘’“”…'''
+
+sanskrit =''' "ँंःअआइईउऊऋऌऎएऐऑओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ऽािीुूृॄॅॆेैॉॊोौ्ॐ॒॑ॠॡॢ।॥०१२३४५६७८९॰'''
+
+odia =''' "',-.;।ଁଂଃଅଆଇଈଉଊଋଏଐଓଔକଖଗଘଙଚଛଜଝଞଟଠଡଢଣତଥଦଧନପଫବଭମଯରଲଳଵଶଷସହ଼ଽାିୀୁୂୃୄେୈୋୌ୍ୖଡ଼ଢ଼ୟୠୢ୦୧୨୩୪୫୬୭୮୯୰ୱ‘’”'''
+
+urdu =''' !"',-.:`abcdeghiklrtuy،ؑؓؔ؛؟ءآأؤئابتثجحخدذرزسشصضطظعغفقكلمنهوىيًَُِّْٓٔٗ٬ٰٴٹپچڈڑژکگںھہۂۃیےۓ۔।‘’“”…ﭨﮭﮯﯾﷲﷺﺅﺗﺘﺩﺲﻧﻮ'''
+
+gujarati =''' !',-.:;?m ।ઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘઙચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૄૅેૈૉોૌ્ૐૠૢ૦૧૨૩૪૫૬૭૮૯–‘’…'''
+
+rajasthani =''' !'(),-.0123456789:;?xँंःअआइईउऊऋऍएऐऑओऔकखगघचछजझञटठडढणतथदधनपफबभमयरलळवशषसह़ािीुूृेैॉोौ्क़ख़ग़ज़ड़ढ़फ़ॠ।०१२३७८९‘’…'''
+
+malayalam =''' !,?ംഃഅആഇഈഉഊഋഎഏഐഒഓഔകഖഗഘങചഛജഝഞടഠഡഢണതഥദധനപഫബഭമയരറലളഴവശഷസഹാിീുൂൃെേൈൊോൌ്ൗൺൻർൽൾ'''
+
+manipuri =''' ,-.mnঁংঅআইঈউঊএঐওঔকখগঘঙচছজঝঞটঠডণতথদধনপফবভমযরলশষসহ়ািীুূৃেৈোৌ্ৎয়০১২৩৪৫৬৭৮৯ৰৱ৷'''
+
+gujrati =''' !"'(),-.0123456789:?{} âઁંઃઅઆઇઈઉઊઋઍએઐઑઓઔકખગઘચછજઝઞટઠડઢણતથદધનપફબભમયરલળવશષસહ઼ાિીુૂૃૅેૈૉોૌ્ૠ૦૧૨૩૪૫૬૭૮૯–—‘’“”…'''
+
+bhojpuri =''' !"'(),-.012346789:?`abeimpy{}·ँंःअआइईउऊऋऍऎएऐऑऒओऔकखगघङचछजझञटठडढणतथदधनपफबभमयरऱलळवशषसहऺ़ऽािीुूृॅॆेैॉॊोौ्ॐॕॖक़ख़ग़ज़ड़ढ़फ़य़ॠ।०२६९॰ ‘’'''
+
+italian =''' !"$'()+,-.:;<=>?[]_`abcdefghijklmnopqrstuvwxyz{}~¡«°´µº»ßàáâãäåæèéêëìíîïðñòóôöøùúûþÿāąćčđėęěğīıľłńňōőœřśşšūŭźżžșțəʻʼʾʿː̨́абдеиклмностуцшѐёљңדהוةرسصغليḥṇṛṣṭễ‑–—‘’“”„…′☆♭♯あアカキサザノフリン・ー万三丰古多家峰張旅禅ꞌ'''
+
+arabic =''' !"',-.:;?egt«»،؛؟ءآأؤإئابةتثجحخدذرزسشصضطظعغـفقكلمنهوىيًٌٍَُِّْٰٓچکیۖۗۘۙۚۛ–“”…ﺃ'''
+
+korean =''' 가각간갈감갑값갓갔강갖같갚개객갠갤갯갱거걱건걷걸검겁것겉게겐겟겠겨격겪견결겸겹겼경곁계고곡곤곧골곰곱곳공곶과곽관괄괌광괘괭괴굉교구국군굳굴굵굶굽굿궁궈권궐궤귀규균그극근글금급긋긍기긴길김깁깃깊까깎깔깝깡깥깨깬꺼껍껏껑께껴꼈꼬꼭꼴꼼꼽꽁꽂꽃꽝꽥꾸꾼꿀꿇꿈꿎꿔꿨꿰뀌뀐끄끅끈끊끌끓끔끗끝끼끽낀낌나낙난날낡남납낫났낭낮낯낳내낸낼냄냅냇냈냉냐냥너넉넌널넓넘넛넣네넥넨넷녀녁년념녔녕녘노녹논놀놈농높놓놨뇌뇨뇽누눈눌눠뉜뉴늉느는늘늠능늦늪늬니닉닌닐님닙닛닝다닥닦단닫달닭닮닳담답닷당닿대댁댐댓더덕던덜덟덤덧덩덫덮데덴델뎅뎌도독돈돋돌돔돕돗동돛돼됐되된될됨됩두둑둔둘둠둥둬뒀뒤뒷듀드득든듣들듬듭듯등디딘딜딥딨딩딪따딱딴딸땀땄땅때땐땠땡떠떡떤떨떴떻떼뗄또똑똥뚜뚝뚫뛰뛴뜨뜯뜸뜻띄띔띠띤띨띵라락란랄람랍랐랑래랙랜램랫랬랭랴략량러럭런럴럼럽렀렁렇레렉렌렘렛려력련렬렴렵렷렸령례로록론롤롬롭롯롱뢰료룡루룬룰룸룹뤄뤘뤼류륙륜률륭르륵른를름릅릇릉릎리릭린릴림립릿링마막만많맏말맑맘맙맛망맞맡매맥맨맹맺머먹먼멀멈멋멍메멕멘멜멧며면멸명몇모목몫몬몰몸못몽묘무묵묶문묻물뭄뭇뭐뭔뭘뮤뮬므믈미믹민믿밀밋밌밍및밑바박밖반받발밝밤밥방밭배백밴뱀뱃뱅버벅번벌범법벗벚베벤벨벳벼벽변별볍볐병볕보복볶본볼봄봅봇봉봐봤뵈뵙부북분불붉붐붓붕붙뷔뷰브븐블비빅빈빌빗빙빚빛빠빨빵빼뺀뺌뺏뺑뻐뻑뻔뻗뻘뼈뽀뽑뽕뿌뿐뿜쁘쁜쁩삐사삭산살삶삼삽삿샀상새색샌샐샘생샤샬샵샷서석섞선섣설섬섭섯섰성세섹센셀셈셉셋셔션셜셨셰소속손솔솜솟송솥쇄쇠쇤쇼숍수숙순술숨숭숲쉬쉰쉼쉽슈슐스슨슬슴습슷승시식신실싫심십싱싶싸싹싼쌀쌈쌌쌍쌓써썩썰썼쏘쏜쏟쏠쑤쓰쓴쓸씀씁씌씨씩씬씸씻아악안앉않알앓암압앗았앙앞애액앤앨앱앵야약얇양얗얘어억언얹얻얼얽엄업없엇었엉엎에엑엔엘엠엡엣여역연열엷염엽엿였영옆예옛오옥온올옮옳옴옵옷옹와완왈왔왕왜외왼요욕용우욱운울움웁웃웅워원월웠웨웬웹위윈윌윗윙유육윤율융으은을음읍응의이익인일읽잃임입잇있잉잊잎자작잔잖잘잠잡잣장잦재잭잰잽쟁저적전절젊젋점접젓정젖제젝젠젤져젼졌조족존졸좀좁종좋좌죄죠주죽준줄줍중줘줬쥐쥔쥘쥬즈즉즌즐즘증지직진질짊짐집짓징짖짙짚짜짝짧짬째쨌쩌쩍쩔쩜쪼쪽쫄쫓쭉쯤찌찍찔찢찧차착찬찮찰참찻창찾채책챌챔챙챠처척천철첨첩첫청체첸첼쳇쳐쳤초촉촌촘촛총촨촬최추축춘출춤충춰취츠측츰층치칙친칠침칩칫칭카칸칼캉캐캔캘캠커컥컨컫컴컵컷컸케켈켐켑켓켜켰코콘콜콤콥콧콩쾌쿄쿠쿡쿨쿼퀴큐크큰클큼키킥킨킬킷킹타탁탄탈탐탑탓탕태택탠탬탱터턱턴털텃텅테텍텐텔템텼토톡톤톨톰통퇴투툴툼퉁튀튜튬트특튼튿틀틈티틱틴틸팀팅파팍팎판팔팜팡패팩팬팰팻팽퍼펀펄펌페펜펠펫펴편펼폄폈평폐포폭폰폴폼표푸푹푼풀품풋풍퓨퓰프픈플픔피픽핀필핏핑하학한할함합핫항해핵핸햇했행향허헌헐험헝헤헨헬헴헷혀혁현혈혐협혔형혜호혹혼홀홈홉홍화확환활황회획횡효후훈훌훤훨훼휘휩휴흉흐흑흔흘흙흠흡흥흩희흰히힌힐힘'''
+
+russian =''' !"'(),-.:;?abcefghiklmnoprstxz«»абвгдежзийклмнопрстуфхцчшщъыьэюяё‑–—“„…−'''
+
+thai =''' กขคฆงจฉชซญณดตถทธนบปผฝพฟภมยรฤลวศษสหอฮะัาำิีึืุูเแโใไ็่้๊๋์'''
+
+japanese =''' !',-.?abcdefghijklmnopqrstuvwxyzμ―‘’“”…☆♡⤴、。々〇〈〉「」『』〜ぁあぃいぅうぇえぉおかがきぎくぐけげこごさざしじすずせぜそぞただちっつづてでとどなにぬねのはばぱひびぴふぶぷへべぺほぼぽまみむめもゃやゅゆょよらりるれろゎわゐをん゛ゝゞァアィイゥウェエォオカガキギクグケゲコゴサザシジスズセゼソゾタダチッツテデトドナニヌネノハバパヒビピフブプヘベペホボポマミムメモャヤュユョヨラリルレロワンヴヵヶ・ー一丁七万丈三上下不与且世丘両並中串丸丹主丼乃久之乏乖乗乙九也乱乳乾亀了予争事二于云互五井亘亜亡交京亭人仁今介仏仕他付仙代令以仮仰仲件任企伊伍伏伐休会伝传伯伴伸似但佇位低住佐体何余作佳併使例侑供依価侮侵便係俄俗保俟信俣修俯俳俺倉個倍倒候借倣値倫倶偉偏做停健側偵偶偽傍傑傘備催債傷傾僅働像僕僚僧儀億儘儚償優允元兄充兆先光克免兎児党入全八公六共兵具典兼内円冊再冒冗写冠冬冴冷凄凌凍凛凝凡処凧凪凶出函刀刃分切刈刊刑列初判別利到制刷券刺刻則削前剛剣剤剥副割創劇力功加劣助努劫励労効勃勇勉動勘務勝募勢勤勧勿匂包化北匠匹区医匿十千午半卑卒卓協南単博占印危即却卵厄厚原厨厭厳去参及友双反収取受叙叛叡口古句叩只叫召可台叱史右叶号司各合吉吊同名后吐向君吟吠否含吸吹吾呂呆呈呉告呑呟周味呻呼命咋和咥咲咳咽哀品員哲唄唆唇唐唖唯唱唸唾商問善喉喋喚喜喧喩喪喫喬喰営嗅嘆嘉嘘嘩嘲噂噌噛器噴囚四回因団困囲図固国圏園土圧在地坂均坊坐坑坪垂型垢垣埃埋城域執培基埼堀堂堅堕堤堪報場堵堺塀塁塊塔塗塙塚塩塵塾境墓増墨壁壇壊壌士壬壮声壱売壷壺変夏夕外多夜夢大天太夫央失奄奇奈奉奏契奔套奥奨奪奮女奴好如妃妄妖妙妥妨妬妹妻姉始姓委姦姫姶姻姿威娘娯婆婚婦婿媒媛媽嫁嫉嫌嬉嬢嬬子孔孕字存孝季孤学孫宅宇守安完宍宗官宙定宛宜宝実客宣室宮宰害宴宵家容宿寂寄密富寒寛寝察寧審寵寸寺対寿封専射将尊尋導小少尖尚尤就尺尻尼尽尾尿局居屈届屋屏屑展属層履屯山岐岡岩岬岸峙峠峡峰島崇崎崔崖崩嵐嵩嶋川州巡巣工左巧巨差己巻巾市布帆希帖帝師席帯帰帳常帽幅幌幕幡干平年并幸幹幻幼幽幾庁広庄床序底庖店府度座庫庭庵康廃廉廊廓延廷建廻廿弁弄弊式弐弓引弘弟弥弦弱張強弾当彙形彦彩彫彰影役彼往征待律後徐徒従得御復循微徳徴徹徽心必忍志忘忙応忠快念怒怖思怠急性怨怪怯恋恐恥恨恩息恰恵悔悟悠患悦悩悪悲悶情惑惚惜惟惨惰想惹愉意愕愚愛感慄慈態慌慎慕慢慣慨慮慰慶憂憎憐憑憚憤憧憩憲憶應懊懐懸成我戒或戚戦戯戸戻房所扇扉手才打払扱扶批承技抄把抑投抗折抜択披抱抵抹押抽担拉拍拒拓拗拘招拝拠拡括拭拳拶拾持指挑挙挟挨挫振挾挿捉捌捕捜捧捨捲捻掃授掌排掘掛掟掠採探接控推措掬掲掴掻揃揄描提揖揚換握揮援揶揺損搬搭携搾摂摘摩撃撒撚撞撤撫播撮擁操擢擦支改攻放政故敏救敗教敢散敦敬数整敵敷文斉斎斐斑斗料斜斧斬断斯新方於施旅族旗既日旦旧旨早旬旭旺昂昆昇昌明易昔星映春昧昨昭是昼時晒晦晩普景晴晶智暇暑暖暗暦暫暮暴曇曖曜曲曳更書曹曽替最月有朋服朕朗望朝期朦朧木未末本札朴机权杉李材村杓杖杜束条来杭杯東杵松板析枕林枚果枝枢枯架柄柏染柔柢柱柳柴柵査柿栃栄栖栗校株核根格栽桁桂桃案桎桐桑桔桜桶梁梅梏梗條梧梨梯械梱梼棄棋棒棚棟森棲椅植椎検楊楢業極楽概榛構槌様槻槽標模権横樹樽橄橋橙機橿檎檜櫃櫻欄欅欒欖欠次欧欲欺歌歎歓止正此武歩歪歯歳歴死歿殊残殖殴段殺殻殿母毎毒比毛毫毯氏民気水氷永汁求汗汚汝江池汲決汽沈沖沙没沢沫河沸油治沼沿況泉泊法泡波泣泥注泪泰泳洋洒洗洞津洪洲活派流浄浅浙浜浦浩浪浮浴海浸消涌涎涙涯液涼淀淡深淵混淹添清渇済渉渋渓減渡渦温測港渾湖湧湯湾湿満源準溜溝溢溶溺滅滋滑滝滞滲滴漁漂漆漏演漕漠漢漫漬潔潜潟潤潮潰澄澤激濁濃濡濫濯瀋瀞瀟瀬灘火灯灰災炉炊炎炒炬炭炸点為烈烏烹焔無焦然焼煉煎煙照煮熊熟熱熾燃燈燕燥燵爆爪爵父爺爽爾片版牙牛牝牟牡牢牧物牲特牽犠犧犬犯状狂狐狙狛狩独狭猛猟猪猫献猶猿獄獅獣獰獲玄率玉王玖珀珂珍珠班現球理琢琥琴瑕瑛瑞璃璧環瓢瓦瓶甘生産甥用田由甲申男町画界畏畑留畜畠畢略番異畳畷疆疑疫疲疵疾病症痛痩痰痴痺瘠療癒癖癪発登發白百的皆皇皮皺皿盆益盗盛盟監盤目盲直相盾省眉看県真眠眸眺眼着睡督睦瞑瞬瞭瞳矛矢知矩短矮矯石砂研砕砥砦砲破砺硝硬确碁碑碗碧確磁磐磨磯礎礫示礼社祈祉祓祖祝神祠祢祥票祭禁禄禍福禽秀私秋科秒秘秦秩称移程税稚種稲稼稽稿穀穂積穏穴究空突窃窓窟窪竈立竜竟章童端競竹笑笛笠符第笹筆筈等筋筑筒答策箆箇箋箔箕算管箱箸節範築篠簔簡簾簿籍籐籠米粉粋粒粕粗粘粛粟粧精糖糞糠糧糸系紀約紅紋納紐純紗紙級紛素索紫細紳紹終組絆経結絞絡給絨統絵絶絹継続綜維綱網綺綻綾綿緊総緑緒線締編緩練縁縄縛縞縦縫縮績繁繊繋織繕繞繫繰繹纂纏缶罠罪置罰署罵羅羊美群義羽翁翅翌習翔翻翼老考者耆而耐耕耳聖聞聴職肆肉肌肘肝股肢肥肩肪肯育肺胃胆背胎胞胡胸能脂脅脆脇脈脊脚脱脳脹腎腐腑腕腰腱腹腿膚膜膝膨膳膵臆臍臓臣臨自臭至致臼興舌舎舐舗舘舞舟航般船艘艦良色艶芋芝芥芦芯花芳芸芽苅苑苓苔苗苛若苦苫英苺茂茅茎茨茶茹草荒荘荷菅菊菓菜華菰菱萌萎萩落葉著葛葦葬葵蒙蒡蒲蒸蓄蓋蓬蓮蔑蔓蔭蔵蕎蕨薄薇薔薩薬藁藍藝藤藻蘂蘆蘇蘭虎虚虫虹虻蚊蛇蛍蛙蛮蛹蛾蝉蝶融蟲蟹蟻蠣血衆行術街衛衝衡衣表衰衾袋袖被袴裁裂装裏裕補裳裸製裾複褐褪襟襲西要覆覇見規視覗覚覧親観角解触言訂計訊討訓託記訟訣訪設許訳訴診証詐詑評詞詠詣試詩詮詰話詳誇誉誌認誑誓誕誘語誠誤誦説読誰課調談請諏論諦諫諸謀謂謎謙講謝謡謬謹識譜警議譲護讐认话读谷豆豊豚象豪貌貝貞負財貢貧貨販貪貫責貯貰貴買貸費貼貿賀賂賃賄資賊賑賛賞賠賢賦質賭購賽贅贈赤赦走起超越趣足距跡跨路跳践踊踏蹟蹴躇躊躍身車軋軌軍軒軟転軸軽較載輝輩輪輸輾轄辛辞辰辱農辺込辿迅迎近返这迦迫述迷追退送逃逅逆透逐途逗這通逝速造逢連逮週進逸遂遅遇遊運遍過道達違遠遣遥適遭遮遷選遺遼遽避邂還邑那邦邪邸郊郎郡部郵郷都鄙鄭配酒酔酢酬酵酷酸醍醐醒醜醤釈里重野量金釘釜針釣釧鈍鈴鉄鉛鉢鉱鉾銀銃銅銚銭鋤鋭錆錠錦録鍋鍛鍵鎌鎖鏡鐘鑑長門閃閉開閑間関閣閨閲闇闘阅阜阪防阻阿陀附降限陛陝院陣除陥陰陳陵陸険陽隅隆隊隋階随隔隙際障隠隣隷隻雀雄雅集雇雑離難雨雪雫雰雲雷電需震霊霞霧露霹靂青静非面革靴鞍鞘鞠鞭韋韓韮音韻響頂頃項順須預頑頓頗領頬頭頷頸頻頼題額顎顔顕願類顧领風飛飜食飢飯飲飴飼飽飾餃餅養餌餓館饅首香馬馭馳馴駄駅駆駈駐駒駝駿騎騒験騨騰驚骨骸高髪髭鬱鬼魂魅魔魚鮫鮭鮮鯖鯛鰐鰡鰭鰺鰻鳥鳩鳴鴉鴎鴨鴻鶏鶴鷲鷹鸞鹸鹿麗麦麺麻黄黒黙鼓鼠鼻鼾齢齧龍녕러분세안여요하︎!(),./;=?\^abcdfghijklmnoprstuvwyz~「」ゥ'''
+
diff --git a/app.py b/app.py
new file mode 100755
index 0000000000000000000000000000000000000000..a3455d9c35d06e0f36f4c1d64744bfb2d832b410
--- /dev/null
+++ b/app.py
@@ -0,0 +1,67 @@
+import sys, os, torch
+# sys.path.append("Testing/")
+import gradio as gr
+from inference import infer, prepare_inputs, load_t2s_model, load_cfm, create_wav_header
+from tqdm import tqdm
+
+# Setup
+os.makedirs("generated_samples/", exist_ok=True)
+device = "cuda" if torch.cuda.is_available() else "cpu"
+print("Using device", device)
+
+# Model checkpoints
+m1_checkpoint = "pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt"
+m2_checkpoint = "pretrained_checkpoint/m2.pt"
+vocoder_checkpoint = 'pretrained_checkpoint/700_580k_multilingual_infer_ready/'
+
+global FM, vocoder, m2, mu, std, m1
+
+# Load models
+FM, vocoder, m2, mu, std = load_cfm(m2_checkpoint, vocoder_checkpoint, device)
+m1 = load_t2s_model(m1_checkpoint, device)
+
+
+# Speaker reference clips
+speaker_refs = {
+ "Speaker1": [
+ "speakers/female1/train_hindifemale_02794.wav",
+ "speakers/female1/train_hindifemale_04167.wav",
+ "speakers/female1/train_hindifemale_02795.wav"
+ ]
+}
+
+# Available languages (can be extended)
+available_languages = ["hindi"]
+
+# Inference function
+def generate_audio(text, speaker_name, language):
+ if speaker_name not in speaker_refs:
+ return f"Reference clips not available for {speaker_name}", None
+
+ ref_clips = speaker_refs[speaker_name]
+
+ text_ids, code_ids, language_code, ref_mels_m1, ref_mels_m2 = prepare_inputs(
+ text.lower(),
+ ref_clips_m1=ref_clips,
+ ref_clips_m2=ref_clips,
+ language=language,
+ device=device
+ )
+
+ audio_wav = infer(m1, m2, vocoder, FM, mu, std, text_ids, code_ids, language_code, ref_mels_m1, ref_mels_m2, device)
+ return 24000,audio_wav
+
+# Gradio UI
+interface = gr.Interface(
+ fn=generate_audio,
+ inputs=[
+ gr.Textbox(label="Enter Text"),
+ gr.Dropdown(choices=list(speaker_refs.keys()), label="Select Speaker"),
+ gr.Dropdown(choices=available_languages, label="Select Language")
+ ],
+ outputs=gr.Audio(label="Generated Speech"),
+ title="MAHATTSv2 Demo",
+ description="Enter text, choose a speaker and language to generate speech."
+)
+
+interface.launch(share=True,server_port=9999)
diff --git a/bigvgan_v2_24khz_100band_256x b/bigvgan_v2_24khz_100band_256x
new file mode 160000
index 0000000000000000000000000000000000000000..c329ede9e9bbc100ddf5c91e2330a61921262370
--- /dev/null
+++ b/bigvgan_v2_24khz_100band_256x
@@ -0,0 +1 @@
+Subproject commit c329ede9e9bbc100ddf5c91e2330a61921262370
diff --git a/config.py b/config.py
new file mode 100755
index 0000000000000000000000000000000000000000..62010255c1f82d2c56bc08b115d6bef9e748cd67
--- /dev/null
+++ b/config.py
@@ -0,0 +1,111 @@
+import os,torch
+import sys
+
+dir_path = os.path.abspath(os.path.dirname(__file__))
+
+class config:
+
+ data_path='./'
+ model_name="OSS Shogun"
+ user_name = "xxx" #for wandb
+
+ desc = '''
+ OSS
+ '''
+
+ train_file = "Semantic_tokens/SEMANTICS_/train.txt"
+ val_file = "Semantic_tokens/SEMANTICS_/val.txt"
+
+ norms = torch.load(f"{dir_path}/mel_norms.pt")
+ mu = norms["mean_val"]
+ std = norms["std"]
+ scale = True
+ semantic_model_centroids = 10000 + 1
+ seed_value = 3407
+
+ t2s_checkpoint = "/omega/Models/FT_ENGLISH/T2S/1_latest.pt"
+ ts_finetuning = True
+ ts_wandb_logs = False
+ text_loss_weight = 0.01
+ t2s_position = 8192
+ ts_batch_size = 1
+ ts_epochs = 10
+ ts_lr = 1e-5
+ ts_weight_decay = 1e-4
+ ts_eval_epoch = 1
+ ts_num_workers = 8
+ ts_gradient_accumulation_steps = 1 # EfBS of 128 for finetuning, 256 for pretraining, around 9k steps for sft for 2 epochs
+ ts_eval_step = 10000
+
+ langs = [
+ "odia",
+ "assamese",
+ "thai",
+ "gujrati",
+ "russian",
+ "japanese",
+ "punjabi",
+ "hindi",
+ "manipuri",
+ "korean",
+ "bhojpuri",
+ "sanskrit",
+ "english",
+ "french",
+ "bodo",
+ "malayalam",
+ "telugu",
+ "kannada",
+ "dogri",
+ "marathi",
+ "german",
+ "italian",
+ "rajasthani",
+ "spanish",
+ "arabic",
+ "urdu",
+ "gujarati",
+ "tamil",
+ "bengali",
+ ]
+
+ lang_index = {i: j for j, i in enumerate(langs)}
+
+ # Train s2a
+ sa_wandb_logs = False
+ joint_training = (False,) # doesn't work
+ checkpoint = "/omega/Models/" + "FT_ENGLISH/"
+ sa_timesteps_max = 1000
+ sa_batch_size = 32
+ sa_epochs = 5000000
+ gradient_accumulation_steps = 4
+ sa_lr = 1e-4
+ sa_weight_decay = 1e-2
+ sa_eval_step = 10000
+ sa_infer = True
+ sa_infer_epoch = 1
+ sa_num_workers = 24
+
+ # Train Dvae (not using)
+ dvae_wandb_logs = True
+ dvae_batch_size = 128
+ dvae_epochs = 5000
+ dvae_lr = 3e-4
+ dvae_weight_decay = 1e-2
+ dvae_eval_epoch = 1
+ dvae_infer = True
+ dvae_infer_epoch = 1
+ dvae_num_workers = 16
+
+ # Acoustic Properties, Do not change
+ CLIP_LENGTH = 500
+ MAX_WAV_VALUE = 32768.0 - 1
+ filter_length = 1024
+ hop_length = 256 # 256
+ window = "hann"
+ win_length = 1024
+ n_mel_channels = 100
+ sampling_rate = 24000
+ mel_fmin = 0.0
+ mel_fmax = None
+ normalize = True
diff --git a/generated_samples/0_hindi.wav b/generated_samples/0_hindi.wav
new file mode 100644
index 0000000000000000000000000000000000000000..6eacb028a2388b2edd1ed4812b4909115d080287
--- /dev/null
+++ b/generated_samples/0_hindi.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f7e0dda3c6d53b0658aff88ec422608f6fda9e3a9626d49338a700c67d296fe
+size 966188
diff --git a/inference.py b/inference.py
new file mode 100755
index 0000000000000000000000000000000000000000..e50e3a11e79803ad1a56eb1769d1f3cecfa299d8
--- /dev/null
+++ b/inference.py
@@ -0,0 +1,217 @@
+import os,sys,time,struct
+import torch,torchaudio
+
+# sys.path.append("S2A/bigvgan_v2_24khz_100band_256x")
+# sys.path.append("S2A/")
+# sys.path.append("T2S/")
+# sys.path.append("hifi-gan/")
+
+from S2A.inference import *
+from S2A.diff_model import DiffModel
+from T2S.autoregressive import TS_model
+from T2S.mel_spec import get_mel_spectrogram
+from Text import labels,text_labels,code_labels
+from config import config
+
+#code encdec
+text_enc = {j:i for i,j in enumerate(text_labels)}
+text_dec = {i:j for i,j in enumerate(text_labels)}
+
+# text encdec
+code_enc = {j:i for i,j in enumerate(code_labels)}
+code_dec = {i:j for i,j in enumerate(code_labels)}
+
+def create_wav_header(sample_rate = 24000, bits_per_sample=16, channels=1):
+ # "RIFF" chunk descriptor
+ chunk_id = b'RIFF'
+ chunk_size = 0xFFFFFFFF # Placeholder for chunk size (unknown during streaming)
+ format = b'WAVE'
+
+ # "fmt " sub-chunk (16 bytes for PCM format)
+ subchunk1_id = b'fmt '
+ subchunk1_size = 16 # PCM format
+ audio_format = 1 # PCM = 1 (linear quantization)
+ num_channels = channels
+ sample_rate = sample_rate
+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
+ block_align = num_channels * bits_per_sample // 8
+ bits_per_sample = bits_per_sample
+
+ # "data" sub-chunk
+ subchunk2_id = b'data'
+ subchunk2_size = 0xFFFFFFFF # Placeholder for data size (unknown during streaming)
+
+ # Pack the header into a byte object using struct
+ header = struct.pack('<4sI4s4sIHHIIHH4sI',
+ chunk_id,
+ chunk_size,
+ format,
+ subchunk1_id,
+ subchunk1_size,
+ audio_format,
+ num_channels,
+ sample_rate,
+ byte_rate,
+ block_align,
+ bits_per_sample,
+ subchunk2_id,
+ subchunk2_size)
+
+ return header
+
+
+def get_processed_clips(ref_clips):
+ frame_rate = 24000
+ new_ref_clips = []
+ for i in ref_clips:
+ if '_proc.wav' in i:
+ new_ref_clips.append(i)
+ continue
+ audio = AudioSegment.from_file(i)
+ audio = audio.set_channels(1)
+ audio = audio.set_frame_rate(frame_rate).set_sample_width(2)
+ audio.export(i[:-4]+'_proc.wav',format='wav')
+ new_ref_clips.append(i[:-4]+'_proc.wav')
+
+ return new_ref_clips
+
+def get_ref_mels(ref_clips):
+ ref_mels = []
+ for i in ref_clips:
+ audio_norm,sampling_rate = torchaudio.load(i)
+ ref_mels.append(get_mel_spectrogram(audio_norm,sampling_rate).squeeze(0)[:, :500])
+
+ ref_mels_padded = (torch.randn((len(ref_mels), 100, 500))) * 1e-9
+ for i, mel in enumerate(ref_mels):
+ ref_mels_padded[i, :, : mel.size(1)] = mel
+ return ref_mels_padded.unsqueeze(0)
+
+def load_cfm(checkpoint,vocoder_checkpoint=None,device="cpu"):
+ FM = BASECFM()
+ if vocoder_checkpoint is None:
+ hifi = bigvgan.BigVGAN.from_pretrained('nvidia/bigvgan_v2_24khz_100band_256x', use_cuda_kernel=False)
+ else:
+ hifi = bigvgan.BigVGAN.from_pretrained(vocoder_checkpoint, use_cuda_kernel=False)
+
+ hifi.remove_weight_norm()
+ hifi = hifi.eval().to(device)
+
+ model = DiffModel(input_channels=100,
+ output_channels=100,
+ model_channels=512,
+ num_heads=8,
+ dropout=0.1,
+ num_layers=8,
+ enable_fp16=False,
+ condition_free_per=0.0,
+ multispeaker=True,
+ style_tokens=100,
+ training=False,
+ ar_active=False)
+
+ model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu'))['model'])
+ model.eval().to(device)
+ mu= torch.load(checkpoint,map_location=torch.device('cpu'))['norms']['mean_val']
+ std = torch.load(checkpoint,map_location=torch.device('cpu'))['norms']['std']
+
+ return FM,hifi,model,mu,std
+
+def load_t2s_model(checkpoint,device):
+ model = TS_model(n_embed= 1024, n_layer= 30, n_head = 16)
+ model.load_state_dict(torch.load(checkpoint,map_location=torch.device('cpu'))['model'],strict=True)
+ model.eval()
+ model.to(device)
+ model.init_gpt_for_inference()
+ return model
+
+def prepare_inputs(text,ref_clips_m1,ref_clips_m2,language,device):
+ code_ids = [code_enc[""]]
+ text_ids = (
+ torch.tensor(
+ [text_enc[""]]
+ + [text_enc[i] for i in text.strip()]
+ + [text_enc[""]]
+ )
+ .to(device)
+ .unsqueeze(0)
+ )
+ language = (
+ torch.tensor(config.lang_index[language]).to(device).unsqueeze(0)
+ )
+
+ ref_mels_m1 = get_ref_mels(get_processed_clips(ref_clips_m1))
+ ref_mels_m2 = get_ref_mels(get_processed_clips(ref_clips_m2))
+
+ return text_ids,code_ids,language,ref_mels_m1,ref_mels_m2
+
+
+def infer(m1,m2,vocoder,FM,mu,std,text_ids,code_ids,language,ref_mels_m1,ref_mels_m2,device):
+ with torch.no_grad():
+ cond_latents = m1.get_speaker_latent(ref_mels_m1.to(device))
+ code_emb = m1.generate(
+ language.to(device), cond_latents.to(device), text_ids.to(device), code_ids, **{
+ "temperature": 0.8,
+ "length_penalty": None,
+ "repetition_penalty": None,
+ "top_k": 50,
+ "top_p": 0.8,
+ "do_sample": True,
+ "num_beams": 1,
+ "max_new_tokens": 1500
+ }
+ )[:, :-1]
+
+ mel = FM(m2, code_emb+1, (1, 100, int(1+93*(code_emb.shape[-1]+1)/50)), ref_mels_m2.to(device), n_timesteps=20, temperature=1.0)
+ mel = denormalize_tacotron_mel(mel,mu,std)
+ audio = vocoder(mel)
+ audio = audio.squeeze(0).detach().cpu()
+ audio = audio * 32767.0
+ audio = (
+ audio.numpy().reshape(-1).astype(np.int16)
+ )
+
+ return audio
+
+if __name__ == '__main__':
+ os.makedirs("generated_samples/",exist_ok=True)
+ start = time.time()
+ device = "cuda" if torch.cuda.is_available() else "cpu"
+ m1_checkpoint = "pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt"
+ m2_checkpoint = "pretrained_checkpoint/m2.pt"
+ vocoder_checkpoint = 'pretrained_checkpoint/700_580k_multilingual_infer_ready/'
+
+ FM,vocoder,m2,mu,std = load_cfm(m2_checkpoint,vocoder_checkpoint,device)
+ m1 = load_t2s_model(m1_checkpoint,device)
+
+ model_loading_time = time.time()
+
+ output_file_path = "test.wav"
+
+ texts = ["यह एक उदाहरणात्मक हिंदी पाठ है जिसका उद्देश्य भाषा की संरचना और शब्दों के प्रवाह को समझना है। भारत एक विविधताओं से भरा देश है जहाँ अनेक भाषाएँ, धर्म, और संस्कृतियाँ एक साथ मिलकर रहते हैं। यहाँ की परंपराएँ, त्योहार और भोजन इसकी सांस्कृतिक समृद्धि को दर्शाते हैं।"]
+ languages=['hindi']
+
+ ref_clips = [
+ 'speakers/female1/train_hindifemale_02794.wav',
+ 'speakers/female1/train_hindifemale_04167.wav',
+ 'speakers/female1/train_hindifemale_02795.wav'
+ ]
+
+ for n,(lang,text) in tqdm(enumerate(zip(languages,texts))):
+ text_ids,code_ids,language,ref_mels_m1,ref_mels_m2 = prepare_inputs(text.lower(),
+ ref_clips_m1=ref_clips,
+ ref_clips_m2=ref_clips,
+ language=lang
+ ,device=device)
+ audio_wav = infer(m1,m2,vocoder,FM,mu,std,text_ids,code_ids,language,ref_mels_m1,ref_mels_m2,device)
+
+ with open(f"generated_samples/{n}_{lang}.wav",'wb') as file:
+ file.write(create_wav_header(sample_rate = 24000, bits_per_sample=16, channels=1))
+ file.write(audio_wav.tobytes())
+
+ audio_generation_time=time.time()
+
+ print()
+ print(text)
+ print(audio_generation_time-start,":Total Time taken")
+ print(model_loading_time-start, ":Model Loading time")
+ print(audio_generation_time-model_loading_time, ":Audio Generation time")
\ No newline at end of file
diff --git a/mel_norms.pt b/mel_norms.pt
new file mode 100755
index 0000000000000000000000000000000000000000..41652df14fa0efec6b9bebe3cff3bbfbd483f209
Binary files /dev/null and b/mel_norms.pt differ
diff --git a/pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt b/pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt
new file mode 100755
index 0000000000000000000000000000000000000000..b49553edb8694ab19989a384c5347853959e8885
--- /dev/null
+++ b/pretrained_checkpoint/700_580k_multilingual_infer_ready/bigvgan_generator.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:e612993a903a446054367d1b7701577d35948634324dd7ee3c00ff5288a986c5
+size 450080461
diff --git a/pretrained_checkpoint/700_580k_multilingual_infer_ready/config.json b/pretrained_checkpoint/700_580k_multilingual_infer_ready/config.json
new file mode 100755
index 0000000000000000000000000000000000000000..9bb34ec5f85700a867d8768e67835c68736be857
--- /dev/null
+++ b/pretrained_checkpoint/700_580k_multilingual_infer_ready/config.json
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:96920a4f61526edb5068b3dbd35f739d20039f129bfc71ccbd6182a4e74befaf
+size 1401
diff --git a/pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt b/pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt
new file mode 100755
index 0000000000000000000000000000000000000000..c2b44f585de24004d812212e8cc422ea2cf90ddc
--- /dev/null
+++ b/pretrained_checkpoint/m1_gemma_benchmark_1_latest_weights.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:ef330d44d3c940875257f80c4a46cfacb4cbcd3d296b9dc65c6bcb0082a95ee9
+size 3303697806
diff --git a/pretrained_checkpoint/m2.pt b/pretrained_checkpoint/m2.pt
new file mode 100755
index 0000000000000000000000000000000000000000..e104ed0b35a78ea933b351d8f4476d406df8ecc6
--- /dev/null
+++ b/pretrained_checkpoint/m2.pt
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:978e0242353299ab14bd2eb8981439f0c58fdcf2f9ecf2b99ce986c50834b844
+size 863312522
diff --git a/requirements.txt b/requirements.txt
new file mode 100755
index 0000000000000000000000000000000000000000..00cdad9713e5bde226c196d77f7c31ff4f53efa2
--- /dev/null
+++ b/requirements.txt
@@ -0,0 +1,12 @@
+torch
+torchaudio
+torchvision
+transformers
+accelerate
+einops
+inflect
+pydub
+unidecode
+matplotlib
+librosa
+gradio
\ No newline at end of file
diff --git a/server.py b/server.py
new file mode 100755
index 0000000000000000000000000000000000000000..735fc09da4528424c8a7a175e41a8631aba091e1
--- /dev/null
+++ b/server.py
@@ -0,0 +1,226 @@
+import os
+import sys
+import time
+import struct
+import random
+from uuid import uuid4
+from typing import List, Optional
+
+import torch
+import torchaudio
+from fastapi import FastAPI, UploadFile, File, Form, HTTPException
+from fastapi.responses import FileResponse, JSONResponse
+import uvicorn
+
+# append model paths
+sys.path.append("S2A/bigvgan_v2_24khz_100band_256x")
+sys.path.append("S2A/")
+sys.path.append("T2S/")
+sys.path.append("hifi-gan/")
+
+# from S2A.inference import *
+# from T2S.autoregressive import TS_model
+# from T2S.mel_spec import get_mel_spectrogram
+# from Text import labels, text_labels, code_labels
+from config import config
+from torch.cuda.amp import autocast
+from inference import *
+# directories for saving uploads and generated audio
+UPLOAD_DIR = "uploads"
+OUTPUT_DIR = "generated_samples"
+os.makedirs(UPLOAD_DIR, exist_ok=True)
+os.makedirs(OUTPUT_DIR, exist_ok=True)
+
+# text/code encoders
+text_enc = {j: i for i, j in enumerate(text_labels)}
+code_enc = {j: i for i, j in enumerate(code_labels)}
+
+# inference globals
+FM = None
+vocoder = None
+m2 = None
+mu = None
+std = None
+m1 = None
+DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
+
+# load models on startup
+def load_models(
+ m1_ckpt: str,
+ m2_ckpt: str,
+ vocoder_ckpt: Optional[str]
+):
+ global FM, vocoder, m2, mu, std, m1
+ FM,vocoder,m2,mu,std = load_cfm(m2_ckpt,vocoder_ckpt,DEVICE)
+ m1 = load_t2s_model(m1_ckpt,DEVICE)
+
+# utility: WAV header
+def create_wav_header(sample_rate=24000, bits_per_sample=16, channels=1):
+ chunk_id = b'RIFF'
+ chunk_size = 0xFFFFFFFF
+ format_tag = b'WAVE'
+ subchunk1_id = b'fmt '
+ subchunk1_size = 16
+ audio_format = 1
+ num_channels = channels
+ byte_rate = sample_rate * num_channels * bits_per_sample // 8
+ block_align = num_channels * bits_per_sample // 8
+ subchunk2_id = b'data'
+ subchunk2_size = 0xFFFFFFFF
+ header = struct.pack(
+ '<4sI4s4sIHHIIHH4sI',
+ chunk_id,
+ chunk_size,
+ format_tag,
+ subchunk1_id,
+ subchunk1_size,
+ audio_format,
+ num_channels,
+ sample_rate,
+ byte_rate,
+ block_align,
+ bits_per_sample,
+ subchunk2_id,
+ subchunk2_size,
+ )
+ return header
+
+# # prepare mels
+# def get_processed_clips(ref_clips: List[str]):
+# frame_rate = 24000
+# new_clips = []
+# from pydub import AudioSegment
+
+# for path in ref_clips:
+# if path.endswith('_proc.wav'):
+# new_clips.append(path)
+# continue
+# audio = AudioSegment.from_file(path)
+# audio = audio.set_channels(1).set_frame_rate(frame_rate).set_sample_width(2)
+# out = path.rstrip('.') + '_proc.wav'
+# audio.export(out, format='wav')
+# new_clips.append(out)
+# return new_clips
+
+# def get_ref_mels(ref_clips: List[str]):
+# ref_mels = []
+# for p in ref_clips:
+# audio_norm, sr = torchaudio.load(p)
+# ref_mels.append(get_mel_spectrogram(audio_norm, sr).squeeze(0)[:, :1024])
+# # pad to (len,100,500)
+# padded = torch.randn((len(ref_mels), 100, 1024)) * 1e-9
+# for i, mel in enumerate(ref_mels):
+# padded[i, :, : mel.size(1)] = mel
+# return padded.unsqueeze(0)
+
+app = FastAPI(title="T2S+CFM Inference API")
+
+@app.on_event("startup")
+def on_startup():
+ # configure these paths as needed
+ m1_checkpoint = []
+
+ m1_checkpoint = os.getenv('M1_CKPT', "/delta/MahaTTS/models/m1_gemma_benchmark_1_latest_weights.pt")
+ # m1_checkpoint.append((os.getenv('M1_CKPT', "/delta/horizon/133939_7_latest.pt"),"pt-1"))
+ # m1_checkpoint.append((os.getenv('M1_CKPT', "/delta/horizon/137877_8_latest.pt"),"pt-2"))
+
+ m2_checkpoint = os.getenv('M2_CKPT', '/delta/model_gemma/_latest_700000.pt')
+ vocoder_checkpoint = os.getenv('VOCODER_CKPT', '/delta/model_gemma/700_580k_multilingual_infer_ready/')
+ load_models(m1_checkpoint, m2_checkpoint, vocoder_checkpoint)
+
+@app.post("/infer")
+async def infer_endpoint(
+ text: str = Form(..., description="Input text to synthesize"),
+ language: str = Form(..., description="Language code, e.g. 'hindi' or 'english'"),
+ seed: int = Form(0),
+ temperature: float = Form(0.8),
+ length_penalty: Optional[float] = Form(None),
+ repetition_penalty: Optional[float] = Form(None),
+ top_k: int = Form(50),
+ top_p: float = Form(0.8),
+ do_sample: bool = Form(True),
+ num_beams: int = Form(1),
+ n_timesteps: int = Form(20),
+ no_repeat_ngram_size: int = Form(None),
+ ref_clips_m1: List[UploadFile] = File(..., description="Reference audio files for m1"),
+ ref_clips_m2: List[UploadFile] = File(..., description="Reference audio files for m2"),
+ model_name: str = Form("pt-2"),
+):
+
+ print(text)
+ # save uploaded reference clips
+ def save_files(files):
+ paths = []
+ for f in files:
+ fname = f"{uuid4().hex}_{f.filename}"
+ fpath = os.path.join(UPLOAD_DIR, fname)
+ with open(fpath, "wb") as out:
+ out.write(f.file.read())
+ paths.append(fpath)
+ return paths
+
+ # try:
+ m1_paths = save_files(ref_clips_m1)
+ m2_paths = save_files(ref_clips_m2)
+
+ # prepare inputs
+ text_ids, code_ids, lang_tensor, ref_mels1, ref_mels2 = prepare_inputs(
+ text.lower().strip(), m1_paths, m2_paths, language, device=str(DEVICE)
+ )
+
+ # set RNG seeds
+ torch.manual_seed(seed)
+ torch.cuda.manual_seed_all(seed)
+ random.seed(seed)
+ if repetition_penalty == 0:
+ repetition_penalty=None
+ print("repetition_penalty",repetition_penalty)
+ print("no_repeat_ngram_size",no_repeat_ngram_size)
+ # generate code embedding
+ seed_value = 42
+ with torch.no_grad(),autocast(dtype=torch.bfloat16):
+ torch.manual_seed(seed_value)
+ torch.cuda.manual_seed_all(seed_value)
+ np.random.seed(seed_value)
+ random.seed(seed_value)
+ cond_latents = m1.get_speaker_latent(ref_mels1.to(DEVICE))
+ code_emb = m1.generate(
+ lang_tensor.to(DEVICE), cond_latents.to(DEVICE), text_ids.to(DEVICE), code_ids,
+ temperature=temperature,
+ length_penalty=length_penalty,
+ repetition_penalty=repetition_penalty,
+ top_k=top_k,
+ top_p=top_p,
+ do_sample=do_sample,
+ num_beams=num_beams,
+ no_repeat_ngram_size=no_repeat_ngram_size,
+ max_new_tokens = 1500,
+ renormalize_logits = True,
+ penalty_alpha=0
+ )[:, :-1]
+ print(code_emb.shape[-1],code_emb)
+ torch.save(code_emb,"file.txt")
+ mel = FM(m2, code_emb+1, (1, 100, int(1+93*(code_emb.shape[-1]+1)/50)), ref_mels2.to(DEVICE), n_timesteps=20, temperature=1.0)
+ mel = denormalize_tacotron_mel(mel,mu,std)
+ audio = vocoder(mel)
+ audio = audio.squeeze(0).detach().cpu()
+ audio = audio * 32767.0
+ audio_int16 = (
+ audio.to(torch.float32).numpy().reshape(-1).astype(np.int16)
+ )
+ # save output wav
+ out_name = f"{uuid4().hex}.wav"
+ out_path = os.path.join(OUTPUT_DIR, out_name)
+ with open(out_path, "wb") as wf:
+ wf.write(create_wav_header())
+ wf.write(audio_int16.tobytes())
+
+ return FileResponse(out_path, media_type="audio/wav", filename=out_name)
+
+ # except Exception as e:
+ # print(e)
+ # raise HTTPException(status_code=500, detail=str(e))
+
+if __name__ == "__main__":
+ uvicorn.run(app, host="0.0.0.0", port=6000)
+ # use ngrok
\ No newline at end of file
diff --git a/speakers/female1/train_hindifemale_02794.wav b/speakers/female1/train_hindifemale_02794.wav
new file mode 100755
index 0000000000000000000000000000000000000000..3f801210a35e7e8f6fec290c340dd9f4e4447fd9
--- /dev/null
+++ b/speakers/female1/train_hindifemale_02794.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25b1ea4f48dec8b3e238dc786a5ba43a21587d93b1a4c5c0391f227fd6fccd88
+size 197940
diff --git a/speakers/female1/train_hindifemale_02794_proc.wav b/speakers/female1/train_hindifemale_02794_proc.wav
new file mode 100644
index 0000000000000000000000000000000000000000..3f801210a35e7e8f6fec290c340dd9f4e4447fd9
--- /dev/null
+++ b/speakers/female1/train_hindifemale_02794_proc.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:25b1ea4f48dec8b3e238dc786a5ba43a21587d93b1a4c5c0391f227fd6fccd88
+size 197940
diff --git a/speakers/female1/train_hindifemale_02795.wav b/speakers/female1/train_hindifemale_02795.wav
new file mode 100755
index 0000000000000000000000000000000000000000..015c55e67b40d58380375de84a66b160869e3ea4
--- /dev/null
+++ b/speakers/female1/train_hindifemale_02795.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f5e5e39f2de0631b33bada4a44f8f619145f4415c72fd5f941af2254f310ada
+size 291224
diff --git a/speakers/female1/train_hindifemale_02795_proc.wav b/speakers/female1/train_hindifemale_02795_proc.wav
new file mode 100644
index 0000000000000000000000000000000000000000..015c55e67b40d58380375de84a66b160869e3ea4
--- /dev/null
+++ b/speakers/female1/train_hindifemale_02795_proc.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4f5e5e39f2de0631b33bada4a44f8f619145f4415c72fd5f941af2254f310ada
+size 291224
diff --git a/speakers/female1/train_hindifemale_04167.wav b/speakers/female1/train_hindifemale_04167.wav
new file mode 100755
index 0000000000000000000000000000000000000000..9f74943fed7ae94b4de4a2ee4dfd081e4f47e67f
--- /dev/null
+++ b/speakers/female1/train_hindifemale_04167.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4692a763a595f959d4854d33dde9cfcf187e892160ffaa7e075545c812c59403
+size 382240
diff --git a/speakers/female1/train_hindifemale_04167_proc.wav b/speakers/female1/train_hindifemale_04167_proc.wav
new file mode 100644
index 0000000000000000000000000000000000000000..9f74943fed7ae94b4de4a2ee4dfd081e4f47e67f
--- /dev/null
+++ b/speakers/female1/train_hindifemale_04167_proc.wav
@@ -0,0 +1,3 @@
+version https://git-lfs.github.com/spec/v1
+oid sha256:4692a763a595f959d4854d33dde9cfcf187e892160ffaa7e075545c812c59403
+size 382240