|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from typing import Optional, Tuple |
|
|
|
|
|
import k2 |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from scaling import ScaledLinear |
|
|
|
|
|
from icefall.utils import add_sos, torch_autocast |
|
|
|
|
|
|
|
|
class AsrModel(nn.Module): |
|
|
def __init__( |
|
|
self, |
|
|
encoder, |
|
|
decoder: Optional[nn.Module] = None, |
|
|
joiner: Optional[nn.Module] = None, |
|
|
encoder_dim: int = 768, |
|
|
decoder_dim: int = 512, |
|
|
vocab_size: int = 500, |
|
|
use_transducer: bool = True, |
|
|
use_ctc: bool = False, |
|
|
): |
|
|
"""A joint CTC & Transducer ASR model. |
|
|
|
|
|
- Connectionist temporal classification: labelling unsegmented sequence data with recurrent neural networks (http://imagine.enpc.fr/~obozinsg/teaching/mva_gm/papers/ctc.pdf) |
|
|
- Sequence Transduction with Recurrent Neural Networks (https://arxiv.org/pdf/1211.3711.pdf) |
|
|
- Pruned RNN-T for fast, memory-efficient ASR training (https://arxiv.org/pdf/2206.13236.pdf) |
|
|
|
|
|
Args: |
|
|
encoder: |
|
|
It is the transcription network in the paper. Its accepts |
|
|
inputs: `x` of (N, T, encoder_dim). |
|
|
It returns two tensors: `logits` of shape (N, T, encoder_dim) and |
|
|
`logit_lens` of shape (N,). |
|
|
decoder: |
|
|
It is the prediction network in the paper. Its input shape |
|
|
is (N, U) and its output shape is (N, U, decoder_dim). |
|
|
It should contain one attribute: `blank_id`. |
|
|
It is used when use_transducer is True. |
|
|
joiner: |
|
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). |
|
|
Its output shape is (N, T, U, vocab_size). Note that its output contains |
|
|
unnormalized probs, i.e., not processed by log-softmax. |
|
|
It is used when use_transducer is True. |
|
|
use_transducer: |
|
|
Whether use transducer head. Default: True. |
|
|
use_ctc: |
|
|
Whether use CTC head. Default: False. |
|
|
""" |
|
|
super().__init__() |
|
|
|
|
|
assert ( |
|
|
use_transducer or use_ctc |
|
|
), f"At least one of them should be True, but got use_transducer={use_transducer}, use_ctc={use_ctc}" |
|
|
|
|
|
self.encoder = encoder |
|
|
|
|
|
self.use_transducer = use_transducer |
|
|
if use_transducer: |
|
|
|
|
|
assert decoder is not None |
|
|
assert hasattr(decoder, "blank_id") |
|
|
assert joiner is not None |
|
|
|
|
|
self.decoder = decoder |
|
|
self.joiner = joiner |
|
|
|
|
|
self.simple_am_proj = ScaledLinear( |
|
|
encoder_dim, vocab_size, initial_scale=0.25 |
|
|
) |
|
|
self.simple_lm_proj = ScaledLinear( |
|
|
decoder_dim, vocab_size, initial_scale=0.25 |
|
|
) |
|
|
else: |
|
|
assert decoder is None |
|
|
assert joiner is None |
|
|
|
|
|
self.use_ctc = use_ctc |
|
|
if use_ctc: |
|
|
|
|
|
self.ctc_output = nn.Sequential( |
|
|
nn.Dropout(p=0.1), |
|
|
nn.Linear(encoder_dim, vocab_size), |
|
|
nn.LogSoftmax(dim=-1), |
|
|
) |
|
|
|
|
|
def forward_encoder( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
padding_mask: Optional[torch.Tensor] = None, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute encoder outputs. |
|
|
Args: |
|
|
x: |
|
|
A 2-D tensor of shape (N, T). |
|
|
|
|
|
Returns: |
|
|
encoder_out: |
|
|
Encoder output, of shape (N, T, C). |
|
|
encoder_out_lens: |
|
|
Encoder output lengths, of shape (N,). |
|
|
""" |
|
|
if padding_mask is None: |
|
|
padding_mask = torch.zeros_like(x, dtype=torch.bool) |
|
|
|
|
|
encoder_out, padding_mask = self.encoder.extract_features( |
|
|
source=x, |
|
|
padding_mask=padding_mask, |
|
|
mask=self.encoder.training, |
|
|
) |
|
|
encoder_out_lens = torch.sum(~padding_mask, dim=1) |
|
|
assert torch.all(encoder_out_lens > 0), encoder_out_lens |
|
|
|
|
|
return encoder_out, encoder_out_lens |
|
|
|
|
|
def forward_ctc( |
|
|
self, |
|
|
encoder_out: torch.Tensor, |
|
|
encoder_out_lens: torch.Tensor, |
|
|
targets: torch.Tensor, |
|
|
target_lengths: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
"""Compute CTC loss. |
|
|
Args: |
|
|
encoder_out: |
|
|
Encoder output, of shape (N, T, C). |
|
|
encoder_out_lens: |
|
|
Encoder output lengths, of shape (N,). |
|
|
targets: |
|
|
Target Tensor of shape (sum(target_lengths)). The targets are assumed |
|
|
to be un-padded and concatenated within 1 dimension. |
|
|
""" |
|
|
|
|
|
ctc_output = self.ctc_output(encoder_out) |
|
|
|
|
|
ctc_loss = torch.nn.functional.ctc_loss( |
|
|
log_probs=ctc_output.permute(1, 0, 2), |
|
|
targets=targets, |
|
|
input_lengths=encoder_out_lens, |
|
|
target_lengths=target_lengths, |
|
|
reduction="sum", |
|
|
) |
|
|
return ctc_loss |
|
|
|
|
|
def forward_transducer( |
|
|
self, |
|
|
encoder_out: torch.Tensor, |
|
|
encoder_out_lens: torch.Tensor, |
|
|
y: k2.RaggedTensor, |
|
|
y_lens: torch.Tensor, |
|
|
prune_range: int = 5, |
|
|
am_scale: float = 0.0, |
|
|
lm_scale: float = 0.0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor]: |
|
|
"""Compute Transducer loss. |
|
|
Args: |
|
|
encoder_out: |
|
|
Encoder output, of shape (N, T, C). |
|
|
encoder_out_lens: |
|
|
Encoder output lengths, of shape (N,). |
|
|
y: |
|
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each |
|
|
utterance. |
|
|
prune_range: |
|
|
The prune range for rnnt loss, it means how many symbols(context) |
|
|
we are considering for each frame to compute the loss. |
|
|
am_scale: |
|
|
The scale to smooth the loss with am (output of encoder network) |
|
|
part |
|
|
lm_scale: |
|
|
The scale to smooth the loss with lm (output of predictor network) |
|
|
part |
|
|
""" |
|
|
|
|
|
blank_id = self.decoder.blank_id |
|
|
sos_y = add_sos(y, sos_id=blank_id) |
|
|
|
|
|
|
|
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) |
|
|
|
|
|
|
|
|
decoder_out = self.decoder(sos_y_padded) |
|
|
|
|
|
|
|
|
|
|
|
y_padded = y.pad(mode="constant", padding_value=0) |
|
|
|
|
|
y_padded = y_padded.to(torch.int64) |
|
|
boundary = torch.zeros( |
|
|
(encoder_out.size(0), 4), |
|
|
dtype=torch.int64, |
|
|
device=encoder_out.device, |
|
|
) |
|
|
boundary[:, 2] = y_lens |
|
|
boundary[:, 3] = encoder_out_lens |
|
|
|
|
|
lm = self.simple_lm_proj(decoder_out) |
|
|
am = self.simple_am_proj(encoder_out) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with torch_autocast(enabled=False): |
|
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( |
|
|
lm=lm.float(), |
|
|
am=am.float(), |
|
|
symbols=y_padded, |
|
|
termination_symbol=blank_id, |
|
|
lm_only_scale=lm_scale, |
|
|
am_only_scale=am_scale, |
|
|
boundary=boundary, |
|
|
reduction="sum", |
|
|
return_grad=True, |
|
|
) |
|
|
|
|
|
|
|
|
ranges = k2.get_rnnt_prune_ranges( |
|
|
px_grad=px_grad, |
|
|
py_grad=py_grad, |
|
|
boundary=boundary, |
|
|
s_range=prune_range, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
am_pruned, lm_pruned = k2.do_rnnt_pruning( |
|
|
am=self.joiner.encoder_proj(encoder_out), |
|
|
lm=self.joiner.decoder_proj(decoder_out), |
|
|
ranges=ranges, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False) |
|
|
|
|
|
with torch_autocast(enabled=False): |
|
|
pruned_loss = k2.rnnt_loss_pruned( |
|
|
logits=logits.float(), |
|
|
symbols=y_padded, |
|
|
ranges=ranges, |
|
|
termination_symbol=blank_id, |
|
|
boundary=boundary, |
|
|
reduction="sum", |
|
|
) |
|
|
|
|
|
return simple_loss, pruned_loss |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
y: k2.RaggedTensor, |
|
|
padding_mask: Optional[torch.Tensor] = None, |
|
|
prune_range: int = 5, |
|
|
am_scale: float = 0.0, |
|
|
lm_scale: float = 0.0, |
|
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Args: |
|
|
x: |
|
|
A 2-D tensor of shape (N, T). |
|
|
y: |
|
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each |
|
|
utterance. |
|
|
prune_range: |
|
|
The prune range for rnnt loss, it means how many symbols(context) |
|
|
we are considering for each frame to compute the loss. |
|
|
am_scale: |
|
|
The scale to smooth the loss with am (output of encoder network) |
|
|
part |
|
|
lm_scale: |
|
|
The scale to smooth the loss with lm (output of predictor network) |
|
|
part |
|
|
Returns: |
|
|
Return the transducer losses and CTC loss, |
|
|
in form of (simple_loss, pruned_loss, ctc_loss) |
|
|
|
|
|
Note: |
|
|
Regarding am_scale & lm_scale, it will make the loss-function one of |
|
|
the form: |
|
|
lm_scale * lm_probs + am_scale * am_probs + |
|
|
(1-lm_scale-am_scale) * combined_probs |
|
|
""" |
|
|
assert x.ndim == 2, x.shape |
|
|
assert y.num_axes == 2, y.num_axes |
|
|
|
|
|
assert x.size(0) == y.dim0, (x.shape, y.dim0) |
|
|
|
|
|
|
|
|
encoder_out, encoder_out_lens = self.forward_encoder(x, padding_mask) |
|
|
|
|
|
row_splits = y.shape.row_splits(1) |
|
|
y_lens = row_splits[1:] - row_splits[:-1] |
|
|
|
|
|
if self.use_transducer: |
|
|
|
|
|
simple_loss, pruned_loss = self.forward_transducer( |
|
|
encoder_out=encoder_out, |
|
|
encoder_out_lens=encoder_out_lens, |
|
|
y=y.to(x.device), |
|
|
y_lens=y_lens, |
|
|
prune_range=prune_range, |
|
|
am_scale=am_scale, |
|
|
lm_scale=lm_scale, |
|
|
) |
|
|
else: |
|
|
simple_loss = torch.empty(0) |
|
|
pruned_loss = torch.empty(0) |
|
|
|
|
|
if self.use_ctc: |
|
|
|
|
|
targets = y.values |
|
|
ctc_loss = self.forward_ctc( |
|
|
encoder_out=encoder_out, |
|
|
encoder_out_lens=encoder_out_lens, |
|
|
targets=targets, |
|
|
target_lengths=y_lens, |
|
|
) |
|
|
else: |
|
|
ctc_loss = torch.empty(0) |
|
|
|
|
|
return simple_loss, pruned_loss, ctc_loss, encoder_out_lens |
|
|
|