|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import math |
|
|
import os |
|
|
|
|
|
import torch |
|
|
|
|
|
from ..utils.import_utils import is_torch_npu_available |
|
|
|
|
|
|
|
|
if is_torch_npu_available(): |
|
|
from torch_npu import npu_fusion_attention |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2 |
|
|
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3 |
|
|
|
|
|
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE)) |
|
|
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]: |
|
|
raise ValueError( |
|
|
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) " |
|
|
"or 3 (down-right aligned causal mask)." |
|
|
) |
|
|
|
|
|
ATTN_MASK_NPU_CACHE = {} |
|
|
|
|
|
|
|
|
def get_attn_mask_npu(device): |
|
|
"""Get or create attention mask for the specified device.""" |
|
|
if device not in ATTN_MASK_NPU_CACHE: |
|
|
ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool() |
|
|
return ATTN_MASK_NPU_CACHE[device] |
|
|
|
|
|
|
|
|
def is_npu_fa2_top_left_aligned_causal_mask(): |
|
|
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False |
|
|
|
|
|
|
|
|
def npu_flash_attn_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=None, |
|
|
causal=False, |
|
|
**kwargs, |
|
|
): |
|
|
keep_prob = 1.0 - dropout_p |
|
|
|
|
|
if softmax_scale is None: |
|
|
softmax_scale = 1.0 / math.sqrt(q.shape[-1]) |
|
|
|
|
|
if not causal: |
|
|
head_num = q.shape[2] |
|
|
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0] |
|
|
else: |
|
|
attn_mask_npu = get_attn_mask_npu(q.device) |
|
|
head_num = q.shape[2] |
|
|
output = npu_fusion_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
head_num, |
|
|
"BSND", |
|
|
keep_prob=keep_prob, |
|
|
scale=softmax_scale, |
|
|
atten_mask=attn_mask_npu, |
|
|
sparse_mode=SPARSE_MODE, |
|
|
)[0] |
|
|
|
|
|
return output |
|
|
|
|
|
|
|
|
def npu_flash_attn_varlen_func( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
cu_seqlens_q, |
|
|
cu_seqlens_k, |
|
|
max_seqlen_q=None, |
|
|
max_seqlen_k=None, |
|
|
dropout_p=0.0, |
|
|
softmax_scale=None, |
|
|
causal=False, |
|
|
**kwargs, |
|
|
): |
|
|
keep_prob = 1.0 - dropout_p |
|
|
|
|
|
if softmax_scale is None: |
|
|
softmax_scale = 1.0 / math.sqrt(q.shape[-1]) |
|
|
|
|
|
if not causal: |
|
|
head_num = q.shape[1] |
|
|
output = npu_fusion_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
head_num, |
|
|
pse=None, |
|
|
atten_mask=None, |
|
|
scale=softmax_scale, |
|
|
keep_prob=keep_prob, |
|
|
input_layout="TND", |
|
|
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), |
|
|
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), |
|
|
)[0] |
|
|
else: |
|
|
attn_mask_npu = get_attn_mask_npu(q.device) |
|
|
head_num = q.shape[1] |
|
|
output = npu_fusion_attention( |
|
|
q, |
|
|
k, |
|
|
v, |
|
|
head_num, |
|
|
pse=None, |
|
|
padding_mask=None, |
|
|
atten_mask=attn_mask_npu, |
|
|
scale=softmax_scale, |
|
|
keep_prob=keep_prob, |
|
|
input_layout="TND", |
|
|
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()), |
|
|
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()), |
|
|
sparse_mode=SPARSE_MODE, |
|
|
)[0] |
|
|
|
|
|
return output |
|
|
|