| import torch | |
| def get_attention_mask(sequence_length, device, mask_type="block-causal", **kwargs): | |
| if mask_type.lower() == 'none' or mask_type is None: | |
| return None | |
| elif mask_type.lower() == 'block-causal': | |
| return _block_caulsal_mask_impl(sequence_length, device, **kwargs) | |
| elif mask_type.lower() == 'causal': | |
| return _caulsal_mask_impl(sequence_length, device, **kwargs) | |
| else: | |
| raise NotImplementedError(f"Mask type {mask_type} not implemented") | |
| def _block_caulsal_mask_impl(sequence_length, device, block_size=16, **kwargs): | |
| """ | |
| Create a block-causal mask | |
| """ | |
| assert sequence_length % block_size == 0, "for block causal masks sequence length must be divisible by block size" | |
| blocks = torch.ones(sequence_length // block_size, block_size, block_size, device=device) | |
| block_diag_enable_mask = torch.block_diag(*blocks) | |
| causal_enable_mask = torch.ones(sequence_length, sequence_length, device=device).tril_(0) | |
| disable_mask = ((block_diag_enable_mask + causal_enable_mask) < 0.5) | |
| return disable_mask | |
| def _caulsal_mask_impl(sequence_length, device, **kwargs): | |
| """ | |
| Create a causal mask | |
| """ | |
| causal_disable_mask = torch.triu( | |
| torch.full((sequence_length, sequence_length), float('-inf'), dtype=torch.float32, device=device), | |
| diagonal=1, | |
| ) | |
| return causal_disable_mask | |
| if __name__ == '__main__': | |
| mask = get_attention_mask(9, "cuda", mask_type="block-causal", block_size=3) | |
| print(mask) | |
| mask = get_attention_mask(9, "cuda", mask_type="causal") | |
| print(mask) |