| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
| from typing import Optional, Tuple |
|
|
| import k2 |
| import torch |
| import torch.nn as nn |
| from scaling import ScaledLinear |
|
|
| from icefall.utils import add_sos, torch_autocast |
|
|
|
|
| class AsrModel(nn.Module): |
| def __init__( |
| self, |
| encoder, |
| decoder: Optional[nn.Module] = None, |
| joiner: Optional[nn.Module] = None, |
| encoder_dim: int = 768, |
| decoder_dim: int = 512, |
| vocab_size: int = 500, |
| use_transducer: bool = True, |
| use_ctc: bool = False, |
| ): |
| """A joint CTC & Transducer ASR model. |
| |
| - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) |
| - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) |
| - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) |
| |
| Args: |
| encoder: |
| It is the transcription network in the paper. Its accepts |
| inputs: `x` of (N, T, encoder_dim). |
| It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
| `logit_lens` of shape (N,). |
| decoder: |
| It is the prediction network in the paper. Its input shape |
| is (N, U) and its output shape is (N, U, decoder_dim). |
| It should contain one attribute: `blank_id`. |
| It is used when use_transducer is True. |
| joiner: |
| It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). |
| Its output shape is (N, T, U, vocab_size). Note that its output contains |
| unnormalized probs, i.e., not processed by log-softmax. |
| It is used when use_transducer is True. |
| use_transducer: |
| Whether use transducer head. Default: True. |
| use_ctc: |
| Whether use CTC head. Default: False. |
| """ |
| super().__init__() |
|
|
| assert ( |
| use_transducer or use_ctc |
| ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" |
|
|
| self.encoder = encoder |
|
|
| self.use_transducer = use_transducer |
| if use_transducer: |
| |
| assert decoder is not None |
| assert hasattr(decoder, "blank_id") |
| assert joiner is not None |
|
|
| self.decoder = decoder |
| self.joiner = joiner |
|
|
| self.simple_am_proj = ScaledLinear( |
| encoder_dim, vocab_size, initial_scale=0.25 |
| ) |
| self.simple_lm_proj = ScaledLinear( |
| decoder_dim, vocab_size, initial_scale=0.25 |
| ) |
| else: |
| assert decoder is None |
| assert joiner is None |
|
|
| self.use_ctc = use_ctc |
| if use_ctc: |
| |
| self.ctc_output = nn.Sequential( |
| nn.Dropout(p=0.1), |
| nn.Linear(encoder_dim, vocab_size), |
| nn.LogSoftmax(dim=-1), |
| ) |
|
|
| def forward_encoder( |
| self, |
| x: torch.Tensor, |
| padding_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute encoder outputs. |
| Args: |
| x: |
| A 2-D tensor of shape (N, T). |
| |
| Returns: |
| encoder_out: |
| Encoder output, of shape (N, T, C). |
| encoder_out_lens: |
| Encoder output lengths, of shape (N,). |
| """ |
| if padding_mask is None: |
| padding_mask = torch.zeros_like(x, dtype=torch.bool) |
|
|
| encoder_out, padding_mask = self.encoder.extract_features( |
| source=x, |
| padding_mask=padding_mask, |
| mask=self.encoder.training, |
| ) |
| encoder_out_lens = torch.sum(~padding_mask, dim=1) |
| assert torch.all(encoder_out_lens > 0), encoder_out_lens |
|
|
| return encoder_out, encoder_out_lens |
|
|
| def forward_ctc( |
| self, |
| encoder_out: torch.Tensor, |
| encoder_out_lens: torch.Tensor, |
| targets: torch.Tensor, |
| target_lengths: torch.Tensor, |
| ) -> torch.Tensor: |
| """Compute CTC loss. |
| Args: |
| encoder_out: |
| Encoder output, of shape (N, T, C). |
| encoder_out_lens: |
| Encoder output lengths, of shape (N,). |
| targets: |
| Target Tensor of shape (sum(target_lengths)). The targets are assumed |
| to be un-padded and concatenated within 1 dimension. |
| """ |
| |
| ctc_output = self.ctc_output(encoder_out) |
|
|
| ctc_loss = torch.nn.functional.ctc_loss( |
| log_probs=ctc_output.permute(1, 0, 2), |
| targets=targets, |
| input_lengths=encoder_out_lens, |
| target_lengths=target_lengths, |
| reduction="sum", |
| ) |
| return ctc_loss |
|
|
| def forward_transducer( |
| self, |
| encoder_out: torch.Tensor, |
| encoder_out_lens: torch.Tensor, |
| y: k2.RaggedTensor, |
| y_lens: torch.Tensor, |
| prune_range: int = 5, |
| am_scale: float = 0.0, |
| lm_scale: float = 0.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Compute Transducer loss. |
| Args: |
| encoder_out: |
| Encoder output, of shape (N, T, C). |
| encoder_out_lens: |
| Encoder output lengths, of shape (N,). |
| y: |
| A ragged tensor with 2 axes [utt][label]. It contains labels of each |
| utterance. |
| prune_range: |
| The prune range for rnnt loss, it means how many symbols(context) |
| we are considering for each frame to compute the loss. |
| am_scale: |
| The scale to smooth the loss with am (output of encoder network) |
| part |
| lm_scale: |
| The scale to smooth the loss with lm (output of predictor network) |
| part |
| """ |
| |
| blank_id = self.decoder.blank_id |
| sos_y = add_sos(y, sos_id=blank_id) |
|
|
| |
| sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) |
|
|
| |
| decoder_out = self.decoder(sos_y_padded) |
|
|
| |
| |
| y_padded = y.pad(mode="constant", padding_value=0) |
|
|
| y_padded = y_padded.to(torch.int64) |
| boundary = torch.zeros( |
| (encoder_out.size(0), 4), |
| dtype=torch.int64, |
| device=encoder_out.device, |
| ) |
| boundary[:, 2] = y_lens |
| boundary[:, 3] = encoder_out_lens |
|
|
| lm = self.simple_lm_proj(decoder_out) |
| am = self.simple_am_proj(encoder_out) |
|
|
| |
| |
| |
| |
|
|
| with torch_autocast(enabled=False): |
| simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( |
| lm=lm.float(), |
| am=am.float(), |
| symbols=y_padded, |
| termination_symbol=blank_id, |
| lm_only_scale=lm_scale, |
| am_only_scale=am_scale, |
| boundary=boundary, |
| reduction="sum", |
| return_grad=True, |
| ) |
|
|
| |
| ranges = k2.get_rnnt_prune_ranges( |
| px_grad=px_grad, |
| py_grad=py_grad, |
| boundary=boundary, |
| s_range=prune_range, |
| ) |
|
|
| |
| |
| am_pruned, lm_pruned = k2.do_rnnt_pruning( |
| am=self.joiner.encoder_proj(encoder_out), |
| lm=self.joiner.decoder_proj(decoder_out), |
| ranges=ranges, |
| ) |
|
|
| |
|
|
| |
| |
| logits = self.joiner(am_pruned, lm_pruned, project_input=False) |
|
|
| with torch_autocast(enabled=False): |
| pruned_loss = k2.rnnt_loss_pruned( |
| logits=logits.float(), |
| symbols=y_padded, |
| ranges=ranges, |
| termination_symbol=blank_id, |
| boundary=boundary, |
| reduction="sum", |
| ) |
|
|
| return simple_loss, pruned_loss |
|
|
| def forward( |
| self, |
| x: torch.Tensor, |
| y: k2.RaggedTensor, |
| padding_mask: Optional[torch.Tensor] = None, |
| prune_range: int = 5, |
| am_scale: float = 0.0, |
| lm_scale: float = 0.0, |
| ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
| """ |
| Args: |
| x: |
| A 2-D tensor of shape (N, T). |
| y: |
| A ragged tensor with 2 axes [utt][label]. It contains labels of each |
| utterance. |
| prune_range: |
| The prune range for rnnt loss, it means how many symbols(context) |
| we are considering for each frame to compute the loss. |
| am_scale: |
| The scale to smooth the loss with am (output of encoder network) |
| part |
| lm_scale: |
| The scale to smooth the loss with lm (output of predictor network) |
| part |
| Returns: |
| Return the transducer losses and CTC loss, |
| in form of (simple_loss, pruned_loss, ctc_loss) |
| |
| Note: |
| Regarding am_scale & lm_scale, it will make the loss-function one of |
| the form: |
| lm_scale * lm_probs + am_scale * am_probs + |
| (1-lm_scale-am_scale) * combined_probs |
| """ |
| assert x.ndim == 2, x.shape |
| assert y.num_axes == 2, y.num_axes |
|
|
| assert x.size(0) == y.dim0, (x.shape, y.dim0) |
|
|
| |
| encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask) |
|
|
| row_splits = y.shape.row_splits(1) |
| y_lens = row_splits[1:] - row_splits[:-1] |
|
|
| if self.use_transducer: |
| |
| simple_loss, pruned_loss = self.forward_transducer( |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| y=y.to(x.device), |
| y_lens=y_lens, |
| prune_range=prune_range, |
| am_scale=am_scale, |
| lm_scale=lm_scale, |
| ) |
| else: |
| simple_loss = torch.empty(0) |
| pruned_loss = torch.empty(0) |
|
|
| if self.use_ctc: |
| |
| targets = y.values |
| ctc_loss = self.forward_ctc( |
| encoder_out=encoder_out, |
| encoder_out_lens=encoder_out_lens, |
| targets=targets, |
| target_lengths=y_lens, |
| ) |
| else: |
| ctc_loss = torch.empty(0) |
|
|
| return simple_loss, pruned_loss, ctc_loss, encoder_out_lens |
|
|