| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import math |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| from encoder_interface import EncoderInterface |
| from subsampling import Conv2dSubsampling, VggSubsampling |
|
|
| from icefall.utils import make_pad_mask |
|
|
|
|
| class Transformer(EncoderInterface): |
| def __init__( |
| self, |
| num_features: int, |
| output_dim: int, |
| subsampling_factor: int = 4, |
| d_model: int = 256, |
| nhead: int = 4, |
| dim_feedforward: int = 2048, |
| num_encoder_layers: int = 12, |
| dropout: float = 0.1, |
| normalize_before: bool = True, |
| vgg_frontend: bool = False, |
| ) -> None: |
| """ |
| Args: |
| num_features: |
| The input dimension of the model. |
| output_dim: |
| The output dimension of the model. |
| subsampling_factor: |
| Number of output frames is num_in_frames // subsampling_factor. |
| Currently, subsampling_factor MUST be 4. |
| d_model: |
| Attention dimension. |
| nhead: |
| Number of heads in multi-head attention. |
| Must satisfy d_model // nhead == 0. |
| dim_feedforward: |
| The output dimension of the feedforward layers in encoder. |
| num_encoder_layers: |
| Number of encoder layers. |
| dropout: |
| Dropout in encoder. |
| normalize_before: |
| If True, use pre-layer norm; False to use post-layer norm. |
| vgg_frontend: |
| True to use vgg style frontend for subsampling. |
| """ |
| super().__init__() |
|
|
| self.num_features = num_features |
| self.output_dim = output_dim |
| self.subsampling_factor = subsampling_factor |
| if subsampling_factor != 4: |
| raise NotImplementedError("Support only 'subsampling_factor=4'.") |
|
|
| |
| |
| |
| |
| |
| if vgg_frontend: |
| self.encoder_embed = VggSubsampling(num_features, d_model) |
| else: |
| self.encoder_embed = Conv2dSubsampling(num_features, d_model) |
|
|
| self.encoder_pos = PositionalEncoding(d_model, dropout) |
|
|
| encoder_layer = TransformerEncoderLayer( |
| d_model=d_model, |
| nhead=nhead, |
| dim_feedforward=dim_feedforward, |
| dropout=dropout, |
| normalize_before=normalize_before, |
| ) |
|
|
| if normalize_before: |
| encoder_norm = nn.LayerNorm(d_model) |
| else: |
| encoder_norm = None |
|
|
| self.encoder = nn.TransformerEncoder( |
| encoder_layer=encoder_layer, |
| num_layers=num_encoder_layers, |
| norm=encoder_norm, |
| ) |
|
|
| |
| self.encoder_output_layer = nn.Sequential( |
| nn.Dropout(p=dropout), nn.Linear(d_model, output_dim) |
| ) |
|
|
| def forward( |
| self, x: torch.Tensor, x_lens: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: |
| The input tensor. Its shape is (batch_size, seq_len, feature_dim). |
| x_lens: |
| A tensor of shape (batch_size,) containing the number of frames in |
| `x` before padding. |
| Returns: |
| Return a tuple containing 2 tensors: |
| - logits, its shape is (batch_size, output_seq_len, output_dim) |
| - logit_lens, a tensor of shape (batch_size,) containing the number |
| of frames in `logits` before padding. |
| """ |
| x = self.encoder_embed(x) |
| x = self.encoder_pos(x) |
| x = x.permute(1, 0, 2) |
|
|
| |
| lengths = ((x_lens - 1) // 2 - 1) // 2 |
| assert x.size(0) == lengths.max().item() |
|
|
| mask = make_pad_mask(lengths) |
| x = self.encoder(x, src_key_padding_mask=mask) |
|
|
| logits = self.encoder_output_layer(x) |
| logits = logits.permute(1, 0, 2) |
|
|
| return logits, lengths |
|
|
|
|
| class TransformerEncoderLayer(nn.Module): |
| """ |
| Modified from torch.nn.TransformerEncoderLayer. |
| Add support of normalize_before, |
| i.e., use layer_norm before the first block. |
| |
| Args: |
| d_model: |
| the number of expected features in the input (required). |
| nhead: |
| the number of heads in the multiheadattention models (required). |
| dim_feedforward: |
| the dimension of the feedforward network model (default=2048). |
| dropout: |
| the dropout value (default=0.1). |
| activation: |
| the activation function of intermediate layer, relu or |
| gelu (default=relu). |
| normalize_before: |
| whether to use layer_norm before the first block. |
| |
| Examples:: |
| >>> encoder_layer = TransformerEncoderLayer(d_model=512, nhead=8) |
| >>> src = torch.rand(10, 32, 512) |
| >>> out = encoder_layer(src) |
| """ |
|
|
| def __init__( |
| self, |
| d_model: int, |
| nhead: int, |
| dim_feedforward: int = 2048, |
| dropout: float = 0.1, |
| activation: str = "relu", |
| normalize_before: bool = True, |
| ) -> None: |
| super(TransformerEncoderLayer, self).__init__() |
| self.self_attn = nn.MultiheadAttention(d_model, nhead, dropout=0.0) |
| |
| self.linear1 = nn.Linear(d_model, dim_feedforward) |
| self.dropout = nn.Dropout(dropout) |
| self.linear2 = nn.Linear(dim_feedforward, d_model) |
|
|
| self.norm1 = nn.LayerNorm(d_model) |
| self.norm2 = nn.LayerNorm(d_model) |
| self.dropout1 = nn.Dropout(dropout) |
| self.dropout2 = nn.Dropout(dropout) |
|
|
| self.activation = _get_activation_fn(activation) |
|
|
| self.normalize_before = normalize_before |
|
|
| def __setstate__(self, state): |
| if "activation" not in state: |
| state["activation"] = nn.functional.relu |
| super(TransformerEncoderLayer, self).__setstate__(state) |
|
|
| def forward( |
| self, |
| src: torch.Tensor, |
| src_mask: Optional[torch.Tensor] = None, |
| src_key_padding_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """ |
| Pass the input through the encoder layer. |
| |
| Args: |
| src: the sequence to the encoder layer (required). |
| src_mask: the mask for the src sequence (optional). |
| src_key_padding_mask: the mask for the src keys per batch (optional) |
| |
| Shape: |
| src: (S, N, E). |
| src_mask: (S, S). |
| src_key_padding_mask: (N, S). |
| S is the source sequence length, T is the target sequence length, |
| N is the batch size, E is the feature number |
| """ |
| residual = src |
| if self.normalize_before: |
| src = self.norm1(src) |
| src2 = self.self_attn( |
| src, |
| src, |
| src, |
| attn_mask=src_mask, |
| key_padding_mask=src_key_padding_mask, |
| )[0] |
| src = residual + self.dropout1(src2) |
| if not self.normalize_before: |
| src = self.norm1(src) |
|
|
| residual = src |
| if self.normalize_before: |
| src = self.norm2(src) |
| src2 = self.linear2(self.dropout(self.activation(self.linear1(src)))) |
| src = residual + self.dropout2(src2) |
| if not self.normalize_before: |
| src = self.norm2(src) |
| return src |
|
|
|
|
| def _get_activation_fn(activation: str): |
| if activation == "relu": |
| return nn.functional.relu |
| elif activation == "gelu": |
| return nn.functional.gelu |
|
|
| raise RuntimeError("activation should be relu/gelu, not {}".format(activation)) |
|
|
|
|
| class PositionalEncoding(nn.Module): |
| """This class implements the positional encoding |
| proposed in the following paper: |
| |
| - Attention Is All You Need: https://arxiv.org/pdf/1706.03762.pdf |
| |
| PE(pos, 2i) = sin(pos / (10000^(2i/d_modle)) |
| PE(pos, 2i+1) = cos(pos / (10000^(2i/d_modle)) |
| |
| Note:: |
| |
| 1 / (10000^(2i/d_model)) = exp(-log(10000^(2i/d_model))) |
| = exp(-1* 2i / d_model * log(100000)) |
| = exp(2i * -(log(10000) / d_model)) |
| """ |
|
|
| def __init__(self, d_model: int, dropout: float = 0.1) -> None: |
| """ |
| Args: |
| d_model: |
| Embedding dimension. |
| dropout: |
| Dropout probability to be applied to the output of this module. |
| """ |
| super().__init__() |
| self.d_model = d_model |
| self.xscale = math.sqrt(self.d_model) |
| self.dropout = nn.Dropout(p=dropout) |
| |
| self.pe = torch.zeros(1, 0, self.d_model, dtype=torch.float32) |
|
|
| def extend_pe(self, x: torch.Tensor) -> None: |
| """Extend the time t in the positional encoding if required. |
| |
| The shape of `self.pe` is (1, T1, d_model). The shape of the input x |
| is (N, T, d_model). If T > T1, then we change the shape of self.pe |
| to (N, T, d_model). Otherwise, nothing is done. |
| |
| Args: |
| x: |
| It is a tensor of shape (N, T, C). |
| Returns: |
| Return None. |
| """ |
| if self.pe is not None: |
| if self.pe.size(1) >= x.size(1): |
| self.pe = self.pe.to(dtype=x.dtype, device=x.device) |
| return |
| pe = torch.zeros(x.size(1), self.d_model, dtype=torch.float32) |
| position = torch.arange(0, x.size(1), dtype=torch.float32).unsqueeze(1) |
| div_term = torch.exp( |
| torch.arange(0, self.d_model, 2, dtype=torch.float32) |
| * -(math.log(10000.0) / self.d_model) |
| ) |
| pe[:, 0::2] = torch.sin(position * div_term) |
| pe[:, 1::2] = torch.cos(position * div_term) |
| pe = pe.unsqueeze(0) |
| |
| self.pe = pe.to(device=x.device, dtype=x.dtype) |
|
|
| def forward(self, x: torch.Tensor) -> torch.Tensor: |
| """ |
| Add positional encoding. |
| |
| Args: |
| x: |
| Its shape is (N, T, C) |
| |
| Returns: |
| Return a tensor of shape (N, T, C) |
| """ |
| self.extend_pe(x) |
| x = x * self.xscale + self.pe[:, : x.size(1), :] |
| return self.dropout(x) |
|
|
|
|
| class Noam(object): |
| """ |
| Implements Noam optimizer. |
| |
| Proposed in |
| "Attention Is All You Need", https://arxiv.org/pdf/1706.03762.pdf |
| |
| Modified from |
| https://github.com/espnet/espnet/blob/master/espnet/nets/pytorch_backend/transformer/optimizer.py # noqa |
| |
| Args: |
| params: |
| iterable of parameters to optimize or dicts defining parameter groups |
| model_size: |
| attention dimension of the transformer model |
| factor: |
| learning rate factor |
| warm_step: |
| warmup steps |
| """ |
|
|
| def __init__( |
| self, |
| params, |
| model_size: int = 256, |
| factor: float = 10.0, |
| warm_step: int = 25000, |
| weight_decay=0, |
| ) -> None: |
| """Construct an Noam object.""" |
| self.optimizer = torch.optim.Adam( |
| params, lr=0, betas=(0.9, 0.98), eps=1e-9, weight_decay=weight_decay |
| ) |
| self._step = 0 |
| self.warmup = warm_step |
| self.factor = factor |
| self.model_size = model_size |
| self._rate = 0 |
|
|
| @property |
| def param_groups(self): |
| """Return param_groups.""" |
| return self.optimizer.param_groups |
|
|
| def step(self): |
| """Update parameters and rate.""" |
| self._step += 1 |
| rate = self.rate() |
| for p in self.optimizer.param_groups: |
| p["lr"] = rate |
| self._rate = rate |
| self.optimizer.step() |
|
|
| def rate(self, step=None): |
| """Implement `lrate` above.""" |
| if step is None: |
| step = self._step |
| return ( |
| self.factor |
| * self.model_size ** (-0.5) |
| * min(step ** (-0.5), step * self.warmup ** (-1.5)) |
| ) |
|
|
| def zero_grad(self): |
| """Reset gradient.""" |
| self.optimizer.zero_grad() |
|
|
| def state_dict(self): |
| """Return state_dict.""" |
| return { |
| "_step": self._step, |
| "warmup": self.warmup, |
| "factor": self.factor, |
| "model_size": self.model_size, |
| "_rate": self._rate, |
| "optimizer": self.optimizer.state_dict(), |
| } |
|
|
| def load_state_dict(self, state_dict): |
| """Load state_dict.""" |
| for key, value in state_dict.items(): |
| if key == "optimizer": |
| self.optimizer.load_state_dict(state_dict["optimizer"]) |
| else: |
| setattr(self, key, value) |
|
|