| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | from typing import Optional, Tuple |
| |
|
| | import k2 |
| | import torch |
| | import torch.nn as nn |
| | from encoder_interface import EncoderInterface |
| | from lhotse.dataset import SpecAugment |
| | from scaling import ScaledLinear |
| |
|
| | from icefall.utils import add_sos, make_pad_mask, time_warp, torch_autocast |
| |
|
| |
|
| | class AsrModel(nn.Module): |
| | def __init__( |
| | self, |
| | encoder_embed: nn.Module, |
| | encoder: EncoderInterface, |
| | decoder: Optional[nn.Module] = None, |
| | joiner: Optional[nn.Module] = None, |
| | attention_decoder: Optional[nn.Module] = None, |
| | encoder_dim: int = 384, |
| | decoder_dim: int = 512, |
| | vocab_size: int = 500, |
| | use_transducer: bool = True, |
| | use_ctc: bool = False, |
| | use_attention_decoder: bool = False, |
| | ): |
| | """A joint CTC & Transducer ASR model. |
| | |
| | - Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) |
| | - Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) |
| | - Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) |
| | |
| | Args: |
| | encoder_embed: |
| | It is a Convolutional 2D subsampling module. It converts |
| | an input of shape (N, T, idim) to an output of of shape |
| | (N, T', odim), where T' = (T-3)//2-2 = (T-7)//2. |
| | encoder: |
| | It is the transcription network in the paper. Its accepts |
| | two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). |
| | It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
| | `logit_lens` of shape (N,). |
| | decoder: |
| | It is the prediction network in the paper. Its input shape |
| | is (N, U) and its output shape is (N, U, decoder_dim). |
| | It should contain one attribute: `blank_id`. |
| | It is used when use_transducer is True. |
| | joiner: |
| | It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). |
| | Its output shape is (N, T, U, vocab_size). Note that its output contains |
| | unnormalized probs, i.e., not processed by log-softmax. |
| | It is used when use_transducer is True. |
| | use_transducer: |
| | Whether use transducer head. Default: True. |
| | use_ctc: |
| | Whether use CTC head. Default: False. |
| | use_attention_decoder: |
| | Whether use attention-decoder head. Default: False. |
| | """ |
| | super().__init__() |
| |
|
| | assert ( |
| | use_transducer or use_ctc |
| | ), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" |
| |
|
| | assert isinstance(encoder, EncoderInterface), type(encoder) |
| |
|
| | self.encoder_embed = encoder_embed |
| | self.encoder = encoder |
| |
|
| | self.use_transducer = use_transducer |
| | if use_transducer: |
| | |
| | assert decoder is not None |
| | assert hasattr(decoder, "blank_id") |
| | assert joiner is not None |
| |
|
| | self.decoder = decoder |
| | self.joiner = joiner |
| |
|
| | self.simple_am_proj = ScaledLinear( |
| | encoder_dim, vocab_size, initial_scale=0.25 |
| | ) |
| | self.simple_lm_proj = ScaledLinear( |
| | decoder_dim, vocab_size, initial_scale=0.25 |
| | ) |
| | else: |
| | assert decoder is None |
| | assert joiner is None |
| |
|
| | self.use_ctc = use_ctc |
| | if use_ctc: |
| | |
| | self.ctc_output = nn.Sequential( |
| | nn.Dropout(p=0.1), |
| | nn.Linear(encoder_dim, vocab_size), |
| | nn.LogSoftmax(dim=-1), |
| | ) |
| |
|
| | self.use_attention_decoder = use_attention_decoder |
| | if use_attention_decoder: |
| | self.attention_decoder = attention_decoder |
| | else: |
| | assert attention_decoder is None |
| |
|
| | def forward_encoder( |
| | self, x: torch.Tensor, x_lens: torch.Tensor |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute encoder outputs. |
| | Args: |
| | x: |
| | A 3-D tensor of shape (N, T, C). |
| | x_lens: |
| | A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| | before padding. |
| | |
| | Returns: |
| | encoder_out: |
| | Encoder output, of shape (N, T, C). |
| | encoder_out_lens: |
| | Encoder output lengths, of shape (N,). |
| | """ |
| | |
| | x, x_lens = self.encoder_embed(x, x_lens) |
| | |
| |
|
| | src_key_padding_mask = make_pad_mask(x_lens) |
| | x = x.permute(1, 0, 2) |
| |
|
| | encoder_out, encoder_out_lens = self.encoder(x, x_lens, src_key_padding_mask) |
| |
|
| | encoder_out = encoder_out.permute(1, 0, 2) |
| | assert torch.all(encoder_out_lens > 0), (x_lens, encoder_out_lens) |
| |
|
| | return encoder_out, encoder_out_lens |
| |
|
| | def forward_ctc( |
| | self, |
| | encoder_out: torch.Tensor, |
| | encoder_out_lens: torch.Tensor, |
| | targets: torch.Tensor, |
| | target_lengths: torch.Tensor, |
| | ) -> torch.Tensor: |
| | """Compute CTC loss. |
| | Args: |
| | encoder_out: |
| | Encoder output, of shape (N, T, C). |
| | encoder_out_lens: |
| | Encoder output lengths, of shape (N,). |
| | targets: |
| | Target Tensor of shape (sum(target_lengths)). The targets are assumed |
| | to be un-padded and concatenated within 1 dimension. |
| | """ |
| | |
| | ctc_output = self.ctc_output(encoder_out) |
| |
|
| | ctc_loss = torch.nn.functional.ctc_loss( |
| | log_probs=ctc_output.permute(1, 0, 2), |
| | targets=targets.cpu(), |
| | input_lengths=encoder_out_lens.cpu(), |
| | target_lengths=target_lengths.cpu(), |
| | reduction="sum", |
| | ) |
| | return ctc_loss |
| |
|
| | def forward_cr_ctc( |
| | self, |
| | encoder_out: torch.Tensor, |
| | encoder_out_lens: torch.Tensor, |
| | targets: torch.Tensor, |
| | target_lengths: torch.Tensor, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute CTC loss with consistency regularization loss. |
| | Args: |
| | encoder_out: |
| | Encoder output, of shape (2 * N, T, C). |
| | encoder_out_lens: |
| | Encoder output lengths, of shape (2 * N,). |
| | targets: |
| | Target Tensor of shape (2 * sum(target_lengths)). The targets are assumed |
| | to be un-padded and concatenated within 1 dimension. |
| | """ |
| | |
| | ctc_output = self.ctc_output(encoder_out) |
| | ctc_loss = torch.nn.functional.ctc_loss( |
| | log_probs=ctc_output.permute(1, 0, 2), |
| | targets=targets.cpu(), |
| | input_lengths=encoder_out_lens.cpu(), |
| | target_lengths=target_lengths.cpu(), |
| | reduction="sum", |
| | ) |
| |
|
| | |
| | batch_size = ctc_output.shape[0] |
| | assert batch_size % 2 == 0, batch_size |
| | |
| | exchanged_targets = torch.roll(ctc_output.detach(), batch_size // 2, dims=0) |
| | cr_loss = nn.functional.kl_div( |
| | input=ctc_output, |
| | target=exchanged_targets, |
| | reduction="none", |
| | log_target=True, |
| | ) |
| | length_mask = make_pad_mask(encoder_out_lens).unsqueeze(-1) |
| | cr_loss = cr_loss.masked_fill(length_mask, 0.0).sum() |
| |
|
| | return ctc_loss, cr_loss |
| |
|
| | def forward_transducer( |
| | self, |
| | encoder_out: torch.Tensor, |
| | encoder_out_lens: torch.Tensor, |
| | y: k2.RaggedTensor, |
| | y_lens: torch.Tensor, |
| | prune_range: int = 5, |
| | am_scale: float = 0.0, |
| | lm_scale: float = 0.0, |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Compute Transducer loss. |
| | Args: |
| | encoder_out: |
| | Encoder output, of shape (N, T, C). |
| | encoder_out_lens: |
| | Encoder output lengths, of shape (N,). |
| | y: |
| | A ragged tensor with 2 axes [utt][label]. It contains labels of each |
| | utterance. |
| | prune_range: |
| | The prune range for rnnt loss, it means how many symbols(context) |
| | we are considering for each frame to compute the loss. |
| | am_scale: |
| | The scale to smooth the loss with am (output of encoder network) |
| | part |
| | lm_scale: |
| | The scale to smooth the loss with lm (output of predictor network) |
| | part |
| | """ |
| | |
| | blank_id = self.decoder.blank_id |
| | sos_y = add_sos(y, sos_id=blank_id) |
| |
|
| | |
| | sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) |
| |
|
| | |
| | decoder_out = self.decoder(sos_y_padded) |
| |
|
| | |
| | |
| | y_padded = y.pad(mode="constant", padding_value=0) |
| |
|
| | y_padded = y_padded.to(torch.int64) |
| | boundary = torch.zeros( |
| | (encoder_out.size(0), 4), |
| | dtype=torch.int64, |
| | device=encoder_out.device, |
| | ) |
| | boundary[:, 2] = y_lens |
| | boundary[:, 3] = encoder_out_lens |
| |
|
| | lm = self.simple_lm_proj(decoder_out) |
| | am = self.simple_am_proj(encoder_out) |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | with torch_autocast(enabled=False): |
| | simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( |
| | lm=lm.float(), |
| | am=am.float(), |
| | symbols=y_padded, |
| | termination_symbol=blank_id, |
| | lm_only_scale=lm_scale, |
| | am_only_scale=am_scale, |
| | boundary=boundary, |
| | reduction="sum", |
| | return_grad=True, |
| | ) |
| |
|
| | |
| | ranges = k2.get_rnnt_prune_ranges( |
| | px_grad=px_grad, |
| | py_grad=py_grad, |
| | boundary=boundary, |
| | s_range=prune_range, |
| | ) |
| |
|
| | |
| | |
| | am_pruned, lm_pruned = k2.do_rnnt_pruning( |
| | am=self.joiner.encoder_proj(encoder_out), |
| | lm=self.joiner.decoder_proj(decoder_out), |
| | ranges=ranges, |
| | ) |
| |
|
| | |
| |
|
| | |
| | |
| | logits = self.joiner(am_pruned, lm_pruned, project_input=False) |
| |
|
| | with torch_autocast(enabled=False): |
| | pruned_loss = k2.rnnt_loss_pruned( |
| | logits=logits.float(), |
| | symbols=y_padded, |
| | ranges=ranges, |
| | termination_symbol=blank_id, |
| | boundary=boundary, |
| | reduction="sum", |
| | ) |
| |
|
| | return simple_loss, pruned_loss |
| |
|
| | def forward( |
| | self, |
| | x: torch.Tensor, |
| | x_lens: torch.Tensor, |
| | y: k2.RaggedTensor, |
| | prune_range: int = 5, |
| | am_scale: float = 0.0, |
| | lm_scale: float = 0.0, |
| | use_cr_ctc: bool = False, |
| | use_spec_aug: bool = False, |
| | spec_augment: Optional[SpecAugment] = None, |
| | supervision_segments: Optional[torch.Tensor] = None, |
| | time_warp_factor: Optional[int] = 80, |
| | ) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]: |
| | """ |
| | Args: |
| | x: |
| | A 3-D tensor of shape (N, T, C). |
| | x_lens: |
| | A 1-D tensor of shape (N,). It contains the number of frames in `x` |
| | before padding. |
| | y: |
| | A ragged tensor with 2 axes [utt][label]. It contains labels of each |
| | utterance. |
| | prune_range: |
| | The prune range for rnnt loss, it means how many symbols(context) |
| | we are considering for each frame to compute the loss. |
| | am_scale: |
| | The scale to smooth the loss with am (output of encoder network) |
| | part |
| | lm_scale: |
| | The scale to smooth the loss with lm (output of predictor network) |
| | part |
| | use_cr_ctc: |
| | Whether use consistency-regularized CTC. |
| | use_spec_aug: |
| | Whether apply spec-augment manually, used only if use_cr_ctc is True. |
| | spec_augment: |
| | The SpecAugment instance that returns time masks, |
| | used only if use_cr_ctc is True. |
| | supervision_segments: |
| | An int tensor of shape ``(S, 3)``. ``S`` is the number of |
| | supervision segments that exist in ``features``. |
| | Used only if use_cr_ctc is True. |
| | time_warp_factor: |
| | Parameter for the time warping; larger values mean more warping. |
| | Set to ``None``, or less than ``1``, to disable. |
| | Used only if use_cr_ctc is True. |
| | |
| | Returns: |
| | Return the transducer losses, CTC loss, AED loss, |
| | and consistency-regularization loss in form of |
| | (simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss) |
| | |
| | Note: |
| | Regarding am_scale & lm_scale, it will make the loss-function one of |
| | the form: |
| | lm_scale * lm_probs + am_scale * am_probs + |
| | (1-lm_scale-am_scale) * combined_probs |
| | """ |
| | assert x.ndim == 3, x.shape |
| | assert x_lens.ndim == 1, x_lens.shape |
| | assert y.num_axes == 2, y.num_axes |
| |
|
| | assert x.size(0) == x_lens.size(0) == y.dim0, (x.shape, x_lens.shape, y.dim0) |
| |
|
| | device = x.device |
| |
|
| | if use_cr_ctc: |
| | assert self.use_ctc |
| | if use_spec_aug: |
| | assert spec_augment is not None and spec_augment.time_warp_factor < 1 |
| | |
| | assert supervision_segments is not None |
| | x = time_warp( |
| | x, |
| | time_warp_factor=time_warp_factor, |
| | supervision_segments=supervision_segments, |
| | ) |
| | |
| | x = spec_augment(x.repeat(2, 1, 1)) |
| | else: |
| | x = x.repeat(2, 1, 1) |
| | x_lens = x_lens.repeat(2) |
| | y = k2.ragged.cat([y, y], axis=0) |
| |
|
| | |
| | encoder_out, encoder_out_lens = self.forward_encoder(x, x_lens) |
| |
|
| | row_splits = y.shape.row_splits(1) |
| | y_lens = row_splits[1:] - row_splits[:-1] |
| |
|
| | if self.use_transducer: |
| | |
| | simple_loss, pruned_loss = self.forward_transducer( |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | y=y.to(device), |
| | y_lens=y_lens, |
| | prune_range=prune_range, |
| | am_scale=am_scale, |
| | lm_scale=lm_scale, |
| | ) |
| | if use_cr_ctc: |
| | simple_loss = simple_loss * 0.5 |
| | pruned_loss = pruned_loss * 0.5 |
| | else: |
| | simple_loss = torch.empty(0) |
| | pruned_loss = torch.empty(0) |
| |
|
| | if self.use_ctc: |
| | |
| | targets = y.values |
| | if not use_cr_ctc: |
| | ctc_loss = self.forward_ctc( |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | targets=targets, |
| | target_lengths=y_lens, |
| | ) |
| | cr_loss = torch.empty(0) |
| | else: |
| | ctc_loss, cr_loss = self.forward_cr_ctc( |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | targets=targets, |
| | target_lengths=y_lens, |
| | ) |
| | ctc_loss = ctc_loss * 0.5 |
| | cr_loss = cr_loss * 0.5 |
| | else: |
| | ctc_loss = torch.empty(0) |
| | cr_loss = torch.empty(0) |
| |
|
| | if self.use_attention_decoder: |
| | attention_decoder_loss = self.attention_decoder.calc_att_loss( |
| | encoder_out=encoder_out, |
| | encoder_out_lens=encoder_out_lens, |
| | ys=y.to(device), |
| | ys_lens=y_lens.to(device), |
| | ) |
| | if use_cr_ctc: |
| | attention_decoder_loss = attention_decoder_loss * 0.5 |
| | else: |
| | attention_decoder_loss = torch.empty(0) |
| |
|
| | return simple_loss, pruned_loss, ctc_loss, attention_decoder_loss, cr_loss |
| |
|