codewraith / data /source_files /clean /439bd748ba75.py
slenk's picture
Upload folder using huggingface_hub
eeef81e verified
# Copyright (c) 2021, NVIDIA CORPORATION. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
#
# Copyright 2018-2019, Mingkun Huang
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import numpy as np
import torch
from torch.autograd import Function, Variable
from torch.nn import Module
def check_type(var, t, name):
if var.dtype is not t:
raise TypeError("{} must be {}".format(name, t))
def check_contiguous(var, name):
if not var.is_contiguous():
raise ValueError("{} must be contiguous".format(name))
def check_dim(var, dim, name):
if len(var.shape) != dim:
raise ValueError("{} must be {}D".format(name, dim))
def certify_inputs(log_probs, labels, lengths, label_lengths):
# check_type(log_probs, torch.float32, "log_probs")
check_type(labels, torch.int32, "labels")
check_type(label_lengths, torch.int32, "label_lengths")
check_type(lengths, torch.int32, "lengths")
check_contiguous(log_probs, "log_probs")
check_contiguous(labels, "labels")
check_contiguous(label_lengths, "label_lengths")
check_contiguous(lengths, "lengths")
if lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
f"Must have a length per example. "
f"Given lengths dim: {lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)
if label_lengths.shape[0] != log_probs.shape[0]:
raise ValueError(
"Must have a label length per example. "
f"Given label lengths dim : {label_lengths.shape[0]}, "
f"Log probs dim : {log_probs.shape[0]}"
)
check_dim(log_probs, 4, "log_probs")
check_dim(labels, 2, "labels")
check_dim(lengths, 1, "lenghts")
check_dim(label_lengths, 1, "label_lenghts")
max_T = torch.max(lengths)
max_U = torch.max(label_lengths)
T, U = log_probs.shape[1:3]
if T != max_T:
raise ValueError(f"Input length mismatch! Given T: {T}, Expected max T from input lengths: {max_T}")
if U != max_U + 1:
raise ValueError(f"Output length mismatch! Given U: {U}, Expected max U from target lengths: {max_U} + 1")
def _assert_no_grad(tensor):
assert not tensor.requires_grad, (
"gradients only computed for log_probs - please " "mark other tensors as not requiring gradients"
)
def forward_pass(log_probs, labels, blank):
"""
Computes probability of the forward variable alpha.
Args:
log_probs: Tensor of shape [T, U, V+1]
labels: Labels of shape [B, U]
blank: Index of the blank token.
Returns:
A tuple of the forward variable probabilities - alpha of shape [T, U]
and the log likelihood of this forward step.
"""
T, U, _ = log_probs.shape
alphas = np.zeros((T, U), dtype='f')
for t in range(1, T):
alphas[t, 0] = alphas[t - 1, 0] + log_probs[t - 1, 0, blank]
for u in range(1, U):
alphas[0, u] = alphas[0, u - 1] + log_probs[0, u - 1, labels[u - 1]]
for t in range(1, T):
for u in range(1, U):
no_emit = alphas[t - 1, u] + log_probs[t - 1, u, blank]
emit = alphas[t, u - 1] + log_probs[t, u - 1, labels[u - 1]]
alphas[t, u] = np.logaddexp(emit, no_emit)
loglike = alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank]
return alphas, loglike
def backward_pass(log_probs, labels, blank):
"""
Computes probability of the backward variable beta.
Args:
log_probs: Tensor of shape [T, U, V+1]
labels: Labels of shape [B, U]
blank: Index of the blank token.
Returns:
A tuple of the backward variable probabilities - beta of shape [T, U]
and the log likelihood of this backward step.
"""
T, U, _ = log_probs.shape
betas = np.zeros((T, U), dtype='f')
betas[T - 1, U - 1] = log_probs[T - 1, U - 1, blank]
for t in reversed(range(T - 1)):
betas[t, U - 1] = betas[t + 1, U - 1] + log_probs[t, U - 1, blank]
for u in reversed(range(U - 1)):
betas[T - 1, u] = betas[T - 1, u + 1] + log_probs[T - 1, u, labels[u]]
for t in reversed(range(T - 1)):
for u in reversed(range(U - 1)):
no_emit = betas[t + 1, u] + log_probs[t, u, blank]
emit = betas[t, u + 1] + log_probs[t, u, labels[u]]
betas[t, u] = np.logaddexp(emit, no_emit)
return betas, betas[0, 0]
def compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda):
"""
Computes the gradients of the log_probs with respect to the log probability of this step occuring.
Args:
Args:
log_probs: Tensor of shape [T, U, V+1]
alphas: Tensor of shape [T, U] which represents the forward variable.
betas: Tensor of shape [T, U] which represents the backward variable.
labels: Labels of shape [B, U]
blank: Index of the blank token.
Returns:
Gradients of shape [T, U, V+1] with respect to the forward log probability
"""
T, U, _ = log_probs.shape
grads = np.full(log_probs.shape, -float("inf"))
log_like = betas[0, 0] # == alphas[T - 1, U - 1] + betas[T - 1, U - 1]
# // grad to last blank transition
grads[T - 1, U - 1, blank] = alphas[T - 1, U - 1]
grads[: T - 1, :, blank] = alphas[: T - 1, :] + betas[1:, :]
# // grad to label transition
for u, l in enumerate(labels):
grads[:, u, l] = alphas[:, u] + betas[:, u + 1]
grads = -np.exp(grads + log_probs - log_like)
if fastemit_lambda > 0.0:
for u, l in enumerate(labels):
grads[:, u, l] = (1.0 + fastemit_lambda) * grads[:, u, l]
return grads
def fastemit_regularization(log_probs, labels, alphas, betas, blank, fastemit_lambda):
"""
Describes the computation of FastEmit regularization from the paper -
[FastEmit: Low-latency Streaming ASR with Sequence-level Emission Regularization](https://arxiv.org/abs/2010.11148)
Args:
log_probs: Tensor of shape [T, U, V+1]
labels: Unused. Labels of shape [B, U]
alphas: Tensor of shape [T, U] which represents the forward variable.
betas: Unused. Tensor of shape [T, U] which represents the backward variable.
blank: Index of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.
Returns:
The regularized negative log likelihood - lambda * P˜(At, u|x)
"""
# General calculation of the fastemit regularization alignments
T, U, _ = log_probs.shape
# alignment = np.zeros((T, U), dtype='float32')
#
# for t in range(0, T):
# alignment[t, U - 1] = alphas[t, U - 1] + betas[t, U - 1]
#
# for t in range(0, T):
# for u in range(0, U - 1):
# emit = alphas[t, u] + log_probs[t, u, labels[u]] + betas[t, u + 1]
# alignment[t, u] = emit
# reg = fastemit_lambda * (alignment[T - 1, U - 1])
# The above is equivalent to below, without need of computing above
# reg = fastemit_lambda * (alphas[T - 1, U - 1] + betas[T - 1, U - 1])
# The above is also equivalent to below, without need of computing the betas alignment matrix
reg = fastemit_lambda * (alphas[T - 1, U - 1] + log_probs[T - 1, U - 1, blank])
return -reg
def transduce(log_probs, labels, blank=0, fastemit_lambda=0.0):
"""
Args:
log_probs: 3D array with shape
[input len, output len + 1, vocab size]
labels: 1D array with shape [output time steps]
blank: Index of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.
Returns:
float: The negative log-likelihood
3D array: Gradients with respect to the
unnormalized input actications
2d arrays: Alphas matrix (TxU)
2d array: Betas matrix (TxU)
"""
alphas, ll_forward = forward_pass(log_probs, labels, blank)
betas, ll_backward = backward_pass(log_probs, labels, blank)
grads = compute_gradient(log_probs, alphas, betas, labels, blank, fastemit_lambda)
return -ll_forward, grads, alphas, betas
def transduce_batch(log_probs, labels, flen, glen, blank=0, fastemit_lambda=0.0):
"""
Compute the transducer loss of the batch.
Args:
log_probs: [B, T, U, V+1]. Activation matrix normalized with log-softmax.
labels: [B, U+1] - ground truth labels with <SOS> padded as blank token in the beginning.
flen: Length vector of the acoustic sequence.
glen: Length vector of the target sequence.
blank: Id of the blank token.
fastemit_lambda: Float scaling factor for FastEmit regularization.
Returns:
Batch of transducer forward log probabilities (loss) and the gradients of the activation matrix.
"""
grads = np.zeros_like(log_probs)
costs = []
for b in range(log_probs.shape[0]):
t = int(flen[b])
u = int(glen[b]) + 1
ll, g, alphas, betas = transduce(log_probs[b, :t, :u, :], labels[b, : u - 1], blank, fastemit_lambda)
grads[b, :t, :u, :] = g
reg = fastemit_regularization(
log_probs[b, :t, :u, :], labels[b, : u - 1], alphas, betas, blank, fastemit_lambda
)
ll += reg
costs.append(ll)
return costs, grads
class _RNNT(Function):
@staticmethod
def forward(ctx, acts, labels, act_lens, label_lens, blank, fastemit_lambda):
costs, grads = transduce_batch(
acts.detach().cpu().numpy(),
labels.cpu().numpy(),
act_lens.cpu().numpy(),
label_lens.cpu().numpy(),
blank,
fastemit_lambda,
)
costs = torch.FloatTensor([sum(costs)])
grads = torch.Tensor(grads).to(acts)
ctx.grads = grads
return costs
@staticmethod
def backward(ctx, grad_output):
return ctx.grads, None, None, None, None, None
class RNNTLoss(Module):
"""
Parameters:
`blank_label` (int): default 0 - label index of blank token
fastemit_lambda: Float scaling factor for FastEmit regularization.
"""
def __init__(self, blank: int = 0, fastemit_lambda: float = 0.0):
super(RNNTLoss, self).__init__()
self.blank = blank
self.fastemit_lambda = fastemit_lambda
self.rnnt = _RNNT.apply
def forward(self, acts, labels, act_lens, label_lens):
assert len(labels.size()) == 2
_assert_no_grad(labels)
_assert_no_grad(act_lens)
_assert_no_grad(label_lens)
certify_inputs(acts, labels, act_lens, label_lens)
acts = torch.nn.functional.log_softmax(acts, -1)
return self.rnnt(acts, labels, act_lens, label_lens, self.blank, self.fastemit_lambda)
if __name__ == '__main__':
loss = RNNTLoss(fastemit_lambda=0.01)
torch.manual_seed(0)
acts = torch.randn(1, 2, 5, 3)
labels = torch.tensor([[0, 2, 1, 2]], dtype=torch.int32)
act_lens = torch.tensor([2], dtype=torch.int32)
label_lens = torch.tensor([len(labels[0])], dtype=torch.int32)
loss_val = loss(acts, labels, act_lens, label_lens)