|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import k2 |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from encoder_interface import EncoderInterface |
|
|
from scaling import ScaledLinear |
|
|
|
|
|
from icefall.utils import add_sos, torch_autocast |
|
|
|
|
|
|
|
|
class Transducer(nn.Module): |
|
|
"""It implements https://arxiv.org/pdf/1211.3711.pdf |
|
|
"Sequence Transduction with Recurrent Neural Networks" |
|
|
""" |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
encoder: EncoderInterface, |
|
|
decoder: nn.Module, |
|
|
joiner: nn.Module, |
|
|
encoder_dim: int, |
|
|
decoder_dim: int, |
|
|
joiner_dim: int, |
|
|
vocab_size: int, |
|
|
): |
|
|
""" |
|
|
Args: |
|
|
encoder: |
|
|
It is the transcription network in the paper. Its accepts |
|
|
two inputs: `x` of (N, T, encoder_dim) and `x_lens` of shape (N,). |
|
|
It returns two tensors: `logits` of shape (N, T, encoder_dm) and |
|
|
`logit_lens` of shape (N,). |
|
|
decoder: |
|
|
It is the prediction network in the paper. Its input shape |
|
|
is (N, U) and its output shape is (N, U, decoder_dim). |
|
|
It should contain one attribute: `blank_id`. |
|
|
joiner: |
|
|
It has two inputs with shapes: (N, T, encoder_dim) and (N, U, decoder_dim). |
|
|
Its output shape is (N, T, U, vocab_size). Note that its output contains |
|
|
unnormalized probs, i.e., not processed by log-softmax. |
|
|
""" |
|
|
super().__init__() |
|
|
assert isinstance(encoder, EncoderInterface), type(encoder) |
|
|
assert hasattr(decoder, "blank_id") |
|
|
|
|
|
self.encoder = encoder |
|
|
self.decoder = decoder |
|
|
self.joiner = joiner |
|
|
|
|
|
self.simple_am_proj = ScaledLinear(encoder_dim, vocab_size, initial_speed=0.5) |
|
|
self.simple_lm_proj = ScaledLinear(decoder_dim, vocab_size) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
x: torch.Tensor, |
|
|
x_lens: torch.Tensor, |
|
|
y: k2.RaggedTensor, |
|
|
prune_range: int = 5, |
|
|
am_scale: float = 0.0, |
|
|
lm_scale: float = 0.0, |
|
|
warmup: float = 1.0, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Args: |
|
|
x: |
|
|
A 3-D tensor of shape (N, T, C). |
|
|
x_lens: |
|
|
A 1-D tensor of shape (N,). It contains the number of frames in `x` |
|
|
before padding. |
|
|
y: |
|
|
A ragged tensor with 2 axes [utt][label]. It contains labels of each |
|
|
utterance. |
|
|
prune_range: |
|
|
The prune range for rnnt loss, it means how many symbols(context) |
|
|
we are considering for each frame to compute the loss. |
|
|
am_scale: |
|
|
The scale to smooth the loss with am (output of encoder network) |
|
|
part |
|
|
lm_scale: |
|
|
The scale to smooth the loss with lm (output of predictor network) |
|
|
part |
|
|
warmup: |
|
|
A value warmup >= 0 that determines which modules are active, values |
|
|
warmup > 1 "are fully warmed up" and all modules will be active. |
|
|
Returns: |
|
|
Return the transducer loss. |
|
|
|
|
|
Note: |
|
|
Regarding am_scale & lm_scale, it will make the loss-function one of |
|
|
the form: |
|
|
lm_scale * lm_probs + am_scale * am_probs + |
|
|
(1-lm_scale-am_scale) * combined_probs |
|
|
""" |
|
|
assert x.ndim == 3, x.shape |
|
|
assert x_lens.ndim == 1, x_lens.shape |
|
|
assert y.num_axes == 2, y.num_axes |
|
|
|
|
|
assert x.size(0) == x_lens.size(0) == y.dim0 |
|
|
|
|
|
encoder_out, x_lens = self.encoder(x, x_lens, warmup=warmup) |
|
|
assert torch.all(x_lens > 0) |
|
|
|
|
|
|
|
|
row_splits = y.shape.row_splits(1) |
|
|
y_lens = row_splits[1:] - row_splits[:-1] |
|
|
|
|
|
blank_id = self.decoder.blank_id |
|
|
sos_y = add_sos(y, sos_id=blank_id) |
|
|
|
|
|
|
|
|
sos_y_padded = sos_y.pad(mode="constant", padding_value=blank_id) |
|
|
|
|
|
|
|
|
decoder_out = self.decoder(sos_y_padded) |
|
|
|
|
|
|
|
|
|
|
|
y_padded = y.pad(mode="constant", padding_value=0) |
|
|
|
|
|
y_padded = y_padded.to(torch.int64) |
|
|
boundary = torch.zeros((x.size(0), 4), dtype=torch.int64, device=x.device) |
|
|
boundary[:, 2] = y_lens |
|
|
boundary[:, 3] = x_lens |
|
|
|
|
|
lm = self.simple_lm_proj(decoder_out) |
|
|
am = self.simple_am_proj(encoder_out) |
|
|
|
|
|
with torch_autocast(enabled=False): |
|
|
simple_loss, (px_grad, py_grad) = k2.rnnt_loss_smoothed( |
|
|
lm=lm.float(), |
|
|
am=am.float(), |
|
|
symbols=y_padded, |
|
|
termination_symbol=blank_id, |
|
|
lm_only_scale=lm_scale, |
|
|
am_only_scale=am_scale, |
|
|
boundary=boundary, |
|
|
reduction="sum", |
|
|
return_grad=True, |
|
|
) |
|
|
|
|
|
|
|
|
ranges = k2.get_rnnt_prune_ranges( |
|
|
px_grad=px_grad, |
|
|
py_grad=py_grad, |
|
|
boundary=boundary, |
|
|
s_range=prune_range, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
am_pruned, lm_pruned = k2.do_rnnt_pruning( |
|
|
am=self.joiner.encoder_proj(encoder_out), |
|
|
lm=self.joiner.decoder_proj(decoder_out), |
|
|
ranges=ranges, |
|
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
logits = self.joiner(am_pruned, lm_pruned, project_input=False) |
|
|
|
|
|
with torch_autocast(enabled=False): |
|
|
pruned_loss = k2.rnnt_loss_pruned( |
|
|
logits=logits.float(), |
|
|
symbols=y_padded, |
|
|
ranges=ranges, |
|
|
termination_symbol=blank_id, |
|
|
boundary=boundary, |
|
|
reduction="sum", |
|
|
) |
|
|
|
|
|
return (simple_loss, pruned_loss) |
|
|
|