import math import torch import torch.nn as nn from utils.torch_utilities import concat_non_padding, restore_from_concat ###################### # fastspeech modules ###################### class LayerNorm(nn.LayerNorm): """Layer normalization module. :param int nout: output dim size :param int dim: dimension to be normalized """ def __init__(self, nout, dim=-1): """Construct an LayerNorm object.""" super(LayerNorm, self).__init__(nout, eps=1e-12) self.dim = dim def forward(self, x): """Apply layer normalization. :param torch.Tensor x: input tensor :return: layer normalized tensor :rtype torch.Tensor """ if self.dim == -1: return super(LayerNorm, self).forward(x) return super(LayerNorm, self).forward(x.transpose(1, -1)).transpose(1, -1) class DurationPredictor(nn.Module): def __init__( self, in_channels: int, filter_channels: int, n_layers: int = 2, kernel_size: int = 3, p_dropout: float = 0.1, padding: str = "SAME" ): super(DurationPredictor, self).__init__() self.conv = nn.ModuleList() self.kernel_size = kernel_size self.padding = padding for idx in range(n_layers): in_chans = in_channels if idx == 0 else filter_channels self.conv += [ nn.Sequential( nn.ConstantPad1d(((kernel_size - 1) // 2, (kernel_size - 1) // 2) if padding == 'SAME' else (kernel_size - 1, 0), 0), nn.Conv1d( in_chans, filter_channels, kernel_size, stride=1, padding=0 ), nn.ReLU(), LayerNorm(filter_channels, dim=1), nn.Dropout(p_dropout) ) ] self.linear = nn.Linear(filter_channels, 1) def forward(self, x: torch.Tensor, x_mask: torch.Tensor): # x: [B, T, E] x = x.transpose(1, -1) x_mask = x_mask.unsqueeze(1).to(x.device) for f in self.conv: x = f(x) x = x * x_mask.float() x = self.linear(x.transpose(1, -1) ) * x_mask.transpose(1, -1).float() # [B, T, 1] return x ###################### # adapter modules ###################### class ContentAdapterBase(nn.Module): def __init__(self, d_out): super().__init__() self.d_out = d_out class SinusoidalPositionalEmbedding(nn.Module): def __init__(self, d_model, dropout, max_len=1000): super().__init__() self.dropout = nn.Dropout(dropout) pe = torch.zeros(max_len, d_model) position = torch.arange(0, max_len, dtype=torch.float).unsqueeze(1) div_term = torch.exp( torch.arange(0, d_model, 2).float() * (-math.log(10000.0) / d_model) ) pe[:, 0::2] = torch.sin(position * div_term) pe[:, 1::2] = torch.cos(position * div_term) pe = pe.unsqueeze(0).transpose(0, 1) self.register_buffer('pe', pe) def forward(self, x): x = x + self.pe[:x.size(1), :] return self.dropout(x) class ContentAdapter(ContentAdapterBase): def __init__( self, d_model: int, d_out: int, num_layers: int, num_heads: int, duration_predictor: DurationPredictor, dropout: float = 0.1, norm_first: bool = False, activation: str = "gelu", duration_grad_scale: float = 0.0, ): super().__init__(d_out) self.duration_grad_scale = duration_grad_scale self.cls_embed = nn.Parameter(torch.randn(d_model)) if hasattr(torch, "npu") and torch.npu.is_available(): enable_nested_tensor = False else: enable_nested_tensor = True encoder_layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=4 * d_model, dropout=dropout, activation=activation, norm_first=norm_first, batch_first=True ) self.encoder_layers = nn.TransformerEncoder( encoder_layer=encoder_layer, num_layers=num_layers, enable_nested_tensor=enable_nested_tensor ) self.duration_predictor = duration_predictor self.content_proj = nn.Conv1d(d_model, d_out, 1) def forward(self, x, x_mask): batch_size = x.size(0) cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) cls_embed = cls_embed.to(x.device).unsqueeze(1) x = torch.cat([cls_embed, x], dim=1) cls_mask = torch.ones(batch_size, 1).to(x_mask.device) x_mask = torch.cat([cls_mask, x_mask], dim=1) x = self.encoder_layers(x, src_key_padding_mask=~x_mask.bool()) x_grad_rescaled = x * self.duration_grad_scale + x.detach( ) * (1 - self.duration_grad_scale) duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] class PrefixAdapter(ContentAdapterBase): def __init__( self, content_dim: int, d_model: int, d_out: int, prefix_dim: int, num_layers: int, num_heads: int, duration_predictor: DurationPredictor, dropout: float = 0.1, norm_first: bool = False, use_last_norm: bool = True, activation: str = "gelu", duration_grad_scale: float = 0.1, ): super().__init__(d_out) self.duration_grad_scale = duration_grad_scale self.prefix_mlp = nn.Sequential( nn.Linear(prefix_dim, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, d_model) ) self.content_mlp = nn.Sequential( nn.Linear(content_dim, d_model), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_model, d_model) ) layer = nn.TransformerEncoderLayer( d_model=d_model, nhead=num_heads, dim_feedforward=4 * d_model, dropout=dropout, activation=activation, batch_first=True, norm_first=norm_first ) if hasattr(torch, "npu") and torch.npu.is_available(): enable_nested_tensor = False else: enable_nested_tensor = True self.cls_embed = nn.Parameter(torch.randn(d_model)) # self.pos_embed = SinusoidalPositionalEmbedding(d_model, dropout) self.layers = nn.TransformerEncoder( encoder_layer=layer, num_layers=num_layers, enable_nested_tensor=enable_nested_tensor ) self.use_last_norm = use_last_norm if self.use_last_norm: self.last_norm = nn.LayerNorm(d_model) self.duration_predictor = duration_predictor self.content_proj = nn.Conv1d(d_model, d_out, 1) nn.init.normal_(self.cls_embed, 0., 0.02) nn.init.xavier_uniform_(self.content_proj.weight) nn.init.constant_(self.content_proj.bias, 0.) def forward(self, content, content_mask, instruction, instruction_mask): batch_size = content.size(0) cls_embed = self.cls_embed.reshape(1, -1).expand(batch_size, -1) cls_embed = cls_embed.to(content.device).unsqueeze(1) content = self.content_mlp(content) x = torch.cat([cls_embed, content], dim=1) cls_mask = torch.ones(batch_size, 1, dtype=bool).to(content_mask.device) x_mask = torch.cat([cls_mask, content_mask], dim=1) prefix = self.prefix_mlp(instruction) seq, seq_mask, perm = concat_non_padding( prefix, instruction_mask, x, x_mask ) # seq = self.pos_embed(seq) x = self.layers(seq, src_key_padding_mask=~seq_mask.bool()) if self.use_last_norm: x = self.last_norm(x) _, x = restore_from_concat(x, instruction_mask, x_mask, perm) x_grad_rescaled = x * self.duration_grad_scale + x.detach( ) * (1 - self.duration_grad_scale) duration = self.duration_predictor(x_grad_rescaled, x_mask).squeeze(-1) content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) return content[:, 1:], x_mask[:, 1:], duration[:, 0], duration[:, 1:] class CrossAttentionAdapter(ContentAdapterBase): def __init__( self, d_out: int, content_dim: int, prefix_dim: int, num_heads: int, duration_predictor: DurationPredictor, dropout: float = 0.1, duration_grad_scale: float = 0.1, ): super().__init__(d_out) self.attn = nn.MultiheadAttention( embed_dim=content_dim, num_heads=num_heads, dropout=dropout, kdim=prefix_dim, vdim=prefix_dim, batch_first=True, ) self.duration_grad_scale = duration_grad_scale self.duration_predictor = duration_predictor self.global_duration_mlp = nn.Sequential( nn.Linear(content_dim, content_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(content_dim, 1) ) self.norm = nn.LayerNorm(content_dim) self.content_proj = nn.Conv1d(content_dim, d_out, 1) def forward(self, content, content_mask, prefix, prefix_mask): attn_output, attn_output_weights = self.attn( query=content, key=prefix, value=prefix, key_padding_mask=~prefix_mask.bool() ) attn_output = attn_output * content_mask.unsqueeze(-1).float() x = self.norm(attn_output + content) x_grad_rescaled = x * self.duration_grad_scale + x.detach( ) * (1 - self.duration_grad_scale) x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() ).sum(dim=1) / content_mask.sum(dim=1, keepdim=True).float() global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) local_duration = self.duration_predictor( x_grad_rescaled, content_mask ).squeeze(-1) content = self.content_proj(x.transpose(1, 2)).transpose(1, 2) return content, content_mask, global_duration, local_duration class ExperimentalCrossAttentionAdapter(ContentAdapterBase): def __init__( self, d_out: int, content_dim: int, prefix_dim: int, num_heads: int, duration_predictor: DurationPredictor, dropout: float = 0.1, duration_grad_scale: float = 0.1, ): super().__init__(d_out) self.content_mlp = nn.Sequential( nn.Linear(content_dim, content_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(content_dim, content_dim), ) self.content_norm = nn.LayerNorm(content_dim) self.prefix_mlp = nn.Sequential( nn.Linear(prefix_dim, prefix_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(prefix_dim, prefix_dim), ) self.prefix_norm = nn.LayerNorm(content_dim) self.attn = nn.MultiheadAttention( embed_dim=content_dim, num_heads=num_heads, dropout=dropout, kdim=prefix_dim, vdim=prefix_dim, batch_first=True, ) self.duration_grad_scale = duration_grad_scale self.duration_predictor = duration_predictor self.global_duration_mlp = nn.Sequential( nn.Linear(content_dim, content_dim), nn.ReLU(), nn.Dropout(dropout), nn.Linear(content_dim, 1) ) self.content_proj = nn.Sequential( nn.Linear(content_dim, d_out), nn.ReLU(), nn.Dropout(dropout), nn.Linear(d_out, d_out), ) self.norm1 = nn.LayerNorm(content_dim) self.norm2 = nn.LayerNorm(d_out) self.init_weights() def init_weights(self): def _init_weights(module): if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.constant_(module.bias, 0.) self.apply(_init_weights) def forward(self, content, content_mask, prefix, prefix_mask): content = self.content_mlp(content) content = self.content_norm(content) prefix = self.prefix_mlp(prefix) prefix = self.prefix_norm(prefix) attn_output, attn_weights = self.attn( query=content, key=prefix, value=prefix, key_padding_mask=~prefix_mask.bool(), ) attn_output = attn_output * content_mask.unsqueeze(-1).float() x = attn_output + content x = self.norm1(x) x_grad_rescaled = x * self.duration_grad_scale + x.detach( ) * (1 - self.duration_grad_scale) x_aggregated = (x_grad_rescaled * content_mask.unsqueeze(-1).float() ).sum(dim=1) / content_mask.sum(dim=1, keepdim=True).float() global_duration = self.global_duration_mlp(x_aggregated).squeeze(-1) local_duration = self.duration_predictor( x_grad_rescaled, content_mask ).squeeze(-1) content = self.content_proj(x) content = self.norm2(content) return content, content_mask, global_duration, local_duration