Spaces:
Running
on
T4
Running
on
T4
| """ | |
| Copied from https://github.com/KdaiP/StableTTS by https://github.com/KdaiP | |
| https://github.com/KdaiP/StableTTS/blob/eebb177ebf195fd1246dedabec4ef69d9351a4f8/models/dit.py | |
| Code is under MIT License | |
| """ | |
| # References: | |
| # https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/components/transformer.py | |
| # https://github.com/jaywalnut310/vits/blob/main/attentions.py | |
| # https://github.com/pytorch-labs/gpt-fast/blob/main/model.py | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| class FFN(nn.Module): | |
| def __init__(self, in_channels, out_channels, filter_channels, kernel_size, p_dropout=0., gin_channels=0): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.out_channels = out_channels | |
| self.filter_channels = filter_channels | |
| self.kernel_size = kernel_size | |
| self.p_dropout = p_dropout | |
| self.gin_channels = gin_channels | |
| self.conv_1 = nn.Conv1d(in_channels, filter_channels, kernel_size, padding=kernel_size // 2) | |
| self.conv_2 = nn.Conv1d(filter_channels, out_channels, kernel_size, padding=kernel_size // 2) | |
| self.drop = nn.Dropout(p_dropout) | |
| self.act1 = nn.GELU(approximate="tanh") | |
| def forward(self, x, x_mask): | |
| x = self.conv_1(x * x_mask) | |
| x = self.act1(x) | |
| x = self.drop(x) | |
| x = self.conv_2(x * x_mask) | |
| return x * x_mask | |
| class MultiHeadAttention(nn.Module): | |
| def __init__(self, channels, out_channels, n_heads, p_dropout=0.): | |
| super().__init__() | |
| assert channels % n_heads == 0 | |
| self.channels = channels | |
| self.out_channels = out_channels | |
| self.n_heads = n_heads | |
| self.p_dropout = p_dropout | |
| self.k_channels = channels // n_heads | |
| self.conv_q = torch.nn.Conv1d(channels, channels, 1) | |
| self.conv_k = torch.nn.Conv1d(channels, channels, 1) | |
| self.conv_v = torch.nn.Conv1d(channels, channels, 1) | |
| # from https://nn.labml.ai/transformers/rope/index.html | |
| self.query_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) | |
| self.key_rotary_pe = RotaryPositionalEmbeddings(self.k_channels * 0.5) | |
| self.conv_o = torch.nn.Conv1d(channels, out_channels, 1) | |
| self.drop = torch.nn.Dropout(p_dropout) | |
| torch.nn.init.xavier_uniform_(self.conv_q.weight) | |
| torch.nn.init.xavier_uniform_(self.conv_k.weight) | |
| torch.nn.init.xavier_uniform_(self.conv_v.weight) | |
| def forward(self, x, attn_mask=None): | |
| q = self.conv_q(x) | |
| k = self.conv_k(x) | |
| v = self.conv_v(x) | |
| x = self.attention(q, k, v, mask=attn_mask) | |
| x = self.conv_o(x) | |
| return x | |
| def attention(self, query, key, value, mask=None): | |
| b, d, t_s, t_t = (*key.size(), query.size(2)) | |
| query = query.view(b, self.n_heads, self.k_channels, t_t).transpose(2, 3) | |
| key = key.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
| value = value.view(b, self.n_heads, self.k_channels, t_s).transpose(2, 3) | |
| query = self.query_rotary_pe(query) # [b, n_head, t, c // n_head] | |
| key = self.key_rotary_pe(key) | |
| output = F.scaled_dot_product_attention(query, key, value, attn_mask=mask, dropout_p=self.p_dropout if self.training else 0) | |
| output = output.transpose(2, 3).contiguous().view(b, d, t_t) # [b, n_h, t_t, d_k] -> [b, d, t_t] | |
| return output | |
| # modified from https://github.com/sh-lee-prml/HierSpeechpp/blob/main/modules.py#L390 | |
| class DiTConVBlock(nn.Module): | |
| """ | |
| A DiT block with adaptive layer norm zero (adaLN-Zero) conditioning. | |
| """ | |
| def __init__(self, hidden_channels, out_channels, filter_channels, num_heads, kernel_size=3, p_dropout=0.1, gin_channels=0): | |
| super().__init__() | |
| self.norm1 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) | |
| self.attn = MultiHeadAttention(hidden_channels + out_channels, hidden_channels + out_channels, num_heads, p_dropout) | |
| self.norm2 = nn.LayerNorm(hidden_channels + out_channels, elementwise_affine=False, eps=1e-6) | |
| self.mlp = FFN(hidden_channels + out_channels, hidden_channels + out_channels, filter_channels, kernel_size, p_dropout=p_dropout) | |
| self.adaLN_modulation = nn.Sequential( | |
| nn.Linear(gin_channels, hidden_channels + out_channels) if gin_channels != hidden_channels + out_channels else nn.Identity(), | |
| nn.SiLU(), | |
| nn.Linear(hidden_channels + out_channels, 6 * (hidden_channels + out_channels), bias=True) | |
| ) | |
| def forward(self, x, c, x_mask): | |
| """ | |
| Args: | |
| x : [batch_size, channel, time] | |
| c : [batch_size, channel] | |
| x_mask : [batch_size, 1, time] | |
| return the same shape as x | |
| """ | |
| x = x * x_mask | |
| attn_mask = x_mask.unsqueeze(1) * x_mask.unsqueeze(-1) # shape: [batch_size, 1, time, time] | |
| # attn_mask = attn_mask.to(torch.bool) | |
| if c is not None: | |
| shift_msa, scale_msa, gate_msa, shift_mlp, scale_mlp, gate_mlp = self.adaLN_modulation(c).unsqueeze(2).chunk(6, dim=1) # shape: [batch_size, channel, 1] | |
| x = x + gate_msa * self.attn(self.modulate(self.norm1(x.transpose(1, 2)).transpose(1, 2), shift_msa, scale_msa), attn_mask) * x_mask | |
| # x = x.masked_fill(~x_mask, 0.0) | |
| x = x + gate_mlp * self.mlp(self.modulate(self.norm2(x.transpose(1, 2)).transpose(1, 2), shift_mlp, scale_mlp), x_mask) * x_mask | |
| else: | |
| # no condition version | |
| x = x + self.attn(self.norm1(x.transpose(1, 2)).transpose(1, 2), attn_mask) | |
| x = x + self.mlp(self.norm1(x.transpose(1, 2)).transpose(1, 2), x_mask) | |
| return x | |
| def modulate(x, shift, scale): | |
| return x * (1 + scale) + shift | |
| class RotaryPositionalEmbeddings(nn.Module): | |
| """ | |
| ## RoPE module | |
| Rotary encoding transforms pairs of features by rotating in the 2D plane. | |
| That is, it organizes the $d$ features as $\frac{d}{2}$ pairs. | |
| Each pair can be considered a coordinate in a 2D plane, and the encoding will rotate it | |
| by an angle depending on the position of the token. | |
| """ | |
| def __init__(self, d: int, base: int = 10_000): | |
| r""" | |
| * `d` is the number of features $d$ | |
| * `base` is the constant used for calculating $\Theta$ | |
| """ | |
| super().__init__() | |
| self.base = base | |
| self.d = int(d) | |
| self.cos_cached = None | |
| self.sin_cached = None | |
| def _build_cache(self, x: torch.Tensor): | |
| r""" | |
| Cache $\cos$ and $\sin$ values | |
| """ | |
| # Return if cache is already built | |
| if self.cos_cached is not None and x.shape[0] <= self.cos_cached.shape[0]: | |
| return | |
| # Get sequence length | |
| seq_len = x.shape[0] | |
| # $\Theta = {\theta_i = 10000^{-\frac{2(i-1)}{d}}, i \in [1, 2, ..., \frac{d}{2}]}$ | |
| theta = 1.0 / (self.base ** (torch.arange(0, self.d, 2).float() / self.d)).to(x.device) | |
| # Create position indexes `[0, 1, ..., seq_len - 1]` | |
| seq_idx = torch.arange(seq_len, device=x.device).float().to(x.device) | |
| # Calculate the product of position index and $\theta_i$ | |
| idx_theta = torch.einsum("n,d->nd", seq_idx, theta) | |
| # Concatenate so that for row $m$ we have | |
| # $[m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}, m \theta_0, m \theta_1, ..., m \theta_{\frac{d}{2}}]$ | |
| idx_theta2 = torch.cat([idx_theta, idx_theta], dim=1) | |
| # Cache them | |
| self.cos_cached = idx_theta2.cos()[:, None, None, :] | |
| self.sin_cached = idx_theta2.sin()[:, None, None, :] | |
| def _neg_half(self, x: torch.Tensor): | |
| # $\frac{d}{2}$ | |
| d_2 = self.d // 2 | |
| # Calculate $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
| return torch.cat([-x[:, :, :, d_2:], x[:, :, :, :d_2]], dim=-1) | |
| def forward(self, x: torch.Tensor): | |
| """ | |
| * `x` is the Tensor at the head of a key or a query with shape `[seq_len, batch_size, n_heads, d]` | |
| """ | |
| # Cache $\cos$ and $\sin$ values | |
| x = x.permute(2, 0, 1, 3) # b h t d -> t b h d | |
| self._build_cache(x) | |
| # Split the features, we can choose to apply rotary embeddings only to a partial set of features. | |
| x_rope, x_pass = x[..., : self.d], x[..., self.d:] | |
| # Calculate | |
| # $[-x^{(\frac{d}{2} + 1)}, -x^{(\frac{d}{2} + 2)}, ..., -x^{(d)}, x^{(1)}, x^{(2)}, ..., x^{(\frac{d}{2})}]$ | |
| neg_half_x = self._neg_half(x_rope) | |
| x_rope = (x_rope * self.cos_cached[: x.shape[0]]) + (neg_half_x * self.sin_cached[: x.shape[0]]) | |
| return torch.cat((x_rope, x_pass), dim=-1).permute(1, 2, 0, 3) # t b h d -> b h t d | |
| class Transpose(nn.Identity): | |
| """(N, T, D) -> (N, D, T)""" | |
| def forward(self, input: torch.Tensor) -> torch.Tensor: | |
| return input.transpose(1, 2) | |