File size: 2,192 Bytes
dfefe0b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
diff --git a/src/transformers/integrations/npu_flash_attention.py b/src/transformers/integrations/npu_flash_attention.py
index bb515540d14b..e32af9f4bc9e 100644
--- a/src/transformers/integrations/npu_flash_attention.py
+++ b/src/transformers/integrations/npu_flash_attention.py
@@ -37,6 +37,8 @@
         "or 3 (down-right aligned causal mask)."
     )
 
+ATTN_MASK_NPU = None
+
 
 def is_npu_fa2_top_left_aligned_causal_mask():
     return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
@@ -171,7 +173,9 @@ def npu_flash_attn_func(
         head_num = q.shape[2]
         output = torch_npu.npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
     else:
-        attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
+        global ATTN_MASK_NPU
+        if ATTN_MASK_NPU is None:
+            ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
         head_num = q.shape[2]
         output = torch_npu.npu_fusion_attention(
             q,
@@ -181,7 +185,7 @@ def npu_flash_attn_func(
             "BSND",
             keep_prob=keep_prob,
             scale=softmax_scale,
-            atten_mask=attn_mask_npu,
+            atten_mask=ATTN_MASK_NPU,
             sparse_mode=SPARSE_MODE,
         )[0]
 
@@ -222,7 +226,9 @@ def npu_flash_attn_varlen_func(
             actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
         )[0]
     else:
-        attn_mask_npu = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
+        global ATTN_MASK_NPU
+        if ATTN_MASK_NPU is None:
+            ATTN_MASK_NPU = torch.triu(torch.ones([2048, 2048], device=q.device), diagonal=1).bool()
         head_num = q.shape[1]
         output = torch_npu.npu_fusion_attention(
             q,
@@ -231,7 +237,7 @@ def npu_flash_attn_varlen_func(
             head_num,
             pse=None,
             padding_mask=None,
-            atten_mask=attn_mask_npu,
+            atten_mask=ATTN_MASK_NPU,
             scale=softmax_scale,
             keep_prob=keep_prob,
             input_layout="TND",