|
|
import math |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from utils.torch_utilities import concat_non_padding, restore_from_concat |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class LayerNorm(nn.LayerNorm): |
|
|
"""Layer normalization module. |
|
|
:param int nout: output dim size |
|
|
:param int dim: dimension to be normalized |
|
|
""" |
|
|
def __init__(self, nout, dim=-1): |
|
|
"""Construct an LayerNorm object.""" |
|
|
super(LayerNorm, self).__init__(nout, eps=1e-12) |
|
|
self.dim = dim |
|
|
|
|
|
def forward(self, x): |
|
|
"""Apply layer normalization. |
|
|
:param torch.Tensor x: input tensor |
|
|
:return: layer normalized tensor |
|
|
:rtype torch.Tensor |
|
|
""" |
|
|
if self.dim == -1: |
|
|
return super(LayerNorm, self).forward(x) |
|
|
return super(LayerNorm, |
|
|
self).forward(x.transpose(1, -1)).transpose(1, -1) |
|
|
|
|
|
|
|
|
class DurationPredictor(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
in_channels: int, |
|
|
filter_channels: int, |
|
|
n_layers: int = 2, |
|
|
kernel_size: int = 3, |
|
|
p_dropout: float = 0.1, |
|
|
padding: str = "SAME" |
|
|
): |
|
|
super(DurationPredictor, self).__init__() |
|
|
self.conv = nn.ModuleList() |
|
|
self.kernel_size = kernel_size |
|
|
self.padding = padding |
|
|
for idx in range(n_layers): |
|
|
in_chans = in_channels if idx == 0 else filter_channels |
|
|
self.conv += [ |
|
|
nn.Sequential( |
|
|
nn.ConstantPad1d(((kernel_size - 1) // 2, |
|
|
(kernel_size - 1) // |
|
|
2) if padding == 'SAME' else |
|
|
(kernel_size - 1, 0), 0), |
|
|
nn.Conv1d( |
|
|
in_chans, |
|
|
filter_channels, |
|
|
kernel_size, |
|
|
stride=1, |
|
|
padding=0 |
|
|
), nn.ReLU(), LayerNorm(filter_channels, dim=1), |
|
|
nn.Dropout(p_dropout) |
|
|
) |
|
|
] |
|
|
self.linear = nn.Linear(filter_channels, 1) |
|
|
|
|
|
def forward(self, x: torch.Tensor, x_mask: torch.Tensor): |
|
|
|
|
|
x = x.transpose(1, -1) |
|
|
x_mask = x_mask.unsqueeze(1).to(x.device) |
|
|
for f in self.conv: |
|
|
x = f(x) |
|
|
x = x * x_mask.float() |
|
|
|
|
|
x = self.linear(x.transpose(1, -1) |
|
|
) * x_mask.transpose(1, -1).float() |
|
|
return x |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ContentAdapterBase(nn.Module): |
|
|
def __init__(self, d_out): |
|
|
super().__init__() |
|
|
self.d_out = d_out |
|
|
|
|
|
|
|
|
class SinusoidalPositionalEmbedding(nn.Module): |
|
|
def __init__(self, d_model, dropout, max_len=1000): |
|
|
super().__init__() |
|
|
self.dropout = nn.Dropout(dropout) |
|
|
pe = torch.zeros(max_len, d_model) |
|
|
position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) |
|
|
div_term = torch.exp( |
|
|
torch.arange(0, d_model, 2).float() * |
|
|
(-math.log(10000.0) / d_model) |
|
|
) |
|
|
pe[:, 0::2] = torch.sin(position * div_term) |
|
|
pe[:, 1::2] = torch.cos(position * div_term) |
|
|
pe = pe.unsqueeze(0).transpose(0, 1) |
|
|
self.register_buffer('pe', pe) |
|
|
|
|
|
def forward(self, x): |
|
|
x = x + self.pe[:x.size(1), :] |
|
|
return self.dropout(x) |
|
|
|
|
|
|
|
|
class ContentAdapter(ContentAdapterBase): |
|
|
def __init__( |
|
|
self, |
|
|
d_model: int, |
|
|
d_out: int, |
|
|
num_layers: int, |
|
|
num_heads: int, |
|
|
duration_predictor: DurationPredictor, |
|
|
dropout: float = 0.1, |
|
|
norm_first: bool = False, |
|
|
activation: str = "gelu", |
|
|
duration_grad_scale: float = 0.0, |
|
|
): |
|
|
super().__init__(d_out) |
|
|
self.duration_grad_scale = duration_grad_scale |
|
|
self.cls_embed = nn.Parameter(torch.randn(d_model)) |
|
|
if hasattr(torch, "npu") and torch.npu.is_available(): |
|
|
enable_nested_tensor = False |
|
|
else: |
|
|
enable_nested_tensor = True |
|
|
encoder_layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=4 * d_model, |
|
|
dropout=dropout, |
|
|
activation=activation, |
|
|
norm_first=norm_first, |
|
|
batch_first=True |
|
|
) |
|
|
self.encoder_layers = nn.TransformerEncoder( |
|
|
encoder_layer=encoder_layer, |
|
|
num_layers=num_layers, |
|
|
enable_nested_tensor=enable_nested_tensor |
|
|
) |
|
|
self.duration_predictor = duration_predictor |
|
|
self.content_proj = nn.Conv1d(d_model, d_out, 1) |
|
|
|
|
|
def forward(self, x, x_mask): |
|
|
batch_size = x.size(0) |
|
|
cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) |
|
|
cls_embed = cls_embed.to(x.device).unsqueeze(1) |
|
|
x = torch.cat([cls_embed, x], dim=1) |
|
|
|
|
|
cls_mask = torch.ones(batch_size, 1).to(x_mask.device) |
|
|
x_mask = torch.cat([cls_mask, x_mask], dim=1) |
|
|
x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool()) |
|
|
x_grad_rescaled = x * self.duration_grad_scale + x.detach( |
|
|
) * (1 - self.duration_grad_scale) |
|
|
duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) |
|
|
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) |
|
|
return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] |
|
|
|
|
|
|
|
|
class PrefixAdapter(ContentAdapterBase): |
|
|
def __init__( |
|
|
self, |
|
|
content_dim: int, |
|
|
d_model: int, |
|
|
d_out: int, |
|
|
prefix_dim: int, |
|
|
num_layers: int, |
|
|
num_heads: int, |
|
|
duration_predictor: DurationPredictor, |
|
|
dropout: float = 0.1, |
|
|
norm_first: bool = False, |
|
|
use_last_norm: bool = True, |
|
|
activation: str = "gelu", |
|
|
duration_grad_scale: float = 0.1, |
|
|
): |
|
|
super().__init__(d_out) |
|
|
self.duration_grad_scale = duration_grad_scale |
|
|
self.prefix_mlp = nn.Sequential( |
|
|
nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout), |
|
|
nn.Linear(d_model, d_model) |
|
|
) |
|
|
self.content_mlp = nn.Sequential( |
|
|
nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout), |
|
|
nn.Linear(d_model, d_model) |
|
|
) |
|
|
layer = nn.TransformerEncoderLayer( |
|
|
d_model=d_model, |
|
|
nhead=num_heads, |
|
|
dim_feedforward=4 * d_model, |
|
|
dropout=dropout, |
|
|
activation=activation, |
|
|
batch_first=True, |
|
|
norm_first=norm_first |
|
|
) |
|
|
if hasattr(torch, "npu") and torch.npu.is_available(): |
|
|
enable_nested_tensor = False |
|
|
else: |
|
|
enable_nested_tensor = True |
|
|
self.cls_embed = nn.Parameter(torch.randn(d_model)) |
|
|
|
|
|
self.layers = nn.TransformerEncoder( |
|
|
encoder_layer=layer, |
|
|
num_layers=num_layers, |
|
|
enable_nested_tensor=enable_nested_tensor |
|
|
) |
|
|
self.use_last_norm = use_last_norm |
|
|
if self.use_last_norm: |
|
|
self.last_norm = nn.LayerNorm(d_model) |
|
|
self.duration_predictor = duration_predictor |
|
|
self.content_proj = nn.Conv1d(d_model, d_out, 1) |
|
|
nn.init.normal_(self.cls_embed, 0., 0.02) |
|
|
nn.init.xavier_uniform_(self.content_proj.weight) |
|
|
nn.init.constant_(self.content_proj.bias, 0.) |
|
|
|
|
|
def forward(self, content, content_mask, instruction, instruction_mask): |
|
|
batch_size = content.size(0) |
|
|
cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) |
|
|
cls_embed = cls_embed.to(content.device).unsqueeze(1) |
|
|
content = self.content_mlp(content) |
|
|
x = torch.cat([cls_embed, content], dim=1) |
|
|
cls_mask = torch.ones(batch_size, 1, |
|
|
dtype=bool).to(content_mask.device) |
|
|
x_mask = torch.cat([cls_mask, content_mask], dim=1) |
|
|
|
|
|
prefix = self.prefix_mlp(instruction) |
|
|
seq, seq_mask, perm = concat_non_padding( |
|
|
prefix, instruction_mask, x, x_mask |
|
|
) |
|
|
|
|
|
x = self.layers(seq, src_key_padding_mask=~seq_mask.bool()) |
|
|
if self.use_last_norm: |
|
|
x = self.last_norm(x) |
|
|
_, x = restore_from_concat(x, instruction_mask, x_mask, perm) |
|
|
|
|
|
x_grad_rescaled = x * self.duration_grad_scale + x.detach( |
|
|
) * (1 - self.duration_grad_scale) |
|
|
duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) |
|
|
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) |
|
|
return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] |
|
|
|
|
|
|
|
|
class CrossAttentionAdapter(ContentAdapterBase): |
|
|
def __init__( |
|
|
self, |
|
|
d_out: int, |
|
|
content_dim: int, |
|
|
prefix_dim: int, |
|
|
num_heads: int, |
|
|
duration_predictor: DurationPredictor, |
|
|
dropout: float = 0.1, |
|
|
duration_grad_scale: float = 0.1, |
|
|
): |
|
|
super().__init__(d_out) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
embed_dim=content_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
kdim=prefix_dim, |
|
|
vdim=prefix_dim, |
|
|
batch_first=True, |
|
|
) |
|
|
self.duration_grad_scale = duration_grad_scale |
|
|
self.duration_predictor = duration_predictor |
|
|
self.global_duration_mlp = nn.Sequential( |
|
|
nn.Linear(content_dim, content_dim), nn.ReLU(), |
|
|
nn.Dropout(dropout), nn.Linear(content_dim, 1) |
|
|
) |
|
|
self.norm = nn.LayerNorm(content_dim) |
|
|
self.content_proj = nn.Conv1d(content_dim, d_out, 1) |
|
|
|
|
|
def forward(self, content, content_mask, prefix, prefix_mask): |
|
|
attn_output, attn_output_weights = self.attn( |
|
|
query=content, |
|
|
key=prefix, |
|
|
value=prefix, |
|
|
key_padding_mask=~prefix_mask.bool() |
|
|
) |
|
|
attn_output = attn_output * content_mask.unsqueeze(-1).float() |
|
|
x = self.norm(attn_output + content) |
|
|
x_grad_rescaled = x * self.duration_grad_scale + x.detach( |
|
|
) * (1 - self.duration_grad_scale) |
|
|
x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() |
|
|
).sum(dim=1) / content_mask.sum(dim=1, |
|
|
keepdim=True).float() |
|
|
global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) |
|
|
local_duration = self.duration_predictor( |
|
|
x_grad_rescaled, content_mask |
|
|
).squeeze(-1) |
|
|
content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) |
|
|
return content, content_mask, global_duration, local_duration |
|
|
|
|
|
|
|
|
class ExperimentalCrossAttentionAdapter(ContentAdapterBase): |
|
|
def __init__( |
|
|
self, |
|
|
d_out: int, |
|
|
content_dim: int, |
|
|
prefix_dim: int, |
|
|
num_heads: int, |
|
|
duration_predictor: DurationPredictor, |
|
|
dropout: float = 0.1, |
|
|
duration_grad_scale: float = 0.1, |
|
|
): |
|
|
super().__init__(d_out) |
|
|
self.content_mlp = nn.Sequential( |
|
|
nn.Linear(content_dim, content_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(content_dim, content_dim), |
|
|
) |
|
|
self.content_norm = nn.LayerNorm(content_dim) |
|
|
self.prefix_mlp = nn.Sequential( |
|
|
nn.Linear(prefix_dim, prefix_dim), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(prefix_dim, prefix_dim), |
|
|
) |
|
|
self.prefix_norm = nn.LayerNorm(content_dim) |
|
|
self.attn = nn.MultiheadAttention( |
|
|
embed_dim=content_dim, |
|
|
num_heads=num_heads, |
|
|
dropout=dropout, |
|
|
kdim=prefix_dim, |
|
|
vdim=prefix_dim, |
|
|
batch_first=True, |
|
|
) |
|
|
self.duration_grad_scale = duration_grad_scale |
|
|
self.duration_predictor = duration_predictor |
|
|
self.global_duration_mlp = nn.Sequential( |
|
|
nn.Linear(content_dim, content_dim), nn.ReLU(), |
|
|
nn.Dropout(dropout), nn.Linear(content_dim, 1) |
|
|
) |
|
|
self.content_proj = nn.Sequential( |
|
|
nn.Linear(content_dim, d_out), |
|
|
nn.ReLU(), |
|
|
nn.Dropout(dropout), |
|
|
nn.Linear(d_out, d_out), |
|
|
) |
|
|
self.norm1 = nn.LayerNorm(content_dim) |
|
|
self.norm2 = nn.LayerNorm(d_out) |
|
|
self.init_weights() |
|
|
|
|
|
def init_weights(self): |
|
|
def _init_weights(module): |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.constant_(module.bias, 0.) |
|
|
|
|
|
self.apply(_init_weights) |
|
|
|
|
|
def forward(self, content, content_mask, prefix, prefix_mask): |
|
|
content = self.content_mlp(content) |
|
|
content = self.content_norm(content) |
|
|
prefix = self.prefix_mlp(prefix) |
|
|
prefix = self.prefix_norm(prefix) |
|
|
attn_output, attn_weights = self.attn( |
|
|
query=content, |
|
|
key=prefix, |
|
|
value=prefix, |
|
|
key_padding_mask=~prefix_mask.bool(), |
|
|
) |
|
|
attn_output = attn_output * content_mask.unsqueeze(-1).float() |
|
|
x = attn_output + content |
|
|
x = self.norm1(x) |
|
|
x_grad_rescaled = x * self.duration_grad_scale + x.detach( |
|
|
) * (1 - self.duration_grad_scale) |
|
|
x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() |
|
|
).sum(dim=1) / content_mask.sum(dim=1, |
|
|
keepdim=True).float() |
|
|
global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) |
|
|
local_duration = self.duration_predictor( |
|
|
x_grad_rescaled, content_mask |
|
|
).squeeze(-1) |
|
|
content = self.content_proj(x) |
|
|
content = self.norm2(content) |
|
|
return content, content_mask, global_duration, local_duration |
|
|
|