| |
| |
| |
| |
| @@ -37,6 +37,8 @@ |
| "or 3 (down-right aligned causal mask)." |
| ) |
| |
| +ATTN_MASK_NPU = None |
| + |
| |
| def is_npu_fa2_top_left_aligned_causal_mask(): |
| return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False |
| @@ -171,7 +173,9 @@ def npu_flash_attn_func( |
| head_num = q.shape[2] |
| output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] |
| else: |
| - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() |
| + global ATTN_MASK_NPU |
| + if ATTN_MASK_NPU is None: |
| + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() |
| head_num = q.shape[2] |
| output = torch_npu.npu_fusion_attention( |
| q, |
| @@ -181,7 +185,7 @@ def npu_flash_attn_func( |
| "BSND", |
| keep_prob=keep_prob, |
| scale=softmax_scale, |
| - atten_mask=attn_mask_npu, |
| + atten_mask=ATTN_MASK_NPU, |
| sparse_mode=SPARSE_MODE, |
| )[0] |
| |
| @@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func( |
| actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), |
| )[0] |
| else: |
| - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() |
| + global ATTN_MASK_NPU |
| + if ATTN_MASK_NPU is None: |
| + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() |
| head_num = q.shape[1] |
| output = torch_npu.npu_fusion_attention( |
| q, |
| @@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func( |
| head_num, |
| pse=None, |
| padding_mask=None, |
| - atten_mask=attn_mask_npu, |
| + atten_mask=ATTN_MASK_NPU, |
| scale=softmax_scale, |
| keep_prob=keep_prob, |
| input_layout="TND", |
|
|