|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional |
|
|
|
|
|
import torch |
|
|
|
|
|
from fairseq.modules import ( |
|
|
ESPNETMultiHeadedAttention, |
|
|
LayerNorm, |
|
|
MultiheadAttention, |
|
|
RelPositionMultiHeadedAttention, |
|
|
RotaryPositionMultiHeadedAttention, |
|
|
) |
|
|
from fairseq.utils import get_activation_fn |
|
|
|
|
|
|
|
|
class ConvolutionModule(torch.nn.Module): |
|
|
"""Convolution block used in the conformer block""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
channels, |
|
|
depthwise_kernel_size, |
|
|
dropout, |
|
|
activation_fn="swish", |
|
|
bias=False, |
|
|
export=False, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
embed_dim: Embedding dimension |
|
|
channels: Number of channels in depthwise conv layers |
|
|
depthwise_kernel_size: Depthwise conv layer kernel size |
|
|
dropout: dropout value |
|
|
activation_fn: Activation function to use after depthwise convolution kernel |
|
|
bias: If bias should be added to conv layers |
|
|
export: If layernorm should be exported to jit |
|
|
""" |
|
|
super(ConvolutionModule, self).__init__() |
|
|
assert ( |
|
|
depthwise_kernel_size - 1 |
|
|
) % 2 == 0, "kernel_size should be a odd number for 'SAME' padding" |
|
|
self.layer_norm = LayerNorm(embed_dim, export=export) |
|
|
self.pointwise_conv1 = torch.nn.Conv1d( |
|
|
embed_dim, |
|
|
2 * channels, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=bias, |
|
|
) |
|
|
self.glu = torch.nn.GLU(dim=1) |
|
|
self.depthwise_conv = torch.nn.Conv1d( |
|
|
channels, |
|
|
channels, |
|
|
depthwise_kernel_size, |
|
|
stride=1, |
|
|
padding=(depthwise_kernel_size - 1) // 2, |
|
|
groups=channels, |
|
|
bias=bias, |
|
|
) |
|
|
self.batch_norm = torch.nn.BatchNorm1d(channels) |
|
|
self.activation = get_activation_fn(activation_fn)(channels) |
|
|
self.pointwise_conv2 = torch.nn.Conv1d( |
|
|
channels, |
|
|
embed_dim, |
|
|
kernel_size=1, |
|
|
stride=1, |
|
|
padding=0, |
|
|
bias=bias, |
|
|
) |
|
|
self.dropout = torch.nn.Dropout(dropout) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: Input of shape B X T X C |
|
|
Returns: |
|
|
Tensor of shape B X T X C |
|
|
""" |
|
|
x = self.layer_norm(x) |
|
|
|
|
|
x = x.transpose(1, 2) |
|
|
|
|
|
|
|
|
x = self.pointwise_conv1(x) |
|
|
x = self.glu(x) |
|
|
|
|
|
|
|
|
x = self.depthwise_conv(x) |
|
|
x = self.batch_norm(x) |
|
|
x = self.activation(x) |
|
|
|
|
|
x = self.pointwise_conv2(x) |
|
|
x = self.dropout(x) |
|
|
return x.transpose(1, 2) |
|
|
|
|
|
|
|
|
class FeedForwardModule(torch.nn.Module): |
|
|
"""Positionwise feed forward layer used in conformer""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
input_feat, |
|
|
hidden_units, |
|
|
dropout1, |
|
|
dropout2, |
|
|
activation_fn="swish", |
|
|
bias=True, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
input_feat: Input feature dimension |
|
|
hidden_units: Hidden unit dimension |
|
|
dropout1: dropout value for layer1 |
|
|
dropout2: dropout value for layer2 |
|
|
activation_fn: Name of activation function |
|
|
bias: If linear layers should have bias |
|
|
""" |
|
|
|
|
|
super(FeedForwardModule, self).__init__() |
|
|
self.layer_norm = LayerNorm(input_feat) |
|
|
self.w_1 = torch.nn.Linear(input_feat, hidden_units, bias=bias) |
|
|
self.w_2 = torch.nn.Linear(hidden_units, input_feat, bias=bias) |
|
|
self.dropout1 = torch.nn.Dropout(dropout1) |
|
|
self.dropout2 = torch.nn.Dropout(dropout2) |
|
|
self.activation = get_activation_fn(activation_fn)(hidden_units) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Args: |
|
|
x: Input Tensor of shape T X B X C |
|
|
Returns: |
|
|
Tensor of shape T X B X C |
|
|
""" |
|
|
x = self.layer_norm(x) |
|
|
x = self.w_1(x) |
|
|
x = self.activation(x) |
|
|
x = self.dropout1(x) |
|
|
x = self.w_2(x) |
|
|
return self.dropout2(x) |
|
|
|
|
|
|
|
|
class ConformerEncoderLayer(torch.nn.Module): |
|
|
"""Conformer block based on https://arxiv.org/abs/2005.08100. We currently don't support relative positional encoding in MHA""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
embed_dim, |
|
|
ffn_embed_dim, |
|
|
attention_heads, |
|
|
dropout, |
|
|
use_fp16, |
|
|
depthwise_conv_kernel_size=31, |
|
|
activation_fn="swish", |
|
|
attn_type=None, |
|
|
pos_enc_type="abs", |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
embed_dim: Input embedding dimension |
|
|
ffn_embed_dim: FFN layer dimension |
|
|
attention_heads: Number of attention heads in MHA |
|
|
dropout: dropout value |
|
|
depthwise_conv_kernel_size: Size of kernel in depthwise conv layer in convolution module |
|
|
activation_fn: Activation function name to use in convulation block and feed forward block |
|
|
attn_type: MHA implementation from ESPNET vs fairseq |
|
|
pos_enc_type: Positional encoding type - abs, rope, rel_pos |
|
|
""" |
|
|
self.pos_enc_type = pos_enc_type |
|
|
super(ConformerEncoderLayer, self).__init__() |
|
|
|
|
|
self.ffn1 = FeedForwardModule( |
|
|
embed_dim, |
|
|
ffn_embed_dim, |
|
|
dropout, |
|
|
dropout, |
|
|
) |
|
|
|
|
|
self.self_attn_layer_norm = LayerNorm(embed_dim, export=False) |
|
|
self.self_attn_dropout = torch.nn.Dropout(dropout) |
|
|
if attn_type == "espnet": |
|
|
if self.pos_enc_type == "rel_pos": |
|
|
self.self_attn = RelPositionMultiHeadedAttention( |
|
|
embed_dim, |
|
|
attention_heads, |
|
|
dropout=dropout, |
|
|
) |
|
|
elif self.pos_enc_type == "rope": |
|
|
self.self_attn = RotaryPositionMultiHeadedAttention( |
|
|
embed_dim, attention_heads, dropout=dropout, precision=use_fp16 |
|
|
) |
|
|
elif self.pos_enc_type == "abs": |
|
|
self.self_attn = ESPNETMultiHeadedAttention( |
|
|
embed_dim, |
|
|
attention_heads, |
|
|
dropout=dropout, |
|
|
) |
|
|
else: |
|
|
raise Exception(f"Unsupported attention type {self.pos_enc_type}") |
|
|
else: |
|
|
|
|
|
self.self_attn = MultiheadAttention( |
|
|
embed_dim, |
|
|
attention_heads, |
|
|
dropout=dropout, |
|
|
) |
|
|
|
|
|
self.conv_module = ConvolutionModule( |
|
|
embed_dim=embed_dim, |
|
|
channels=embed_dim, |
|
|
depthwise_kernel_size=depthwise_conv_kernel_size, |
|
|
dropout=dropout, |
|
|
activation_fn=activation_fn, |
|
|
) |
|
|
|
|
|
self.ffn2 = FeedForwardModule( |
|
|
embed_dim, |
|
|
ffn_embed_dim, |
|
|
dropout, |
|
|
dropout, |
|
|
activation_fn=activation_fn, |
|
|
) |
|
|
self.final_layer_norm = LayerNorm(embed_dim, export=False) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x, |
|
|
encoder_padding_mask: Optional[torch.Tensor], |
|
|
position_emb: Optional[torch.Tensor] = None, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
x: Tensor of shape T X B X C |
|
|
encoder_padding_mask: Optional mask tensor |
|
|
positions: |
|
|
Returns: |
|
|
Tensor of shape T X B X C |
|
|
""" |
|
|
residual = x |
|
|
x = self.ffn1(x) |
|
|
x = x * 0.5 + residual |
|
|
residual = x |
|
|
x = self.self_attn_layer_norm(x) |
|
|
if self.pos_enc_type == "rel_pos": |
|
|
x, attn = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
key_padding_mask=encoder_padding_mask, |
|
|
pos_emb=position_emb, |
|
|
need_weights=False, |
|
|
) |
|
|
else: |
|
|
x, attn = self.self_attn( |
|
|
query=x, |
|
|
key=x, |
|
|
value=x, |
|
|
key_padding_mask=encoder_padding_mask, |
|
|
need_weights=False, |
|
|
) |
|
|
x = self.self_attn_dropout(x) |
|
|
x = x + residual |
|
|
|
|
|
residual = x |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
x = self.conv_module(x) |
|
|
|
|
|
x = x.transpose(0, 1) |
|
|
x = residual + x |
|
|
|
|
|
residual = x |
|
|
x = self.ffn2(x) |
|
|
|
|
|
layer_result = x |
|
|
|
|
|
x = x * 0.5 + residual |
|
|
|
|
|
x = self.final_layer_norm(x) |
|
|
return x, (attn, layer_result) |
|
|
|
|
|
|
|
|
class ConformerWav2Vec2EncoderLayer(ConformerEncoderLayer): |
|
|
"""Encoder layer for Wav2vec2 encoder""" |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
self_attn_mask: torch.Tensor = None, |
|
|
self_attn_padding_mask: torch.Tensor = None, |
|
|
need_weights: bool = False, |
|
|
att_args=None, |
|
|
position_emb=None, |
|
|
): |
|
|
return super().forward(x, self_attn_padding_mask, position_emb) |
|
|
|