| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """Unility functions for Transformer.""" |
|
|
| import math |
| import time |
| from typing import List, Tuple |
|
|
| import torch |
| from torch.nn.utils.rnn import pad_sequence |
|
|
| from whisper.tokenizer import LANGUAGES as WhiserLanguages |
|
|
| WHISPER_LANGS = tuple(WhiserLanguages.keys()) |
| IGNORE_ID = -1 |
|
|
|
|
| def pad_list(xs: List[torch.Tensor], pad_value: int): |
| """Perform padding for the list of tensors. |
| |
| Args: |
| xs (List): List of Tensors [(T_1, `*`), (T_2, `*`), ..., (T_B, `*`)]. |
| pad_value (float): Value for padding. |
| |
| Returns: |
| Tensor: Padded tensor (B, Tmax, `*`). |
| |
| Examples: |
| >>> x = [torch.ones(4), torch.ones(2), torch.ones(1)] |
| >>> x |
| [tensor([1., 1., 1., 1.]), tensor([1., 1.]), tensor([1.])] |
| >>> pad_list(x, 0) |
| tensor([[1., 1., 1., 1.], |
| [1., 1., 0., 0.], |
| [1., 0., 0., 0.]]) |
| |
| """ |
| max_len = max([len(item) for item in xs]) |
| batchs = len(xs) |
| ndim = xs[0].ndim |
| if ndim == 1: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| elif ndim == 2: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| xs[0].shape[1], |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| elif ndim == 3: |
| pad_res = torch.zeros(batchs, |
| max_len, |
| xs[0].shape[1], |
| xs[0].shape[2], |
| dtype=xs[0].dtype, |
| device=xs[0].device) |
| else: |
| raise ValueError(f"Unsupported ndim: {ndim}") |
| pad_res.fill_(pad_value) |
| for i in range(batchs): |
| pad_res[i, :len(xs[i])] = xs[i] |
| return pad_res |
|
|
|
|
| def add_blank(ys_pad: torch.Tensor, blank: int, |
| ignore_id: int) -> torch.Tensor: |
| """ Prepad blank for transducer predictor |
| |
| Args: |
| ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
| blank (int): index of <blank> |
| |
| Returns: |
| ys_in (torch.Tensor) : (B, Lmax + 1) |
| |
| Examples: |
| >>> blank = 0 |
| >>> ignore_id = -1 |
| >>> ys_pad |
| tensor([[ 1, 2, 3, 4, 5], |
| [ 4, 5, 6, -1, -1], |
| [ 7, 8, 9, -1, -1]], dtype=torch.int32) |
| >>> ys_in = add_blank(ys_pad, 0, -1) |
| >>> ys_in |
| tensor([[0, 1, 2, 3, 4, 5], |
| [0, 4, 5, 6, 0, 0], |
| [0, 7, 8, 9, 0, 0]]) |
| """ |
| bs = ys_pad.size(0) |
| _blank = torch.tensor([blank], |
| dtype=torch.long, |
| requires_grad=False, |
| device=ys_pad.device) |
| _blank = _blank.repeat(bs).unsqueeze(1) |
| out = torch.cat([_blank, ys_pad], dim=1) |
| return torch.where(out == ignore_id, blank, out) |
|
|
|
|
| def add_sos_eos(ys_pad: torch.Tensor, sos: int, eos: int, |
| ignore_id: int) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Add <sos> and <eos> labels. |
| |
| Args: |
| ys_pad (torch.Tensor): batch of padded target sequences (B, Lmax) |
| sos (int): index of <sos> |
| eos (int): index of <eeos> |
| ignore_id (int): index of padding |
| |
| Returns: |
| ys_in (torch.Tensor) : (B, Lmax + 1) |
| ys_out (torch.Tensor) : (B, Lmax + 1) |
| |
| Examples: |
| >>> sos_id = 10 |
| >>> eos_id = 11 |
| >>> ignore_id = -1 |
| >>> ys_pad |
| tensor([[ 1, 2, 3, 4, 5], |
| [ 4, 5, 6, -1, -1], |
| [ 7, 8, 9, -1, -1]], dtype=torch.int32) |
| >>> ys_in,ys_out=add_sos_eos(ys_pad, sos_id , eos_id, ignore_id) |
| >>> ys_in |
| tensor([[10, 1, 2, 3, 4, 5], |
| [10, 4, 5, 6, 11, 11], |
| [10, 7, 8, 9, 11, 11]]) |
| >>> ys_out |
| tensor([[ 1, 2, 3, 4, 5, 11], |
| [ 4, 5, 6, 11, -1, -1], |
| [ 7, 8, 9, 11, -1, -1]]) |
| """ |
| _sos = torch.tensor([sos], |
| dtype=torch.long, |
| requires_grad=False, |
| device=ys_pad.device) |
| _eos = torch.tensor([eos], |
| dtype=torch.long, |
| requires_grad=False, |
| device=ys_pad.device) |
| ys = [y[y != ignore_id] for y in ys_pad] |
| ys_in = [torch.cat([_sos, y], dim=0) for y in ys] |
| ys_out = [torch.cat([y, _eos], dim=0) for y in ys] |
| return pad_list(ys_in, eos), pad_list(ys_out, ignore_id) |
|
|
|
|
| def add_whisper_tokens(special_tokens, ys_pad: torch.Tensor, ignore_id: int, |
| tasks: List[str], no_timestamp: bool, langs: List[str], |
| use_prev: bool) -> Tuple[torch.Tensor, torch.Tensor]: |
| """Add whisper-style tokens. |
| |
| ([PREV] -> [previous text tokens or hotwords]).optional -- |
| ┌------------------------------------------------------↲ |
| ↓ |
| [sot] -> [language id] -> [transcribe] -> [begin time] -> [text tokens] -> [end time] -> ... -> [eot] # noqa |
| | | |-------> [no timestamps] -> [text tokens] ----------------------↑ # noqa |
| | | | # noqa |
| | |--------> [translate] -> [begin time] -> [text tokens] -> [end time] -> ... --->| # noqa |
| | |-------> [no timestamps] -> [text tokens] --------------------->| # noqa |
| | | # noqa |
| |--> [no speech(VAD)] ---------------------------------------------------------------------->| # noqa |
| |
| Args: |
| special_tokens: get IDs of special tokens |
| ignore_id (int): index of padding |
| no_timestamp (bool): whether to add timestamps tokens |
| tasks (List[str]): list of task tags |
| langs (List[str]): list of language tags |
| |
| Returns: |
| ys_in (torch.Tensor) : (B, Lmax + ?) |
| ys_out (torch.Tensor) : (B, Lmax + ?) |
| |
| """ |
| assert len(langs) == ys_pad.size(0) |
| assert len(tasks) == ys_pad.size(0) |
| if use_prev: |
| |
| _prev = [special_tokens["sot_prev"]] |
| |
| |
| raise NotImplementedError |
| else: |
| _prev = [] |
|
|
| _sot = [] |
| for task, lang in zip(tasks, langs): |
| if task == "transcribe": |
| task_id = special_tokens["transcribe"] |
| elif task == "translate": |
| task_id = special_tokens["translate"] |
| elif task == "vad": |
| task_id = special_tokens["no_speech"] |
| else: |
| raise NotImplementedError("unsupported task {}".format(task)) |
| language_id = special_tokens["sot"] + 1 + WHISPER_LANGS.index(lang) |
| prefix = _prev + [special_tokens["sot"], language_id, task_id] |
| if task == "transcribe" or task == "translate": |
| if no_timestamp: |
| prefix.append(special_tokens["no_timestamps"]) |
| else: |
| prefix.append(special_tokens["timestamp_begin"]) |
| |
| |
| raise NotImplementedError |
| elif task == "vad": |
| prefix.append(special_tokens["no_speech"]) |
| else: |
| raise NotImplementedError |
| prefix = torch.tensor(prefix, |
| dtype=torch.long, |
| requires_grad=False, |
| device=ys_pad.device) |
| _sot.append(prefix) |
|
|
| _eot = torch.tensor([special_tokens["eot"]], |
| dtype=torch.long, |
| requires_grad=False, |
| device=ys_pad.device) |
| ys = [y[y != ignore_id] for y in ys_pad] |
|
|
| ys_in = [torch.cat([prefix, y], dim=0) for prefix, y in zip(_sot, ys)] |
| ys_out = [ |
| torch.cat([prefix[1:], y, _eot], dim=0) for prefix, y in zip(_sot, ys) |
| ] |
| return pad_list(ys_in, special_tokens["eot"]), pad_list(ys_out, ignore_id) |
|
|
|
|
| def reverse_pad_list(ys_pad: torch.Tensor, |
| ys_lens: torch.Tensor, |
| pad_value: float = -1.0) -> torch.Tensor: |
| """Reverse padding for the list of tensors. |
| |
| Args: |
| ys_pad (tensor): The padded tensor (B, Tokenmax). |
| ys_lens (tensor): The lens of token seqs (B) |
| pad_value (int): Value for padding. |
| |
| Returns: |
| Tensor: Padded tensor (B, Tokenmax). |
| |
| Examples: |
| >>> x |
| tensor([[1, 2, 3, 4], [5, 6, 7, 0], [8, 9, 0, 0]]) |
| >>> pad_list(x, 0) |
| tensor([[4, 3, 2, 1], |
| [7, 6, 5, 0], |
| [9, 8, 0, 0]]) |
| |
| """ |
| r_ys_pad = pad_sequence([(torch.flip(y.int()[:i], [0])) |
| for y, i in zip(ys_pad, ys_lens)], True, |
| pad_value) |
| return r_ys_pad |
|
|
|
|
| def th_accuracy(pad_outputs: torch.Tensor, pad_targets: torch.Tensor, |
| ignore_label: int) -> torch.Tensor: |
| """Calculate accuracy. |
| |
| Args: |
| pad_outputs (Tensor): Prediction tensors (B * Lmax, D). |
| pad_targets (LongTensor): Target label tensors (B, Lmax). |
| ignore_label (int): Ignore label id. |
| |
| Returns: |
| torch.Tensor: Accuracy value (0.0 - 1.0). |
| |
| """ |
| pad_pred = pad_outputs.view(pad_targets.size(0), pad_targets.size(1), |
| pad_outputs.size(1)).argmax(2) |
| mask = pad_targets != ignore_label |
| numerator = torch.sum( |
| pad_pred.masked_select(mask) == pad_targets.masked_select(mask)) |
| denominator = torch.sum(mask) |
| return (numerator / denominator).detach() |
|
|
|
|
| def get_subsample(config): |
| input_layer = config["encoder_conf"]["input_layer"] |
| assert input_layer in ["conv2d", "conv2d6", "conv2d8"] |
| if input_layer == "conv2d": |
| return 4 |
| elif input_layer == "conv2d6": |
| return 6 |
| elif input_layer == "conv2d8": |
| return 8 |
|
|
|
|
| def log_add(*args) -> float: |
| """ |
| Stable log add |
| """ |
| if all(a == -float('inf') for a in args): |
| return -float('inf') |
| a_max = max(args) |
| lsp = math.log(sum(math.exp(a - a_max) for a in args)) |
| return a_max + lsp |
|
|
|
|
| def mask_to_bias(mask: torch.Tensor, dtype: torch.dtype) -> torch.Tensor: |
| assert mask.dtype == torch.bool |
| assert dtype in [torch.float32, torch.bfloat16, torch.float16] |
| mask = mask.to(dtype) |
| |
| |
| |
| mask = (1.0 - mask) * -1.0e+10 |
| return mask |
|
|
|
|
| def get_nested_attribute(obj, attr_path): |
| if isinstance(obj, torch.nn.parallel.DistributedDataParallel): |
| obj = obj.module |
| attributes = attr_path.split('.') |
| for attr in attributes: |
| obj = getattr(obj, attr) |
| return obj |
|
|
|
|
| def lrs_to_str(lrs: List): |
| return " ".join(["{:.4e}".format(lr) for lr in lrs]) |
|
|
|
|
| class StepTimer: |
| """Utility class for measuring steps/second.""" |
|
|
| def __init__(self, step=0.0): |
| self.last_iteration = step |
| self.start() |
|
|
| def start(self): |
| self.last_time = time.time() |
|
|
| def steps_per_second(self, cur_step, restart=True): |
| value = ((float(cur_step) - self.last_iteration) / |
| (time.time() - self.last_time)) |
| if restart: |
| self.start() |
| self.last_iteration = float(cur_step) |
| return value |
|
|
|
|
| def tensor_to_scalar(x): |
| if torch.is_tensor(x): |
| return x.item() |
| return x |
|
|
|
|
| def is_torch_npu_available() -> bool: |
| ''' |
| check if torch_npu is available. |
| torch_npu is a npu adapter of PyTorch |
| ''' |
| try: |
| import torch_npu |
| return True |
| except ImportError: |
| if not torch.cuda.is_available(): |
| print("Module \"torch_npu\" not found. \"pip install torch_npu\" \ |
| if you are using Ascend NPU, otherwise, ignore it") |
| return False |
|
|
|
|
| TORCH_NPU_AVAILABLE = is_torch_npu_available() |
|
|