| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| """ |
| Args: |
| lengths: |
| A 1-D tensor containing sentence lengths. |
| max_len: |
| The length of masks. |
| Returns: |
| Return a 2-D bool tensor, where masked positions |
| are filled with `True` and non-masked positions are |
| filled with `False`. |
| |
| >>> lengths = torch.tensor([1, 3, 2, 5]) |
| >>> make_pad_mask(lengths) |
| tensor([[False, True, True, True, True], |
| [False, False, False, True, True], |
| [False, False, True, True, True], |
| [False, False, False, False, False]]) |
| """ |
| assert lengths.ndim == 1, lengths.ndim |
| max_len = max(max_len, lengths.max()) |
| n = lengths.size(0) |
| seq_range = torch.arange(0, max_len, device=lengths.device) |
| expaned_lengths = seq_range.unsqueeze(0).expand(n, max_len) |
|
|
| return expaned_lengths >= lengths.unsqueeze(-1) |
|
|
| class ZipformerModel(nn.Module): |
| def __init__( |
| self, |
| encoder_embed: nn.Module, |
| encoder: nn.Module, |
| encoder_dim: int, |
| ): |
| """The simplest Zipformer Model for inference. |
| |
| Args: |
| encoder_embed (nn.Module): convolution embedding module |
| encoder (nn.Module): Zipformer Encoder |
| encoder_dim (int): Dimension of Ziformer Encoder |
| """ |
| super().__init__() |
| self.encoder_embed = encoder_embed |
| self.encoder = encoder |
| self.encoder_dim = encoder_dim |
| |
| def forward_encoder( |
| self, x: torch.Tensor, x_lens: torch.Tensor |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute encoder outputs. |
| Args: |
| x: |
| A 3-D tensor of shape (N, T, C). |
| x_lens: |
| A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| before padding. |
| |
| Returns: |
| encoder_out: |
| Encoder output, of shape (N, T, C). |
| encoder_out_lens: |
| Encoder output lengths, of shape (N,). |
| """ |
| |
| x, x_lens = self.encoder_embed(x, x_lens) |
| |
|
|
| src_key_padding_mask = make_pad_mask(x_lens) |
| x = x.permute(1, 0, 2) |
|
|
| encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) |
|
|
| encoder_out = encoder_out.permute(1, 0, 2) |
| assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) |
|
|
| return encoder_out, encoder_out_lens |
| |
| def forward( |
| self, |
| x: torch.Tensor, |
| x_lens: torch.Tensor, |
| codebook_indexes: torch.Tensor = None, |
| at_targets: torch.Tensor = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: |
| A 3-D tensor of shape (N, T, C). |
| x_lens: |
| A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| before padding. |
| codebook_indexes: |
| Codebook indexes of teacher embeddings |
| |
| Returns: |
| Return the transducer losses and CTC loss, |
| in form of (simple_loss, pruned_loss, ctc_loss) |
| |
| Note: |
| Regarding am_scale & lm_scale, it will make the loss-function one of |
| the form: |
| lm_scale * lm_probs + am_scale * am_probs + |
| (1-lm_scale-am_scale) * combined_probs |
| """ |
| assert x.ndim == 3, x.shape |
| assert x_lens.ndim == 1, x_lens.shape |
| assert codebook_indexes is not None or at_targets is not None |
|
|
| |
| encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) |
| |
| return encoder_out, encoder_out_lens |