| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | import torch |
| |
|
| | ''' |
| | def subsequent_mask( |
| | size: int, |
| | device: torch.device = torch.device("cpu"), |
| | ) -> torch.Tensor: |
| | """Create mask for subsequent steps (size, size). |
| | |
| | This mask is used only in decoder which works in an auto-regressive mode. |
| | This means the current step could only do attention with its left steps. |
| | |
| | In encoder, fully attention is used when streaming is not necessary and |
| | the sequence is not long. In this case, no attention mask is needed. |
| | |
| | When streaming is need, chunk-based attention is used in encoder. See |
| | subsequent_chunk_mask for the chunk-based attention mask. |
| | |
| | Args: |
| | size (int): size of mask |
| | str device (str): "cpu" or "cuda" or torch.Tensor.device |
| | dtype (torch.device): result dtype |
| | |
| | Returns: |
| | torch.Tensor: mask |
| | |
| | Examples: |
| | >>> subsequent_mask(3) |
| | [[1, 0, 0], |
| | [1, 1, 0], |
| | [1, 1, 1]] |
| | """ |
| | ret = torch.ones(size, size, device=device, dtype=torch.bool) |
| | return torch.tril(ret) |
| | ''' |
| |
|
| |
|
| | def subsequent_chunk_mask( |
| | size: int, |
| | chunk_size: int, |
| | num_left_chunks: int = -1, |
| | device: torch.device = torch.device("cpu"), |
| | ) -> torch.Tensor: |
| | """Create mask for subsequent steps (size, size) with chunk size, |
| | this is for streaming encoder |
| | |
| | Args: |
| | size (int): size of mask |
| | chunk_size (int): size of chunk |
| | num_left_chunks (int): number of left chunks |
| | <0: use full chunk |
| | >=0: use num_left_chunks |
| | device (torch.device): "cpu" or "cuda" or torch.Tensor.device |
| | |
| | Returns: |
| | torch.Tensor: mask |
| | |
| | Examples: |
| | >>> subsequent_chunk_mask(4, 2) |
| | [[1, 1, 0, 0], |
| | [1, 1, 0, 0], |
| | [1, 1, 1, 1], |
| | [1, 1, 1, 1]] |
| | """ |
| | |
| | |
| | pos_idx = torch.arange(size, device=device) |
| | block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size |
| | ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) |
| | return ret |
| |
|
| |
|
| | def add_optional_chunk_mask(xs: torch.Tensor, |
| | masks: torch.Tensor, |
| | use_dynamic_chunk: bool, |
| | use_dynamic_left_chunk: bool, |
| | decoding_chunk_size: int, |
| | static_chunk_size: int, |
| | num_decoding_left_chunks: int, |
| | enable_full_context: bool = True): |
| | """ Apply optional mask for encoder. |
| | |
| | Args: |
| | xs (torch.Tensor): padded input, (B, L, D), L for max length |
| | mask (torch.Tensor): mask for xs, (B, 1, L) |
| | use_dynamic_chunk (bool): whether to use dynamic chunk or not |
| | use_dynamic_left_chunk (bool): whether to use dynamic left chunk for |
| | training. |
| | decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's |
| | 0: default for training, use random dynamic chunk. |
| | <0: for decoding, use full chunk. |
| | >0: for decoding, use fixed chunk size as set. |
| | static_chunk_size (int): chunk size for static chunk training/decoding |
| | if it's greater than 0, if use_dynamic_chunk is true, |
| | this parameter will be ignored |
| | num_decoding_left_chunks: number of left chunks, this is for decoding, |
| | the chunk size is decoding_chunk_size. |
| | >=0: use num_decoding_left_chunks |
| | <0: use all left chunks |
| | enable_full_context (bool): |
| | True: chunk size is either [1, 25] or full context(max_len) |
| | False: chunk size ~ U[1, 25] |
| | |
| | Returns: |
| | torch.Tensor: chunk mask of the input xs. |
| | """ |
| | |
| | if use_dynamic_chunk: |
| | max_len = xs.size(1) |
| | if decoding_chunk_size < 0: |
| | chunk_size = max_len |
| | num_left_chunks = -1 |
| | elif decoding_chunk_size > 0: |
| | chunk_size = decoding_chunk_size |
| | num_left_chunks = num_decoding_left_chunks |
| | else: |
| | |
| | |
| | |
| | chunk_size = torch.randint(1, max_len, (1, )).item() |
| | num_left_chunks = -1 |
| | if chunk_size > max_len // 2 and enable_full_context: |
| | chunk_size = max_len |
| | else: |
| | chunk_size = chunk_size % 25 + 1 |
| | if use_dynamic_left_chunk: |
| | max_left_chunks = (max_len - 1) // chunk_size |
| | num_left_chunks = torch.randint(0, max_left_chunks, |
| | (1, )).item() |
| | chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, |
| | num_left_chunks, |
| | xs.device) |
| | chunk_masks = chunk_masks.unsqueeze(0) |
| | chunk_masks = masks & chunk_masks |
| | elif static_chunk_size > 0: |
| | num_left_chunks = num_decoding_left_chunks |
| | chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, |
| | num_left_chunks, |
| | xs.device) |
| | chunk_masks = chunk_masks.unsqueeze(0) |
| | chunk_masks = masks & chunk_masks |
| | else: |
| | chunk_masks = masks |
| | assert chunk_masks.dtype == torch.bool |
| | if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: |
| | logging.warning('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') |
| | chunk_masks[chunk_masks.sum(dim=-1)==0] = True |
| | return chunk_masks |
| |
|
| |
|
| | def make_pad_mask(lengths: torch.Tensor, max_len: int = 0) -> torch.Tensor: |
| | """Make mask tensor containing indices of padded part. |
| | |
| | See description of make_non_pad_mask. |
| | |
| | Args: |
| | lengths (torch.Tensor): Batch of lengths (B,). |
| | Returns: |
| | torch.Tensor: Mask tensor containing indices of padded part. |
| | |
| | Examples: |
| | >>> lengths = [5, 3, 2] |
| | >>> make_pad_mask(lengths) |
| | masks = [[0, 0, 0, 0 ,0], |
| | [0, 0, 0, 1, 1], |
| | [0, 0, 1, 1, 1]] |
| | """ |
| | batch_size = lengths.size(0) |
| | max_len = max_len if max_len > 0 else lengths.max().item() |
| | seq_range = torch.arange(0, |
| | max_len, |
| | dtype=torch.int64, |
| | device=lengths.device) |
| | seq_range_expand = seq_range.unsqueeze(0).expand(batch_size, max_len) |
| | seq_length_expand = lengths.unsqueeze(-1) |
| | mask = seq_range_expand >= seq_length_expand |
| | return mask |
| |
|