| | |
| |
|
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | """Unility functions for Transformer.""" |
| |
|
| | import math |
| | from typing import List, Tuple |
| |
|
| | import torch |
| | from torch.nn.utils.rnn import pad_sequence |
| |
|
| | IGNORE_ID = -1 |
| |
|
| |
|
| | def pad_list(xs: List[torch.Tensor], pad_value: int): |
| | """Perform padding for the list of tensors. |
| | |
| | Args: |
| | xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
| | pad_value (float): Value for padding. |
| | |
| | Returns: |
| | Tensor: Padded tensor (B, Tmax, `*`). |
| | |
| | Examples: |
| | >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] |
| | >>> x |
| | [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] |
| | >>> pad_list(x, 0) |
| | tensor([[1., 1., 1., 1.], |
| | [1., 1., 0., 0.], |
| | [1., 0., 0., 0.]]) |
| | |
| | """ |
| | n_batch = len(xs) |
| | max_len = max([x.size(0) for x in xs]) |
| | pad = torch.zeros(n_batch, max_len, dtype=xs[0].dtype, device=xs[0].device) |
| | pad = pad.fill_(pad_value) |
| | for i in range(n_batch): |
| | pad[i, : xs[i].size(0)] = xs[i] |
| |
|
| | return pad |
| |
|
| |
|
| | def add_blank(ys_pad: torch.Tensor, blank: int, ignore_id: int) -> torch.Tensor: |
| | """Prepad blank for transducer predictor |
| | |
| | Args: |
| | ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
| | blank (int): index of <blank> |
| | |
| | Returns: |
| | ys_in (torch.Tensor) : (B, Lmax + 1) |
| | |
| | Examples: |
| | >>> blank = 0 |
| | >>> ignore_id = -1 |
| | >>> ys_pad |
| | tensor([[ 1, 2, 3, 4, 5], |
| | [ 4, 5, 6, -1, -1], |
| | [ 7, 8, 9, -1, -1]], dtype=torch.int32) |
| | >>> ys_in = add_blank(ys_pad, 0, -1) |
| | >>> ys_in |
| | tensor([[0, 1, 2, 3, 4, 5], |
| | [0, 4, 5, 6, 0, 0], |
| | [0, 7, 8, 9, 0, 0]]) |
| | """ |
| | bs = ys_pad.size(0) |
| | _blank = torch.tensor( |
| | [blank], dtype=torch.long, requires_grad=False, device=ys_pad.device |
| | ) |
| | _blank = _blank.repeat(bs).unsqueeze(1) |
| | out = torch.cat([_blank, ys_pad], dim=1) |
| | return torch.where(out == ignore_id, blank, out) |
| |
|
| |
|
| | def add_sos_eos( |
| | ys_pad: torch.Tensor, sos: int, eos: int, ignore_id: int |
| | ) -> Tuple[torch.Tensor, torch.Tensor]: |
| | """Add <sos> and <eos> labels. |
| | |
| | Args: |
| | ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
| | sos (int): index of <sos> |
| | eos (int): index of <eeos> |
| | ignore_id (int): index of padding |
| | |
| | Returns: |
| | ys_in (torch.Tensor) : (B, Lmax + 1) |
| | ys_out (torch.Tensor) : (B, Lmax + 1) |
| | |
| | Examples: |
| | >>> sos_id = 10 |
| | >>> eos_id = 11 |
| | >>> ignore_id = -1 |
| | >>> ys_pad |
| | tensor([[ 1, 2, 3, 4, 5], |
| | [ 4, 5, 6, -1, -1], |
| | [ 7, 8, 9, -1, -1]], dtype=torch.int32) |
| | >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) |
| | >>> ys_in |
| | tensor([[10, 1, 2, 3, 4, 5], |
| | [10, 4, 5, 6, 11, 11], |
| | [10, 7, 8, 9, 11, 11]]) |
| | >>> ys_out |
| | tensor([[ 1, 2, 3, 4, 5, 11], |
| | [ 4, 5, 6, 11, -1, -1], |
| | [ 7, 8, 9, 11, -1, -1]]) |
| | """ |
| | _sos = torch.tensor( |
| | [sos], dtype=torch.long, requires_grad=False, device=ys_pad.device |
| | ) |
| | _eos = torch.tensor( |
| | [eos], dtype=torch.long, requires_grad=False, device=ys_pad.device |
| | ) |
| | ys = [y[y != ignore_id] for y in ys_pad] |
| | ys_in = [torch.cat([_sos, y], dim=0) for y in ys] |
| | ys_out = [torch.cat([y, _eos], dim=0) for y in ys] |
| | return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) |
| |
|
| |
|
| | def reverse_pad_list( |
| | ys_pad: torch.Tensor, ys_lens: torch.Tensor, pad_value: float = -1.0 |
| | ) -> torch.Tensor: |
| | """Reverse padding for the list of tensors. |
| | |
| | Args: |
| | ys_pad (tensor): The padded tensor (B, Tokenmax). |
| | ys_lens (tensor): The lens of token seqs (B) |
| | pad_value (int): Value for padding. |
| | |
| | Returns: |
| | Tensor: Padded tensor (B, Tokenmax). |
| | |
| | Examples: |
| | >>> x |
| | tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) |
| | >>> pad_list(x, 0) |
| | tensor([[4, 3, 2, 1], |
| | [7, 6, 5, 0], |
| | [9, 8, 0, 0]]) |
| | |
| | """ |
| | r_ys_pad = pad_sequence( |
| | [(torch.flip(y.int()[:i], [0])) for y, i in zip(ys_pad, ys_lens)], |
| | True, |
| | pad_value, |
| | ) |
| | return r_ys_pad |
| |
|
| |
|
| | def th_accuracy( |
| | pad_outputs: torch.Tensor, pad_targets: torch.Tensor, ignore_label: int |
| | ) -> float: |
| | """Calculate accuracy. |
| | |
| | Args: |
| | pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| | pad_targets (LongTensor): Target label tensors (B, Lmax). |
| | ignore_label (int): Ignore label id. |
| | |
| | Returns: |
| | float: Accuracy value (0.0 - 1.0). |
| | |
| | """ |
| | pad_pred = pad_outputs.view( |
| | pad_targets.size(0), pad_targets.size(1), pad_outputs.size(1) |
| | ).argmax(2) |
| | mask = pad_targets != ignore_label |
| | numerator = torch.sum( |
| | pad_pred.masked_select(mask) == pad_targets.masked_select(mask) |
| | ) |
| | denominator = torch.sum(mask) |
| | return float(numerator) / float(denominator) |
| |
|
| |
|
| | def get_rnn(rnn_type: str) -> torch.nn.Module: |
| | assert rnn_type in ["rnn", "lstm", "gru"] |
| | if rnn_type == "rnn": |
| | return torch.nn.RNN |
| | elif rnn_type == "lstm": |
| | return torch.nn.LSTM |
| | else: |
| | return torch.nn.GRU |
| |
|
| |
|
| | def get_activation(act): |
| | """Return activation function.""" |
| | |
| | from modules.wenet_extractor.transformer.swish import Swish |
| |
|
| | activation_funcs = { |
| | "hardtanh": torch.nn.Hardtanh, |
| | "tanh": torch.nn.Tanh, |
| | "relu": torch.nn.ReLU, |
| | "selu": torch.nn.SELU, |
| | "swish": getattr(torch.nn, "SiLU", Swish), |
| | "gelu": torch.nn.GELU, |
| | } |
| |
|
| | return activation_funcs[act]() |
| |
|
| |
|
| | def get_subsample(config): |
| | input_layer = config["encoder_conf"]["input_layer"] |
| | assert input_layer in ["conv2d", "conv2d6", "conv2d8"] |
| | if input_layer == "conv2d": |
| | return 4 |
| | elif input_layer == "conv2d6": |
| | return 6 |
| | elif input_layer == "conv2d8": |
| | return 8 |
| |
|
| |
|
| | def remove_duplicates_and_blank(hyp: List[int]) -> List[int]: |
| | new_hyp: List[int] = [] |
| | cur = 0 |
| | while cur < len(hyp): |
| | if hyp[cur] != 0: |
| | new_hyp.append(hyp[cur]) |
| | prev = cur |
| | while cur < len(hyp) and hyp[cur] == hyp[prev]: |
| | cur += 1 |
| | return new_hyp |
| |
|
| |
|
| | def replace_duplicates_with_blank(hyp: List[int]) -> List[int]: |
| | new_hyp: List[int] = [] |
| | cur = 0 |
| | while cur < len(hyp): |
| | new_hyp.append(hyp[cur]) |
| | prev = cur |
| | cur += 1 |
| | while cur < len(hyp) and hyp[cur] == hyp[prev] and hyp[cur] != 0: |
| | new_hyp.append(0) |
| | cur += 1 |
| | return new_hyp |
| |
|
| |
|
| | def log_add(args: List[int]) -> float: |
| | """ |
| | Stable log add |
| | """ |
| | if all(a == -float("inf") for a in args): |
| | return -float("inf") |
| | a_max = max(args) |
| | lsp = math.log(sum(math.exp(a - a_max) for a in args)) |
| | return a_max + lsp |
| |
|