dheeena's picture
Add files using upload-large-folder tool
6f0b660 verified
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import math
import os
import torch
from ..utils.import_utils import is_torch_npu_available
if is_torch_npu_available():
from torch_npu import npu_fusion_attention
# FlashAttention2 is supported on Ascend NPU with down-right aligned causal mask by default.
# Set environment variable `NPU_FA2_SPARSE_MODE` to 2 when using top-left aligned causal mask.
TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE = 2
DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE = 3
SPARSE_MODE = int(os.getenv("NPU_FA2_SPARSE_MODE", default=DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE))
if SPARSE_MODE not in [TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE, DOWN_RIGHT_ALIGNED_CAUSAL_MASK_MODE]:
raise ValueError(
"Environment variable `NPU_FA2_SPARSE_MODE` can only be set as 2 (top-left aligned causal mask) "
"or 3 (down-right aligned causal mask)."
)
ATTN_MASK_NPU_CACHE = {}
def get_attn_mask_npu(device):
"""Get or create attention mask for the specified device."""
if device not in ATTN_MASK_NPU_CACHE:
ATTN_MASK_NPU_CACHE[device] = torch.triu(torch.ones([2048, 2048], device=device), diagonal=1).bool()
return ATTN_MASK_NPU_CACHE[device]
def is_npu_fa2_top_left_aligned_causal_mask():
return SPARSE_MODE == TOP_LEFT_ALIGNED_CAUSAL_MASK_MODE if is_torch_npu_available() else False
def npu_flash_attn_func(
q,
k,
v,
dropout_p=0.0,
softmax_scale=None,
causal=False,
**kwargs,
):
keep_prob = 1.0 - dropout_p
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
if not causal:
head_num = q.shape[2]
output = npu_fusion_attention(q, k, v, head_num, "BSND", keep_prob=keep_prob, scale=softmax_scale)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[2]
output = npu_fusion_attention(
q,
k,
v,
head_num,
"BSND",
keep_prob=keep_prob,
scale=softmax_scale,
atten_mask=attn_mask_npu,
sparse_mode=SPARSE_MODE,
)[0]
return output
def npu_flash_attn_varlen_func(
q,
k,
v,
cu_seqlens_q,
cu_seqlens_k,
max_seqlen_q=None, # defined for aligning params order with corresponding function in `flash-attn`
max_seqlen_k=None, # defined for aligning params order with corresponding function in `flash-attn`
dropout_p=0.0,
softmax_scale=None,
causal=False,
**kwargs,
):
keep_prob = 1.0 - dropout_p
if softmax_scale is None:
softmax_scale = 1.0 / math.sqrt(q.shape[-1])
if not causal:
head_num = q.shape[1]
output = npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
atten_mask=None,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
)[0]
else:
attn_mask_npu = get_attn_mask_npu(q.device)
head_num = q.shape[1]
output = npu_fusion_attention(
q,
k,
v,
head_num,
pse=None,
padding_mask=None,
atten_mask=attn_mask_npu,
scale=softmax_scale,
keep_prob=keep_prob,
input_layout="TND",
actual_seq_qlen=tuple(cu_seqlens_q[1:].cpu().numpy().tolist()),
actual_seq_kvlen=tuple(cu_seqlens_k[1:].cpu().numpy().tolist()),
sparse_mode=SPARSE_MODE,
)[0]
return output