import torch def add_optional_chunk_mask(xs: torch.Tensor, masks: torch.Tensor, use_dynamic_chunk: bool, use_dynamic_left_chunk: bool, decoding_chunk_size: int, static_chunk_size: int, num_decoding_left_chunks: int, enable_full_context: bool = True): """ Apply optional mask for encoder. Args: xs (torch.Tensor): padded input, (B, L, D), L for max length mask (torch.Tensor): mask for xs, (B, 1, L) use_dynamic_chunk (bool): whether to use dynamic chunk or not use_dynamic_left_chunk (bool): whether to use dynamic left chunk for training. decoding_chunk_size (int): decoding chunk size for dynamic chunk, it's 0: default for training, use random dynamic chunk. <0: for decoding, use full chunk. >0: for decoding, use fixed chunk size as set. static_chunk_size (int): chunk size for static chunk training/decoding if it's greater than 0, if use_dynamic_chunk is true, this parameter will be ignored num_decoding_left_chunks: number of left chunks, this is for decoding, the chunk size is decoding_chunk_size. >=0: use num_decoding_left_chunks <0: use all left chunks enable_full_context (bool): True: chunk size is either [1, 25] or full context(max_len) False: chunk size ~ U[1, 25] Returns: torch.Tensor: chunk mask of the input xs. """ # Whether to use chunk mask or not if use_dynamic_chunk: max_len = xs.size(1) if decoding_chunk_size < 0: chunk_size = max_len num_left_chunks = -1 elif decoding_chunk_size > 0: chunk_size = decoding_chunk_size num_left_chunks = num_decoding_left_chunks else: # chunk size is either [1, 25] or full context(max_len). # Since we use 4 times subsampling and allow up to 1s(100 frames) # delay, the maximum frame is 100 / 4 = 25. chunk_size = torch.randint(1, max_len, (1, )).item() num_left_chunks = -1 if chunk_size > max_len // 2 and enable_full_context: chunk_size = max_len else: chunk_size = chunk_size % 25 + 1 if use_dynamic_left_chunk: max_left_chunks = (max_len - 1) // chunk_size num_left_chunks = torch.randint(0, max_left_chunks, (1, )).item() chunk_masks = subsequent_chunk_mask(xs.size(1), chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) elif static_chunk_size > 0: num_left_chunks = num_decoding_left_chunks chunk_masks = subsequent_chunk_mask(xs.size(1), static_chunk_size, num_left_chunks, xs.device) # (L, L) chunk_masks = chunk_masks.unsqueeze(0) # (1, L, L) chunk_masks = masks & chunk_masks # (B, L, L) else: chunk_masks = masks assert chunk_masks.dtype == torch.bool if (chunk_masks.sum(dim=-1) == 0).sum().item() != 0: print('get chunk_masks all false at some timestep, force set to true, make sure they are masked in futuer computation!') chunk_masks[chunk_masks.sum(dim=-1) == 0] = True return chunk_masks def subsequent_chunk_mask( size: int, chunk_size: int, num_left_chunks: int = -1, device: torch.device = torch.device("cpu"), ) -> torch.Tensor: """Create mask for subsequent steps (size, size) with chunk size, this is for streaming encoder Args: size (int): size of mask chunk_size (int): size of chunk num_left_chunks (int): number of left chunks <0: use full chunk >=0: use num_left_chunks device (torch.device): "cpu" or "cuda" or torch.Tensor.device Returns: torch.Tensor: mask Examples: >>> subsequent_chunk_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] """ # NOTE this modified implementation meets onnx export requirements, but it doesn't support num_left_chunks pos_idx = torch.arange(size, device=device) block_value = (torch.div(pos_idx, chunk_size, rounding_mode='trunc') + 1) * chunk_size ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) return ret def causal_block_mask(size, block_size=1, device="cpu", dtype=torch.bool): """Create mask for subsequent steps (size, size). :param int size: size of mask :param int block_size: block size of mask :param str device: "cpu" or "cuda" or torch.Tensor.device :param torch.dtype dtype: result dtype :rtype: torch.Tensor >>> causal_block_mask(4, 2) [[1, 1, 0, 0], [1, 1, 0, 0], [1, 1, 1, 1], [1, 1, 1, 1]] """ # assert size % block_size == 0 pos_idx = torch.arange(size, device=device) block_value = (torch.div(pos_idx, block_size, rounding_mode='trunc') + 1) * block_size ret = pos_idx.unsqueeze(0) < block_value.unsqueeze(1) return ret.to(dtype)