MahaTTSv2 / S2A /modules.py
rasenganai
init
41bc8a8
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.nn.init as init
from einops import rearrange, repeat
from torch.nn.utils import weight_norm
def zero_module(module):
"""
Zero out the parameters of a module and return it.
Using it for Zero Convolutions
"""
for p in module.parameters():
p.detach().zero_()
return module
class GroupNorm32(nn.GroupNorm):
def forward(self, x):
return super().forward(x.float()).type(x.dtype)
def normalization(channels):
"""
Make a standard normalization layer. of groups ranging from 2 to 32.
:param channels: number of input channels.
:return: an nn.Module for normalization.
"""
# return nn.LayerNorm(normalized_shape)
groups = 32
if channels <= 16:
groups = 8
elif channels <= 64:
groups = 16
while channels % groups != 0:
groups = int(groups / 2)
assert groups > 2
return GroupNorm32(groups, channels)
class mySequential(nn.Sequential):
"""Using this to pass mask variable to nn layers"""
def forward(self, *inputs):
for module in self._modules.values():
if type(inputs) == tuple:
inputs = module(*inputs)
else:
inputs = module(inputs)
return inputs
class SepConv1D(nn.Module):
"""Depth wise separable Convolution layer with mask"""
def __init__(
self,
nin,
nout,
kernel_size,
stride=1,
dilation=1,
padding_mode="same",
bias=False,
):
super(SepConv1D, self).__init__()
self.conv1 = nn.Conv1d(
nin,
nin,
kernel_size=kernel_size,
stride=stride,
groups=nin,
dilation=dilation,
padding=padding_mode,
bias=bias,
)
self.conv2 = nn.Conv1d(
nin, nout, kernel_size=1, stride=1, padding=padding_mode, bias=bias
)
def forward(self, x, mask=None):
if mask is not None:
x = x * mask.unsqueeze(1).to(device=x.device)
x = self.conv1(x)
x = self.conv2(x)
return x, mask
class Conv1DBN(nn.Module):
def __init__(
self,
nin,
nout,
kernel_size,
stride=1,
dilation=1,
dropout=0.1,
padding_mode="same",
bias=False,
):
super(Conv1DBN, self).__init__()
self.conv1 = nn.Conv1d(
nin,
nout,
kernel_size=kernel_size,
stride=stride,
padding=padding_mode,
dilation=dilation,
bias=bias,
)
self.bn = nn.BatchNorm1d(nout)
self.drop = nn.Dropout(dropout)
def forward(self, x, mask=None):
if mask is not None:
x = x * mask.unsqueeze(1).to(device=x.device)
x = self.conv1(x)
x = self.bn(x)
x = F.silu(x)
x = self.drop(x)
return x, mask
class Conv1d(nn.Module):
"""normal conv1d with mask"""
def __init__(self, nin, nout, kernel_size, padding, bias=False):
super(Conv1d, self).__init__()
self.l = nn.Conv1d(nin, nout, kernel_size, padding=padding, bias=bias)
def forward(self, x, mask):
if mask is not None:
x = x * mask.unsqueeze(1).to(device=x.device)
y = self.l(x)
return y, mask
class SqueezeExcite(nn.Module):
"""Let the CNN decide how to add across channels"""
def __init__(self, nin, ratio=8):
super(SqueezeExcite, self).__init__()
self.nin = nin
self.ratio = ratio
self.fc = mySequential(
nn.Linear(nin, nin // ratio, bias=True),
nn.SiLU(inplace=True),
nn.Linear(nin // ratio, nin, bias=True),
)
def forward(self, x, mask=None):
if mask is None:
mask = torch.ones((x.shape[0], x.shape[-1]), dtype=torch.bool).to(x.device)
mask = ~mask
x = x.float()
x.masked_fill_(mask.unsqueeze(1), 0.0)
mask = ~mask
y = (
torch.sum(x, dim=-1, keepdim=True)
/ mask.unsqueeze(1).sum(dim=-1, keepdim=True)
).type(x.dtype)
# y=torch.mean(x,-1,keepdim=True)
y = y.transpose(1, -1)
y = self.fc(y)
y = torch.sigmoid(y)
y = y.transpose(1, -1)
y = x * y
return y, mask
class SCBD(nn.Module):
"""SeparableConv1D + Batchnorm + Dropout, Generally use it for middle layers and resnet"""
def __init__(
self, nin, nout, kernel_size, p=0.1, rd=True, separable=True, bias=False
):
super(SCBD, self).__init__()
if separable:
self.SC = SepConv1D(nin, nout, kernel_size, bias=bias)
else:
self.SC = Conv1d(nin, nout, kernel_size, padding="same", bias=bias)
if rd: # relu and Dropout
self.mout = mySequential(
normalization(nout),
nn.SiLU(), # nn.BatchNorm1d(nout,eps)
nn.Dropout(p),
)
else:
self.mout = normalization(nout) # nn.BatchNorm1d(nout,eps)
def forward(self, x, mask=None):
if mask is not None:
x = x * mask.unsqueeze(1).to(device=x.device)
x, _ = self.SC(x, mask)
y = self.mout(x)
return y, mask
class QuartzNetBlock(nn.Module):
"""Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
"""
def __init__(
self,
nin,
nout,
kernel_size,
dropout=0.1,
R=5,
se=False,
ratio=8,
separable=False,
bias=False,
):
super(QuartzNetBlock, self).__init__()
self.se = se
self.residual = mySequential(
nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
normalization(nout), # nn.BatchNorm1d(nout,eps)
)
model = []
for i in range(R - 1):
model.append(SCBD(nin, nout, kernel_size, dropout, eps=0.001, bias=bias))
nin = nout
if separable:
model.append(
SCBD(nin, nout, kernel_size, dropout, eps=0.001, rd=False, bias=bias)
)
else:
model.append(
SCBD(
nin,
nout,
kernel_size,
dropout,
eps=0.001,
rd=False,
separable=False,
bias=bias,
)
)
self.model = mySequential(*model)
if self.se:
# model.append(SqueezeExcite(nin,ratio))
self.se_layer = SqueezeExcite(nin, ratio)
self.mout = mySequential(nn.SiLU(), nn.Dropout(dropout))
def forward(self, x, mask=None):
if mask is not None:
x = x * mask.unsqueeze(1).to(device=x.device)
y, _ = self.model(x, mask)
if self.se:
y, _ = self.se_layer(y, mask)
y += self.residual(x)
y = self.mout(y)
return y, mask
class QKVAttentionLegacy(nn.Module):
"""
A module which performs QKV attention. Matches legacy QKVAttention + input/output heads shaping
"""
def __init__(self, n_heads):
super().__init__()
self.n_heads = n_heads
def forward(self, qkv, mask=None, rel_pos=None):
"""
Apply QKV attention.
:param qkv: an [N x (H * 3 * C) x T] tensor of Qs, Ks, and Vs.
:return: an [N x (H * C) x T] tensor after attention.
"""
bs, width, length = qkv.shape
assert width % (3 * self.n_heads) == 0
ch = width // (3 * self.n_heads)
q, k, v = qkv.reshape(bs * self.n_heads, ch * 3, length).split(ch, dim=1)
scale = 1 / math.sqrt(math.sqrt(ch))
weight = torch.einsum(
"bct,bcs->bts", q * scale, k * scale
) # More stable with f16 than dividing afterwards
if rel_pos is not None:
weight = rel_pos(
weight.reshape(bs, self.n_heads, weight.shape[-2], weight.shape[-1])
).reshape(bs * self.n_heads, weight.shape[-2], weight.shape[-1])
weight = torch.softmax(weight.float(), dim=-1).type(weight.dtype)
if mask is not None:
# The proper way to do this is to mask before the softmax using -inf, but that doesn't work properly on CPUs.
mask = mask.repeat(self.n_heads, 1).unsqueeze(1)
weight = weight * mask
a = torch.einsum("bts,bcs->bct", weight, v)
return a.reshape(bs, -1, length)
class AttentionBlock(nn.Module):
"""
An attention block that allows spatial positions to attend to each other.
Originally ported from here, but adapted to the N-d case.
https://github.com/hojonathanho/diffusion/blob/1e0dceb3b3495bbe19116a5e1b3596cd0706c543/diffusion_tf/models/unet.py#L66.
"""
def __init__(
self,
channels,
num_heads=1,
num_head_channels=-1,
do_checkpoint=True,
relative_pos_embeddings=False,
):
super().__init__()
self.channels = channels
self.do_checkpoint = do_checkpoint
if num_head_channels == -1:
self.num_heads = num_heads
else:
assert channels % num_head_channels == 0, (
f"q,k,v channels {channels} is not divisible by num_head_channels {num_head_channels}"
)
self.num_heads = channels // num_head_channels
self.norm = normalization(channels)
self.qkv = nn.Conv1d(channels, channels * 3, 1, bias=False)
# split heads before split qkv
self.attention = QKVAttentionLegacy(self.num_heads)
self.proj_out = zero_module(
nn.Conv1d(channels, channels, 1, bias=False)
) # no effect of attention in the inital stages.
# if relative_pos_embeddings:
self.relative_pos_embeddings = RelativePositionBias(
scale=(channels // self.num_heads) ** 0.5,
causal=False,
heads=num_heads,
num_buckets=64,
max_distance=128,
)
def forward(self, x, mask=None):
b, c, *spatial = x.shape
x = x.reshape(b, c, -1)
qkv = self.qkv(self.norm(x))
h = self.attention(qkv, mask, self.relative_pos_embeddings)
h = self.proj_out(h)
return (x + h).reshape(b, c, *spatial)
class AbsolutePositionalEmbedding(nn.Module):
def __init__(self, dim, max_seq_len):
super().__init__()
self.scale = dim**-0.5
self.emb = nn.Embedding(max_seq_len, dim)
def forward(self, x):
n = torch.arange(x.shape[1], device=x.device)
pos_emb = self.emb(n)
pos_emb = rearrange(pos_emb, "n d -> () n d")
return pos_emb * self.scale
class FixedPositionalEmbedding(nn.Module):
def __init__(self, dim):
super().__init__()
inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim))
self.register_buffer("inv_freq", inv_freq)
def forward(self, x, seq_dim=1, offset=0):
t = (
torch.arange(x.shape[seq_dim], device=x.device).type_as(self.inv_freq)
+ offset
)
sinusoid_inp = torch.einsum("i , j -> i j", t, self.inv_freq)
emb = torch.cat((sinusoid_inp.sin(), sinusoid_inp.cos()), dim=-1)
return rearrange(emb, "n d -> () n d")
class RelativePositionBias(nn.Module):
def __init__(self, scale, causal=False, num_buckets=16, max_distance=32, heads=8):
super().__init__()
self.scale = scale
self.causal = causal
self.num_buckets = num_buckets
self.max_distance = max_distance
self.relative_attention_bias = nn.Embedding(num_buckets, heads)
@staticmethod
def _relative_position_bucket(
relative_position, causal=True, num_buckets=16, max_distance=32
):
ret = 0
n = -relative_position
if not causal:
num_buckets //= 2
ret += (n < 0).long() * num_buckets
n = torch.abs(n)
else:
n = torch.max(n, torch.zeros_like(n))
max_exact = num_buckets // 2
is_small = n < max_exact
val_if_large = (
max_exact
+ (
torch.log(n.float() / max_exact)
/ math.log(max_distance / max_exact)
* (num_buckets - max_exact)
).long()
)
val_if_large = torch.min(
val_if_large, torch.full_like(val_if_large, num_buckets - 1)
)
ret += torch.where(is_small, n, val_if_large)
return ret
def forward(self, qk_dots):
i, j, device = *qk_dots.shape[-2:], qk_dots.device
q_pos = torch.arange(i, dtype=torch.long, device=device)
k_pos = torch.arange(j, dtype=torch.long, device=device)
rel_pos = k_pos[None, :] - q_pos[:, None]
rp_bucket = self._relative_position_bucket(
rel_pos,
causal=self.causal,
num_buckets=self.num_buckets,
max_distance=self.max_distance,
)
values = self.relative_attention_bias(rp_bucket)
bias = rearrange(values, "i j h -> () h i j")
return qk_dots + (bias * self.scale)
class MultiHeadAttention(nn.Module):
"""
only for GST
input:
query --- [N, T_q, query_dim]
key --- [N, T_k, key_dim]
output:
out --- [N, T_q, num_units]
"""
def __init__(self, query_dim, key_dim, num_units, num_heads):
super().__init__()
self.num_units = num_units
self.num_heads = num_heads
self.key_dim = key_dim
self.W_query = nn.Linear(
in_features=query_dim, out_features=num_units, bias=False
)
self.W_key = nn.Linear(in_features=key_dim, out_features=num_units, bias=False)
self.W_value = nn.Linear(
in_features=key_dim, out_features=num_units, bias=False
)
def forward(self, query, key):
querys = self.W_query(query) # [N, T_q, num_units]
keys = self.W_key(key) # [N, T_k, num_units]
values = self.W_value(key)
split_size = self.num_units // self.num_heads
querys = torch.stack(
torch.split(querys, split_size, dim=2), dim=0
) # [h, N, T_q, num_units/h]
keys = torch.stack(
torch.split(keys, split_size, dim=2), dim=0
) # [h, N, T_k, num_units/h]
values = torch.stack(
torch.split(values, split_size, dim=2), dim=0
) # [h, N, T_k, num_units/h]
# score = softmax(QK^T / (d_k ** 0.5))
scores = torch.matmul(querys, keys.transpose(2, 3)) # [h, N, T_q, T_k]
scores = scores / (self.key_dim**0.5)
scores = F.softmax(scores, dim=3)
# out = score * V
out = torch.matmul(scores, values) # [h, N, T_q, num_units/h]
out = torch.cat(torch.split(out, 1, dim=0), dim=3).squeeze(
0
) # [N, T_q, num_units]
return out
class GST(nn.Module):
def __init__(
self, model_channels=512, style_tokens=100, num_heads=8, in_channels=100
):
super(GST, self).__init__()
self.model_channels = model_channels
self.style_tokens = style_tokens
self.num_heads = num_heads
# self.reference_encoder=nn.Sequential(
# nn.Conv2d(1,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
# nn.Conv2d(32,32,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(32),nn.ReLU(inplace=True),
# nn.Conv2d(32,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
# nn.Conv2d(64,64,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(64),nn.ReLU(inplace=True),
# AttentionBlock(64, 8, relative_pos_embeddings=True),
# nn.Conv2d(64,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
# AttentionBlock(128, 8, relative_pos_embeddings=True),
# nn.Conv2d(128,128,kernel_size=(3,3),stride=(2,2),padding=(1, 1)),normalization(128),nn.ReLU(inplace=True),
# AttentionBlock(128, 8, relative_pos_embeddings=True),
# nn.Conv2d(128,model_channels,kernel_size=(3,3),stride=(1,1),padding=(1, 1)),normalization(model_channels),nn.ReLU(inplace=True),
# AttentionBlock(model_channels, 16, relative_pos_embeddings=True)
# )
# self.reference_encoder=nn.Sequential(
# nn.Conv1d(80,model_channels,3,padding=1,stride=2),
# nn.Conv1d(model_channels, model_channels,3,padding=1,stride=2),
# AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
# AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
# AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
# AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False),
# AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True, do_checkpoint=False)
# )
# in_channels=1
# num_heads = 8
self.reference_encoder = nn.Sequential(
nn.Conv1d(in_channels, model_channels, 3, padding=1, stride=2, bias=False),
nn.Conv1d(
model_channels, model_channels * 2, 3, padding=1, stride=2, bias=False
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
AttentionBlock(
model_channels * 2,
num_heads,
relative_pos_embeddings=True,
do_checkpoint=False,
),
# nn.Conv1d(model_channels*2, 64,3,padding=1,stride=2),
# nn.Conv1d(64, model_channels*2,3,padding=1,stride=2) #added bottleneck
)
# bottleneck = 64
# self.bottleneck = nn.Sequential(nn.Conv1d(model_channels*2,bottleneck,3,padding=1,stride=1),nn.SiLU(),
# nn.Conv1d(bottleneck,model_channels*2,3,padding=1,stride=1),nn.SiLU())
# self.gru=nn.GRU(128*2,256,batch_first=True,bidirectional=True)
# self.attention = MultiHeadAttention(query_dim=model_channels, key_dim=model_channels//num_heads, num_units=model_channels*2, num_heads=num_heads)
# self.style_tokens = nn.parameter.Parameter(torch.FloatTensor(style_tokens,model_channels//num_heads))
# init.normal_(self.style_tokens, mean=0, std=0.5)
def forward(self, x):
# add masking
# batch=x.size(0)
# x=x.view(batch,1,-1,80) # (N,1,t,80)
x = self.reference_encoder(x) # (N,128,t,80//x)
# print(x.shape)
# x = self.bottleneck(x)
# print(x.shape)
# print(x.shape,'encoder')
# x = x.mean(dim=-1)#.mean(dim=-1)
# # x=x.transpose(1,2).contiguous() #(N,t,128,80//x)
# # time=x.size(1)
# # x=x.view(batch,time,-1)
# # _,x=self.gru(x)
# # print(x.shape,'gru')
# x=x.view(batch,1,-1)
# keys = self.style_tokens.unsqueeze(0).expand(batch, -1, -1) # [N, token_num, E // num_heads]
# # print(keys.shape,'keys')
# style_embed = self.attention(x, keys)
# # print(style_embed.shape,'gst tokens')
# add normalization?
return x
# class GST(nn.Module):
# """
# inputs --- [N, Ty/r, n_mels*r] mels
# outputs --- [N, ref_enc_gru_size]
# """
# def __init__(self, spec_channels=80, gin_channels=512, layernorm=True):
# super().__init__()
# self.spec_channels = spec_channels
# ref_enc_filters = [32, 32, 64, 64, 128, 128]
# K = len(ref_enc_filters)
# filters = [1] + ref_enc_filters
# convs = [
# weight_norm(
# nn.Conv2d(
# in_channels=filters[i],
# out_channels=filters[i + 1],
# kernel_size=(3, 3),
# stride=(2, 2),
# padding=(1, 1),
# )
# )
# for i in range(K)
# ]
# self.convs = nn.ModuleList(convs)
# out_channels = self.calculate_channels(spec_channels, 3, 2, 1, K)
# self.gru = nn.GRU(
# input_size=ref_enc_filters[-1] * out_channels,
# hidden_size=256 // 2,
# batch_first=True,
# )
# self.proj = nn.Linear(128, gin_channels)
# if layernorm:
# self.layernorm = nn.LayerNorm(self.spec_channels)
# else:
# self.layernorm = None
# def forward(self, inputs, mask=None):
# N = inputs.size(0)
# out = inputs.view(N, 1, -1, self.spec_channels) # [N, 1, Ty, n_freqs]
# if self.layernorm is not None:
# out = self.layernorm(out)
# for conv in self.convs:
# out = conv(out)
# # out = wn(out)
# out = F.silu(out) # [N, 128, Ty//2^K, n_mels//2^K]
# out = out.transpose(1, 2) # [N, Ty//2^K, 128, n_mels//2^K]
# T = out.size(1)
# N = out.size(0)
# out = out.contiguous().view(N, T, -1) # [N, Ty//2^K, 128*n_mels//2^K]
# self.gru.flatten_parameters()
# memory, out = self.gru(out) # out --- [1, N, 128]
# return self.proj(out.squeeze(0))
# def calculate_channels(self, L, kernel_size, stride, pad, n_convs):
# for i in range(n_convs):
# L = (L - kernel_size + 2 * pad) // stride + 1
# return L
if __name__ == "__main__":
device = torch.device("cpu")
m = GST(512, 10).to(device)
mels = torch.rand((16, 80, 1000)).to(device)
o = m(mels)
print(o.shape, "final output")
from torchinfo import summary
summary(m, input_data={"x": torch.randn(16, 80, 500).to(device)})