| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import copy |
| | import math |
| | from typing import List, Optional, Tuple |
| |
|
| | import torch |
| | import torch.nn.functional as F |
| | from torch import Tensor, nn |
| |
|
| | from icefall.transformer_lm.attention import RelPositionMultiheadAttention |
| | from icefall.transformer_lm.scaling import ( |
| | ActivationBalancer, |
| | BasicNorm, |
| | DoubleSwish, |
| | ScaledConv1d, |
| | ScaledConv2d, |
| | ScaledLinear, |
| | ) |
| | from icefall.utils import is_jit_tracing, make_pad_mask |
| |
|
| |
|
| | class Transformer(torch.nn.Module): |
| | """_summary_ |
| | |
| | Args: |
| | input_dim (int): Input feature dimension |
| | d_mode (int): The dimension of the transformer |
| | dim_feedforward (int ): The dimension of the ffw module |
| | nhead (int): The number of attention heads |
| | dropout_rate (float): dropout rate |
| | att_dropout (float): dropout rate in attention module |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | input_dim: int, |
| | d_model: int, |
| | dim_feedforward: int, |
| | nhead: int = 4, |
| | num_layers: int = 6, |
| | dropout_rate: float = 0.1, |
| | att_dropout: float = 0.0, |
| | ): |
| | super().__init__() |
| |
|
| | self.encoder_layers = num_layers |
| | self.d_model = d_model |
| |
|
| | self.embed = ScaledLinear(input_dim, d_model) |
| | self.norm_before = BasicNorm(d_model, learn_eps=False) |
| |
|
| | self.encoder_pos = RelPositionalEncoding(d_model, dropout_rate) |
| |
|
| | encoder_layer = TransformerEncoderLayer( |
| | d_model=d_model, |
| | dim_feedforward=dim_feedforward, |
| | nhead=nhead, |
| | dropout_rate=dropout_rate, |
| | ) |
| |
|
| | self.encoder = TransformerEncoder(encoder_layer, num_layers) |
| |
|
| | def _create_attention_mask(self, x_lens: torch.Tensor): |
| | |
| | |
| | max_len = max(x_lens) |
| | ones = torch.ones(max_len, max_len, device=x_lens.device, dtype=torch.bool) |
| | return torch.triu(ones, diagonal=1) |
| |
|
| | def forward( |
| | self, x: torch.Tensor, x_lens: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Transformer forward |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (B,T,input_dim) |
| | x_lens (torch.Tensor): The length of input tensors before padding (B,) |
| | |
| | Returns: |
| | Return a tuple of 2 tensors: |
| | - x: output feature of the transformer (B,T,d_model) |
| | - x_lens: output feature lens of the transformer |
| | """ |
| |
|
| | attention_mask = self._create_attention_mask(x_lens) |
| | src_key_padding_mask = make_pad_mask(x_lens) |
| |
|
| | x = self.norm_before(self.embed(x)) |
| |
|
| | x, pos_emb = self.encoder_pos(x) |
| | x = x.permute(1, 0, 2) |
| |
|
| | x = self.encoder( |
| | x, |
| | pos_emb, |
| | mask=attention_mask, |
| | src_key_padding_mask=src_key_padding_mask, |
| | ) |
| |
|
| | x = x.permute(1, 0, 2) |
| | return x, x_lens |
| |
|
| |
|
| | class TransformerEncoder(torch.nn.Module): |
| | def __init__(self, encoder_layer: torch.nn.Module, num_layers: int) -> None: |
| | """TransformerEncoder is a stack of N encoder layers |
| | |
| | Args: |
| | encoder_layer (torch.nn.Module): an instance of the TransformerEncoderLayer() |
| | num_layers (int): Number of layers to be stacked |
| | """ |
| | super().__init__() |
| | self.layers = nn.ModuleList( |
| | [copy.deepcopy(encoder_layer) for i in range(num_layers)] |
| | ) |
| | self.num_layers = num_layers |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | pos_emb: torch.Tensor, |
| | src_key_padding_mask: Optional[torch.Tensor] = None, |
| | mask: Optional[torch.Tensor] = None, |
| | ) -> torch.Tensor: |
| | """_summary_ |
| | |
| | Args: |
| | src: the sequence to the encoder (required). |
| | pos_emb: Positional embedding tensor (required). |
| | mask: the mask for the src sequence (optional). |
| | src_key_padding_mask: the mask for the src keys per batch (optional). |
| | |
| | Returns: |
| | output: transformer encoded features |
| | """ |
| | output = src |
| |
|
| | for layer_index, mod in enumerate(self.layers): |
| | output = mod( |
| | output, |
| | pos_emb, |
| | src_key_padding_mask=src_key_padding_mask, |
| | src_mask=mask, |
| | ) |
| |
|
| | return output |
| |
|
| |
|
| | class TransformerEncoderLayer(torch.nn.Module): |
| | def __init__( |
| | self, |
| | d_model: int, |
| | dim_feedforward: int, |
| | nhead: int, |
| | dropout_rate: float, |
| | ): |
| | """TransformerEncoderLayer is made up of self-attn and feedforward module |
| | |
| | Args: |
| | d_model (int): The model size |
| | dim_feedforward (int): Dimension of ffw module |
| | nhead (int): Number of heads |
| | dropout_rate (float): Dropout rate |
| | """ |
| | super().__init__() |
| |
|
| | self.d_model = d_model |
| |
|
| | self.self_attn = RelPositionMultiheadAttention(d_model, nhead, dropout=0.0) |
| | self.feed_forward = nn.Sequential( |
| | ScaledLinear(d_model, dim_feedforward), |
| | ActivationBalancer(channel_dim=-1), |
| | DoubleSwish(), |
| | nn.Dropout(dropout_rate), |
| | ScaledLinear(dim_feedforward, d_model, initial_scale=0.25), |
| | ) |
| |
|
| | self.norm_final = BasicNorm(d_model) |
| |
|
| | self.balancer = ActivationBalancer( |
| | channel_dim=-1, min_positive=0.45, max_positive=0.55, max_abs=6.0 |
| | ) |
| |
|
| | self.dropout = nn.Dropout(dropout_rate) |
| |
|
| | def forward( |
| | self, |
| | src: torch.Tensor, |
| | pos_emb: torch.Tensor, |
| | src_key_padding_mask: Optional[torch.Tensor] = None, |
| | src_mask: Optional[torch.Tensor] = None, |
| | cache=None, |
| | ): |
| | """ |
| | Pass the input through the encoder layer. |
| | |
| | Args: |
| | src: the sequence to the encoder layer (required). |
| | pos_emb: Positional embedding tensor (required). |
| | src_key_padding_mask: the mask for the src keys per batch (optional). |
| | src_mask: the mask for the src sequence (optional). |
| | """ |
| | src_orig = src |
| |
|
| | src_att = self.self_attn( |
| | src, |
| | src, |
| | src, |
| | pos_emb=pos_emb, |
| | attn_mask=src_mask, |
| | key_padding_mask=src_key_padding_mask, |
| | )[0] |
| |
|
| | src = src + self.dropout(src_att) |
| |
|
| | |
| | src = src + self.dropout(self.feed_forward(src)) |
| |
|
| | src = self.norm_final(self.balancer(src)) |
| |
|
| | return src |
| |
|
| |
|
| | class RelPositionalEncoding(torch.nn.Module): |
| | """Relative positional encoding module. |
| | |
| | See : Appendix B in "Transformer-XL: Attentive Language Models Beyond a Fixed-Length Context" |
| | Modified from https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/embedding.py |
| | |
| | Args: |
| | d_model: Embedding dimension. |
| | dropout_rate: Dropout rate. |
| | max_len: Maximum input length. |
| | |
| | """ |
| |
|
| | def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000) -> None: |
| | """Construct an PositionalEncoding object.""" |
| | super(RelPositionalEncoding, self).__init__() |
| | if is_jit_tracing(): |
| | |
| | |
| | |
| | |
| | |
| | max_len = 10000 |
| |
|
| | self.d_model = d_model |
| | self.dropout = torch.nn.Dropout(p=dropout_rate) |
| | self.pe = None |
| | self.extend_pe(torch.tensor(0.0).expand(1, max_len)) |
| |
|
| | def extend_pe(self, x: torch.Tensor, left_context: int = 0) -> None: |
| | """Reset the positional encodings.""" |
| | x_size_1 = x.size(1) + left_context |
| | if self.pe is not None: |
| | |
| | |
| | if self.pe.size(1) >= x_size_1 * 2 - 1: |
| | |
| | if self.pe.dtype != x.dtype or str(self.pe.device) != str(x.device): |
| | self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| | return |
| | |
| | |
| | |
| | pe_positive = torch.zeros(x_size_1, self.d_model) |
| | pe_negative = torch.zeros(x_size_1, self.d_model) |
| | position = torch.arange(0, x_size_1, dtype=torch.float32).unsqueeze(1) |
| | div_term = torch.exp( |
| | torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| | * -(math.log(10000.0) / self.d_model) |
| | ) |
| | pe_positive[:, 0::2] = torch.sin(position * div_term) |
| | pe_positive[:, 1::2] = torch.cos(position * div_term) |
| | pe_negative[:, 0::2] = torch.sin(-1 * position * div_term) |
| | pe_negative[:, 1::2] = torch.cos(-1 * position * div_term) |
| |
|
| | |
| | |
| | |
| | pe_positive = torch.flip(pe_positive, [0]).unsqueeze(0) |
| | pe_negative = pe_negative[1:].unsqueeze(0) |
| | pe = torch.cat([pe_positive, pe_negative], dim=1) |
| | self.pe = pe.to(device=x.device, dtype=x.dtype) |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | left_context: int = 0, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add positional encoding. |
| | |
| | Args: |
| | x (torch.Tensor): Input tensor (batch, time, `*`). |
| | left_context (int): left context (in frames) used during streaming decoding. |
| | this is used only in real streaming decoding, in other circumstances, |
| | it MUST be 0. |
| | |
| | Returns: |
| | torch.Tensor: Encoded tensor (batch, time, `*`). |
| | torch.Tensor: Encoded tensor (batch, 2*time-1, `*`). |
| | |
| | """ |
| | self.extend_pe(x, left_context) |
| | x_size_1 = x.size(1) + left_context |
| | pos_emb = self.pe[ |
| | :, |
| | self.pe.size(1) // 2 |
| | - x_size_1 |
| | + 1 : self.pe.size(1) // 2 |
| | + x.size(1), |
| | ] |
| | return self.dropout(x), self.dropout(pos_emb) |
| |
|