Spaces:
Sleeping
Sleeping
| # Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| # | |
| # Copyright 2018-2019, Mingkun Huang | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import numpy as np | |
| import torch | |
| from torch.autograd import Function, Variable | |
| from torch.nn import Module | |
| def check_type(var, t, name): | |
| if var.dtype is not t: | |
| raise TypeError("{} must be {}".format(name, t)) | |
| def check_contiguous(var, name): | |
| if not var.is_contiguous(): | |
| raise ValueError("{} must be contiguous".format(name)) | |
| def check_dim(var, dim, name): | |
| if len(var.shape) != dim: | |
| raise ValueError("{} must be {}D".format(name, dim)) | |
| def certify_inputs(log_probs, labels, lengths, label_lengths): | |
| # check_type(log_probs, torch.float32, "log_probs") | |
| check_type(labels, torch.int32, "labels") | |
| check_type(label_lengths, torch.int32, "label_lengths") | |
| check_type(lengths, torch.int32, "lengths") | |
| check_contiguous(log_probs, "log_probs") | |
| check_contiguous(labels, "labels") | |
| check_contiguous(label_lengths, "label_lengths") | |
| check_contiguous(lengths, "lengths") | |
| if lengths.shape[0] != log_probs.shape[0]: | |
| raise ValueError( | |
| f"Must have a length per example. " | |
| f"Given lengths dim: {lengths.shape[0]}, " | |
| f"Log probs dim : {log_probs.shape[0]}" | |
| ) | |
| if label_lengths.shape[0] != log_probs.shape[0]: | |
| raise ValueError( | |
| "Must have a label length per example. " | |
| f"Given label lengths dim : {label_lengths.shape[0]}, " | |
| f"Log probs dim : {log_probs.shape[0]}" | |
| ) | |
| check_dim(log_probs, 4, "log_probs") | |
| check_dim(labels, 2, "labels") | |
| check_dim(lengths, 1, "lenghts") | |
| check_dim(label_lengths, 1, "label_lenghts") | |
| max_T = torch.max(lengths) | |
| max_U = torch.max(label_lengths) | |
| T, U = log_probs.shape[1:3] | |
| if T != max_T: | |
| raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}") | |
| if U != max_U + 1: | |
| raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1") | |
| def _assert_no_grad(tensor): | |
| assert not tensor.requires_grad, ( | |
| "gradients only computed for log_probs - please " "mark other tensors as not requiring gradients" | |
| ) | |
| def forward_pass(log_probs, labels, blank): | |
| """ | |
| Computes probability of the forward variable alpha. | |
| Args: | |
| log_probs: Tensor of shape [T, U, V+1] | |
| labels: Labels of shape [B, U] | |
| blank: Index of the blank token. | |
| Returns: | |
| A tuple of the forward variable probabilities - alpha of shape [T, U] | |
| and the log likelihood of this forward step. | |
| """ | |
| T, U, _ = log_probs.shape | |
| alphas = np.zeros((T, U), dtype='f') | |
| for t in range(1, T): | |
| alphas[t, 0] = alphas[t - 1, 0] + log_probs[t - 1, 0, blank] | |
| for u in range(1, U): | |
| alphas[0, u] = alphas[0, u - 1] + log_probs[0, u - 1, labels[u - 1]] | |
| for t in range(1, T): | |
| for u in range(1, U): | |
| no_emit = alphas[t - 1, u] + log_probs[t - 1, u, blank] | |
| emit = alphas[t, u - 1] + log_probs[t, u - 1, labels[u - 1]] | |
| alphas[t, u] = np.logaddexp(emit, no_emit) | |
| loglike = alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank] | |
| return alphas, loglike | |
| def backward_pass(log_probs, labels, blank): | |
| """ | |
| Computes probability of the backward variable beta. | |
| Args: | |
| log_probs: Tensor of shape [T, U, V+1] | |
| labels: Labels of shape [B, U] | |
| blank: Index of the blank token. | |
| Returns: | |
| A tuple of the backward variable probabilities - beta of shape [T, U] | |
| and the log likelihood of this backward step. | |
| """ | |
| T, U, _ = log_probs.shape | |
| betas = np.zeros((T, U), dtype='f') | |
| betas[T - 1, U - 1] = log_probs[T - 1, U - 1, blank] | |
| for t in reversed(range(T - 1)): | |
| betas[t, U - 1] = betas[t + 1, U - 1] + log_probs[t, U - 1, blank] | |
| for u in reversed(range(U - 1)): | |
| betas[T - 1, u] = betas[T - 1, u + 1] + log_probs[T - 1, u, labels[u]] | |
| for t in reversed(range(T - 1)): | |
| for u in reversed(range(U - 1)): | |
| no_emit = betas[t + 1, u] + log_probs[t, u, blank] | |
| emit = betas[t, u + 1] + log_probs[t, u, labels[u]] | |
| betas[t, u] = np.logaddexp(emit, no_emit) | |
| return betas, betas[0, 0] | |
| def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda): | |
| """ | |
| Computes the gradients of the log_probs with respect to the log probability of this step occuring. | |
| Args: | |
| Args: | |
| log_probs: Tensor of shape [T, U, V+1] | |
| alphas: Tensor of shape [T, U] which represents the forward variable. | |
| betas: Tensor of shape [T, U] which represents the backward variable. | |
| labels: Labels of shape [B, U] | |
| blank: Index of the blank token. | |
| Returns: | |
| Gradients of shape [T, U, V+1] with respect to the forward log probability | |
| """ | |
| T, U, _ = log_probs.shape | |
| grads = np.full(log_probs.shape, -float("inf")) | |
| log_like = betas[0, 0] # == alphas[T - 1, U - 1] + betas[T - 1, U - 1] | |
| # // grad to last blank transition | |
| grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1] | |
| grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :] | |
| # // grad to label transition | |
| for u, l in enumerate(labels): | |
| grads[:, u, l] = alphas[:, u] + betas[:, u + 1] | |
| grads = -np.exp(grads + log_probs - log_like) | |
| if fastemit_lambda > 0.0: | |
| for u, l in enumerate(labels): | |
| grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l] | |
| return grads | |
| def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda): | |
| """ | |
| Describes the computation of FastEmit regularization from the paper - | |
| [FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148) | |
| Args: | |
| log_probs: Tensor of shape [T, U, V+1] | |
| labels: Unused. Labels of shape [B, U] | |
| alphas: Tensor of shape [T, U] which represents the forward variable. | |
| betas: Unused. Tensor of shape [T, U] which represents the backward variable. | |
| blank: Index of the blank token. | |
| fastemit_lambda: Float scaling factor for FastEmit regularization. | |
| Returns: | |
| The regularized negative log likelihood - lambda * P˜(At, u|x) | |
| """ | |
| # General calculation of the fastemit regularization alignments | |
| T, U, _ = log_probs.shape | |
| # alignment = np.zeros((T, U), dtype='float32') | |
| # | |
| # for t in range(0, T): | |
| # alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1] | |
| # | |
| # for t in range(0, T): | |
| # for u in range(0, U - 1): | |
| # emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1] | |
| # alignment[t, u] = emit | |
| # reg = fastemit_lambda * (alignment[T - 1, U - 1]) | |
| # The above is equivalent to below, without need of computing above | |
| # reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1]) | |
| # The above is also equivalent to below, without need of computing the betas alignment matrix | |
| reg = fastemit_lambda * (alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank]) | |
| return -reg | |
| def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0): | |
| """ | |
| Args: | |
| log_probs: 3D array with shape | |
| [input len, output len + 1, vocab size] | |
| labels: 1D array with shape [output time steps] | |
| blank: Index of the blank token. | |
| fastemit_lambda: Float scaling factor for FastEmit regularization. | |
| Returns: | |
| float: The negative log-likelihood | |
| 3D array: Gradients with respect to the | |
| unnormalized input actications | |
| 2d arrays: Alphas matrix (TxU) | |
| 2d array: Betas matrix (TxU) | |
| """ | |
| alphas, ll_forward = forward_pass(log_probs, labels, blank) | |
| betas, ll_backward = backward_pass(log_probs, labels, blank) | |
| grads = compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda) | |
| return -ll_forward, grads, alphas, betas | |
| def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0): | |
| """ | |
| Compute the transducer loss of the batch. | |
| Args: | |
| log_probs: [B, T, U, V+1]. Activation matrix normalized with log-softmax. | |
| labels: [B, U+1] - ground truth labels with <SOS> padded as blank token in the beginning. | |
| flen: Length vector of the acoustic sequence. | |
| glen: Length vector of the target sequence. | |
| blank: Id of the blank token. | |
| fastemit_lambda: Float scaling factor for FastEmit regularization. | |
| Returns: | |
| Batch of transducer forward log probabilities (loss) and the gradients of the activation matrix. | |
| """ | |
| grads = np.zeros_like(log_probs) | |
| costs = [] | |
| for b in range(log_probs.shape[0]): | |
| t = int(flen[b]) | |
| u = int(glen[b]) + 1 | |
| ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda) | |
| grads[b, :t, :u, :] = g | |
| reg = fastemit_regularization( | |
| log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda | |
| ) | |
| ll += reg | |
| costs.append(ll) | |
| return costs, grads | |
| class _RNNT(Function): | |
| def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda): | |
| costs, grads = transduce_batch( | |
| acts.detach().cpu().numpy(), | |
| labels.cpu().numpy(), | |
| act_lens.cpu().numpy(), | |
| label_lens.cpu().numpy(), | |
| blank, | |
| fastemit_lambda, | |
| ) | |
| costs = torch.FloatTensor([sum(costs)]) | |
| grads = torch.Tensor(grads).to(acts) | |
| ctx.grads = grads | |
| return costs | |
| def backward(ctx, grad_output): | |
| return ctx.grads, None, None, None, None, None | |
| class RNNTLoss(Module): | |
| """ | |
| Parameters: | |
| `blank_label` (int): default 0 - label index of blank token | |
| fastemit_lambda: Float scaling factor for FastEmit regularization. | |
| """ | |
| def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0): | |
| super(RNNTLoss, self).__init__() | |
| self.blank = blank | |
| self.fastemit_lambda = fastemit_lambda | |
| self.rnnt = _RNNT.apply | |
| def forward(self, acts, labels, act_lens, label_lens): | |
| assert len(labels.size()) == 2 | |
| _assert_no_grad(labels) | |
| _assert_no_grad(act_lens) | |
| _assert_no_grad(label_lens) | |
| certify_inputs(acts, labels, act_lens, label_lens) | |
| acts = torch.nn.functional.log_softmax(acts, -1) | |
| return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda) | |
| if __name__ == '__main__': | |
| loss = RNNTLoss(fastemit_lambda=0.01) | |
| torch.manual_seed(0) | |
| acts = torch.randn(1, 2, 5, 3) | |
| labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32) | |
| act_lens = torch.tensor([2], dtype=torch.int32) | |
| label_lens = torch.tensor([len(labels[0])], dtype=torch.int32) | |
| loss_val = loss(acts, labels, act_lens, label_lens) | |