Spaces:
Runtime error
Runtime error
Dit-document-layout-analysis
/
unilm
/decoding
/IAD
/fairseq
/examples
/simultaneous_translation
/utils
/functions.py
| # Copyright (c) Facebook, Inc. and its affiliates. | |
| # | |
| # This source code is licensed under the MIT license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| import torch | |
| def exclusive_cumprod(tensor, dim: int, eps: float = 1e-10): | |
| """ | |
| Implementing exclusive cumprod. | |
| There is cumprod in pytorch, however there is no exclusive mode. | |
| cumprod(x) = [x1, x1x2, x2x3x4, ..., prod_{i=1}^n x_i] | |
| exclusive means cumprod(x) = [1, x1, x1x2, x1x2x3, ..., prod_{i=1}^{n-1} x_i] | |
| """ | |
| tensor_size = list(tensor.size()) | |
| tensor_size[dim] = 1 | |
| return_tensor = safe_cumprod( | |
| torch.cat([torch.ones(tensor_size).type_as(tensor), tensor], dim=dim), | |
| dim=dim, | |
| eps=eps, | |
| ) | |
| if dim == 0: | |
| return return_tensor[:-1] | |
| elif dim == 1: | |
| return return_tensor[:, :-1] | |
| elif dim == 2: | |
| return return_tensor[:, :, :-1] | |
| else: | |
| raise RuntimeError("Cumprod on dimension 3 and more is not implemented") | |
| def safe_cumprod(tensor, dim: int, eps: float = 1e-10): | |
| """ | |
| An implementation of cumprod to prevent precision issue. | |
| cumprod(x) | |
| = [x1, x1x2, x1x2x3, ....] | |
| = [exp(log(x1)), exp(log(x1) + log(x2)), exp(log(x1) + log(x2) + log(x3)), ...] | |
| = exp(cumsum(log(x))) | |
| """ | |
| if (tensor + eps < 0).any().item(): | |
| raise RuntimeError( | |
| "Safe cumprod can only take non-negative tensors as input." | |
| "Consider use torch.cumprod if you want to calculate negative values." | |
| ) | |
| log_tensor = torch.log(tensor + eps) | |
| cumsum_log_tensor = torch.cumsum(log_tensor, dim) | |
| exp_cumsum_log_tensor = torch.exp(cumsum_log_tensor) | |
| return exp_cumsum_log_tensor | |
| def lengths_to_mask(lengths, max_len: int, dim: int = 0, negative_mask: bool = False): | |
| """ | |
| Convert a tensor of lengths to mask | |
| For example, lengths = [[2, 3, 4]], max_len = 5 | |
| mask = | |
| [[1, 1, 1], | |
| [1, 1, 1], | |
| [0, 1, 1], | |
| [0, 0, 1], | |
| [0, 0, 0]] | |
| """ | |
| assert len(lengths.size()) <= 2 | |
| if len(lengths) == 2: | |
| if dim == 1: | |
| lengths = lengths.t() | |
| lengths = lengths | |
| else: | |
| lengths = lengths.unsqueeze(1) | |
| # lengths : batch_size, 1 | |
| lengths = lengths.view(-1, 1) | |
| batch_size = lengths.size(0) | |
| # batch_size, max_len | |
| mask = torch.arange(max_len).expand(batch_size, max_len).type_as(lengths) < lengths | |
| if negative_mask: | |
| mask = ~mask | |
| if dim == 0: | |
| # max_len, batch_size | |
| mask = mask.t() | |
| return mask | |
| def moving_sum(x, start_idx: int, end_idx: int): | |
| """ | |
| From MONOTONIC CHUNKWISE ATTENTION | |
| https://arxiv.org/pdf/1712.05382.pdf | |
| Equation (18) | |
| x = [x_1, x_2, ..., x_N] | |
| MovingSum(x, start_idx, end_idx)_n = Sigma_{m=n−(start_idx−1)}^{n+end_idx-1} x_m | |
| for n in {1, 2, 3, ..., N} | |
| x : src_len, batch_size | |
| start_idx : start idx | |
| end_idx : end idx | |
| Example | |
| src_len = 5 | |
| batch_size = 3 | |
| x = | |
| [[ 0, 5, 10], | |
| [ 1, 6, 11], | |
| [ 2, 7, 12], | |
| [ 3, 8, 13], | |
| [ 4, 9, 14]] | |
| MovingSum(x, 3, 1) = | |
| [[ 0, 5, 10], | |
| [ 1, 11, 21], | |
| [ 3, 18, 33], | |
| [ 6, 21, 36], | |
| [ 9, 24, 39]] | |
| MovingSum(x, 1, 3) = | |
| [[ 3, 18, 33], | |
| [ 6, 21, 36], | |
| [ 9, 24, 39], | |
| [ 7, 17, 27], | |
| [ 4, 9, 14]] | |
| """ | |
| assert start_idx > 0 and end_idx > 0 | |
| assert len(x.size()) == 2 | |
| src_len, batch_size = x.size() | |
| # batch_size, 1, src_len | |
| x = x.t().unsqueeze(1) | |
| # batch_size, 1, src_len | |
| moving_sum_weight = x.new_ones([1, 1, end_idx + start_idx - 1]) | |
| moving_sum = ( | |
| torch.nn.functional.conv1d( | |
| x, moving_sum_weight, padding=start_idx + end_idx - 1 | |
| ) | |
| .squeeze(1) | |
| .t() | |
| ) | |
| moving_sum = moving_sum[end_idx:-start_idx] | |
| assert src_len == moving_sum.size(0) | |
| assert batch_size == moving_sum.size(1) | |
| return moving_sum | |