MahaTTSv2 / S2A /diff_model.py
rasenganai
init
41bc8a8
import math
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch import autocast
from config import config
from .modules import GST, AttentionBlock, mySequential, normalization
def timestep_embedding(timesteps, dim, max_period=10000):
"""
Create sinusoidal timestep embeddings.
:param timesteps: a 1-D Tensor of N indices, one per batch element.
These may be fractional.
:param dim: the dimension of the output.
:param max_period: controls the minimum frequency of the embeddings.
:return: an [N x dim] Tensor of positional embeddings.
"""
half = dim // 2
freqs = torch.exp(
-math.log(max_period)
* torch.arange(start=0, end=half, dtype=torch.float32)
/ half
).to(device=timesteps.device)
args = timesteps[:, None].float() * freqs[None]
embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1)
if dim % 2:
embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1)
return embedding
class TimestepBlock(nn.Module):
def forward(self, x, emb):
"""
Apply the module to `x` given `emb` timestep embeddings.
"""
class TimestepEmbedSequential(nn.Sequential, TimestepBlock):
def forward(self, x, emb):
for layer in self:
if isinstance(layer, TimestepBlock):
x = layer(x, emb)
else:
x = layer(x)
return x
class QuartzNetBlock(TimestepBlock):
"""Similar to Resnet block with Batchnorm and dropout, and using Separable conv in the middle.
if its the last layer,set se = False and separable = False, and use a projection layer on top of this.
"""
def __init__(
self,
nin,
nout,
emb_channels,
kernel_size=3,
dropout=0.1,
R=1,
se=True,
ratio=8,
separable=False,
bias=True,
use_scale_shift_norm=True,
):
super(QuartzNetBlock, self).__init__()
self.use_scale_shift_norm = use_scale_shift_norm
self.se = se
self.in_layers = mySequential(
nn.Conv1d(nin, nout, kernel_size=1, padding="same", bias=bias),
nn.SiLU(),
normalization(nout),
)
if nin == nout:
self.residual = nn.Identity()
else:
self.residual = nn.Conv1d(
nin, nout, kernel_size=1, padding="same", bias=bias
)
nin = nout
self.model = nn.Sequential(
nn.Conv1d(nin, nout, kernel_size, padding="same"),
nn.SiLU(),
normalization(nout),
nn.Dropout(p=dropout),
)
self.emb_layers = nn.Sequential(
nn.Linear(
emb_channels,
2 * nout if use_scale_shift_norm else nout,
),
nn.SiLU(),
)
def forward(self, x, emb, mask=None):
x_new = self.in_layers(x)
emb = self.emb_layers(emb)
while len(emb.shape) < len(x_new.shape):
emb = emb[..., None]
scale, shift = torch.chunk(emb, 2, dim=1)
x_new = x_new * (1 + scale) + shift
y = self.model(x_new)
return y + self.residual(x)
class QuartzAttn(TimestepBlock):
def __init__(self, model_channels, dropout, num_heads):
super().__init__()
self.resblk = QuartzNetBlock(
model_channels,
model_channels,
model_channels,
dropout=dropout,
use_scale_shift_norm=True,
)
self.attn = AttentionBlock(
model_channels, num_heads, relative_pos_embeddings=True
)
def forward(self, x, time_emb):
y = self.resblk(x, time_emb)
return self.attn(y)
class QuartzNet9x5(nn.Module):
def __init__(self, model_channels, num_heads, dropout=0.1, enable_fp16=False):
super(QuartzNet9x5, self).__init__()
self.enable_fp16 = enable_fp16
kernels = [3] * 10
quartznet = []
attn = []
for i in kernels:
quartznet.append(
QuartzNetBlock(
model_channels,
model_channels,
model_channels,
kernel_size=i,
dropout=dropout,
R=5,
se=True,
)
)
attn.append(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True)
)
self.quartznet = nn.ModuleList(quartznet)
self.attn = nn.ModuleList(attn)
self.conv2 = nn.ModuleList(
[
QuartzNetBlock(
model_channels,
model_channels,
model_channels,
kernel_size=3,
dropout=dropout,
R=3,
separable=False,
)
for i in range(3)
]
)
self.conv3 = nn.Sequential(
nn.Conv1d(model_channels, model_channels, 3, padding="same"),
nn.SiLU(),
normalization(model_channels),
nn.Conv1d(model_channels, 100, 1, padding="same"),
)
def forward(self, x, time_emb):
for n, (layer, attn) in enumerate(zip(self.quartznet, self.attn)):
x = layer(x, time_emb) # 256 dim
x = attn(x)
for layer in self.conv2:
x = layer(x, time_emb)
x = self.conv3(x)
return x
class DiffModel(nn.Module):
def __init__(
self,
input_channels=80,
output_channels=160,
model_channels=256,
num_heads=8,
dropout=0.1,
num_layers=8,
multispeaker=True,
style_tokens=100,
enable_fp16=False,
condition_free_per=0.1,
training=False,
ar_active=False,
in_latent_channels=10004,
):
super().__init__()
self.input_channels = input_channels
self.model_channels = model_channels
self.output_channels = output_channels
self.num_heads = num_heads
self.dropout = dropout
self.num_layers = num_layers
self.enable_fp16 = enable_fp16
self.condition_free_per = condition_free_per
self.training = training
self.multispeaker = multispeaker
self.ar_active = ar_active
self.in_latent_channels = in_latent_channels
if not self.ar_active:
self.code_emb = nn.Embedding(
config.semantic_model_centroids + 1, model_channels
)
self.code_converter = mySequential(
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
else:
self.code_converter = mySequential(
nn.Conv1d(
self.in_latent_channels, model_channels, 3, padding=1, bias=True
),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
AttentionBlock(model_channels, num_heads, relative_pos_embeddings=True),
)
if self.multispeaker:
self.GST = GST(
model_channels, style_tokens, num_heads, in_channels=input_channels
)
self.code_norm = normalization(model_channels)
self.time_norm = normalization(model_channels)
self.code_time_norm = normalization(model_channels)
self.time_embed = mySequential(
nn.Linear(model_channels, model_channels),
nn.SiLU(),
nn.Linear(model_channels, model_channels),
)
self.input_block = nn.Conv1d(input_channels, model_channels, 3, 1, 1, bias=True)
self.unconditioned_embedding = nn.Parameter(torch.randn(1, model_channels, 1))
self.integrating_conv = nn.Conv1d(
model_channels * 2, model_channels, kernel_size=1
)
self.code_time = TimestepEmbedSequential(
QuartzAttn(model_channels, dropout, num_heads),
QuartzAttn(model_channels, dropout, num_heads),
QuartzAttn(model_channels, dropout, num_heads),
)
self.layers = QuartzNet9x5(
model_channels, num_heads, self.enable_fp16, self.dropout
)
def get_speaker_latent(self, ref_mels):
ref_mels = ref_mels.unsqueeze(1) if len(ref_mels.shape) == 3 else ref_mels
conds = []
for j in range(ref_mels.shape[1]):
conds.append(self.GST(ref_mels[:, j, :, :]))
conds = torch.cat(conds, dim=-1)
conds = conds.mean(dim=-1)
return conds.unsqueeze(2)
def forward(
self,
x,
t,
code_emb,
ref_clips=None,
speaker_latents=None,
conditioning_free=False,
):
time_embed = self.time_norm(
self.time_embed(
timestep_embedding(t.unsqueeze(-1), self.model_channels)
).permute(0, 2, 1)
).squeeze(2)
if conditioning_free:
code_embed = self.unconditioned_embedding.repeat(x.shape[0], 1, x.shape[-1])
else:
if not self.ar_active:
code_embed = self.code_norm(
self.code_converter(self.code_emb(code_emb).permute(0, 2, 1))
)
else:
code_embed = self.code_norm(self.code_converter(code_emb))
if self.multispeaker:
assert speaker_latents is not None or ref_clips is not None
if ref_clips is not None:
speaker_latents = self.get_speaker_latent(ref_clips)
cond_scale, cond_shift = torch.chunk(speaker_latents, 2, dim=1)
code_embed = code_embed * (1 + cond_scale) + cond_shift
if self.training and self.condition_free_per > 0:
unconditioned_batches = (
torch.rand((code_embed.shape[0], 1, 1), device=code_embed.device)
< self.condition_free_per
)
code_embed = torch.where(
unconditioned_batches,
self.unconditioned_embedding.repeat(code_embed.shape[0], 1, 1),
code_embed,
)
expanded_code_emb = F.interpolate(code_embed, size=x.shape[-1], mode="linear")
x_cond = self.code_time_norm(self.code_time(expanded_code_emb, time_embed))
x = self.input_block(x)
x = torch.cat([x, x_cond], dim=1)
x = self.integrating_conv(x)
out = self.layers(x, time_embed)
return out