diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py index bb515540d14b..e32af9f4bc9e 100644 --- a/src/transformers/integrations/npu_flash_attention.py +++ b/src/transformers/integrations/npu_flash_attention.py @@ -37,6 +37,8 @@ "or 3 (down-right aligned causal mask)." ) +ATTN_MASK_NPU = None + def is_npu_fa2_top_left_aligned_causal_mask(): return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False @@ -171,7 +173,9 @@ def npu_flash_attn_func( head_num = q.shape[2] output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[2] output = torch_npu.npu_fusion_attention( q, @@ -181,7 +185,7 @@ def npu_flash_attn_func( "BSND", keep_prob=keep_prob, scale=softmax_scale, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, sparse_mode=SPARSE_MODE, )[0] @@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func( actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), )[0] else: - attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() + global ATTN_MASK_NPU + if ATTN_MASK_NPU is None: + ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool() head_num = q.shape[1] output = torch_npu.npu_fusion_attention( q, @@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func( head_num, pse=None, padding_mask=None, - atten_mask=attn_mask_npu, + atten_mask=ATTN_MASK_NPU, scale=softmax_scale, keep_prob=keep_prob, input_layout="TND",