Spaces:
Running
on
Zero
Running
on
Zero
File size: 6,856 Bytes
6f25f9f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import os
import torch
import warnings
try:
import flash_attn_interface
FLASH_ATTN_3_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_3_AVAILABLE = False
try:
import flash_attn
FLASH_ATTN_2_AVAILABLE = True
except ModuleNotFoundError:
FLASH_ATTN_2_AVAILABLE = False
try:
major, minor = torch.cuda.get_device_capability(0)
if f"{major}.{minor}" == "8.0":
from sageattention_sm80 import sageattn
SAGE_ATTENTION_AVAILABLE = True
elif f"{major}.{minor}" == "8.6":
from sageattention_sm86 import sageattn
SAGE_ATTENTION_AVAILABLE = True
elif f"{major}.{minor}" == "8.9":
from sageattention_sm89 import sageattn
SAGE_ATTENTION_AVAILABLE = True
elif f"{major}.{minor}" == "9.0":
from sageattention_sm90 import sageattn
SAGE_ATTENTION_AVAILABLE = True
elif major>9:
from sageattention_sm120 import sageattn
SAGE_ATTENTION_AVAILABLE = True
except:
try:
from sageattention import sageattn
SAGE_ATTENTION_AVAILABLE = True
except:
sageattn = None
SAGE_ATTENTION_AVAILABLE = False
def flash_attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
version=None,
):
"""
q: [B, Lq, Nq, C1].
k: [B, Lk, Nk, C1].
v: [B, Lk, Nk, C2]. Nq must be divisible by Nk.
q_lens: [B].
k_lens: [B].
dropout_p: float. Dropout probability.
softmax_scale: float. The scaling of QK^T before applying softmax.
causal: bool. Whether to apply causal attention mask.
window_size: (left right). If not (-1, -1), apply sliding window local attention.
deterministic: bool. If True, slightly slower and uses more memory.
dtype: torch.dtype. Apply when dtype of q/k/v is not float16/bfloat16.
"""
half_dtypes = (torch.float16, torch.bfloat16)
assert dtype in half_dtypes
assert q.device.type == 'cuda' and q.size(-1) <= 256
# params
b, lq, lk, out_dtype = q.size(0), q.size(1), k.size(1), q.dtype
def half(x):
return x if x.dtype in half_dtypes else x.to(dtype)
# preprocess query
if q_lens is None:
q = half(q.flatten(0, 1))
q_lens = torch.tensor(
[lq] * b, dtype=torch.int32).to(
device=q.device, non_blocking=True)
else:
q = half(torch.cat([u[:v] for u, v in zip(q, q_lens)]))
# preprocess key, value
if k_lens is None:
k = half(k.flatten(0, 1))
v = half(v.flatten(0, 1))
k_lens = torch.tensor(
[lk] * b, dtype=torch.int32).to(
device=k.device, non_blocking=True)
else:
k = half(torch.cat([u[:v] for u, v in zip(k, k_lens)]))
v = half(torch.cat([u[:v] for u, v in zip(v, k_lens)]))
q = q.to(v.dtype)
k = k.to(v.dtype)
if q_scale is not None:
q = q * q_scale
if version is not None and version == 3 and not FLASH_ATTN_3_AVAILABLE:
warnings.warn(
'Flash attention 3 is not available, use flash attention 2 instead.'
)
# apply attention
if (version is None or version == 3) and FLASH_ATTN_3_AVAILABLE:
# Note: dropout_p, window_size are not supported in FA3 now.
x = flash_attn_interface.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
seqused_q=None,
seqused_k=None,
max_seqlen_q=lq,
max_seqlen_k=lk,
softmax_scale=softmax_scale,
causal=causal,
deterministic=deterministic).unflatten(0, (b, lq))
else:
assert FLASH_ATTN_2_AVAILABLE
x = flash_attn.flash_attn_varlen_func(
q=q,
k=k,
v=v,
cu_seqlens_q=torch.cat([q_lens.new_zeros([1]), q_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
cu_seqlens_k=torch.cat([k_lens.new_zeros([1]), k_lens]).cumsum(
0, dtype=torch.int32).to(q.device, non_blocking=True),
max_seqlen_q=lq,
max_seqlen_k=lk,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic).unflatten(0, (b, lq))
# output
return x.type(out_dtype)
def attention(
q,
k,
v,
q_lens=None,
k_lens=None,
dropout_p=0.,
softmax_scale=None,
q_scale=None,
causal=False,
window_size=(-1, -1),
deterministic=False,
dtype=torch.bfloat16,
fa_version=None,
attention_type=None,
attn_mask=None,
):
attention_type = os.environ.get("VIDEOX_ATTENTION_TYPE", "FLASH_ATTENTION") if attention_type is None else attention_type
if torch.is_grad_enabled() and attention_type == "SAGE_ATTENTION":
attention_type = "FLASH_ATTENTION"
if attention_type == "SAGE_ATTENTION" and SAGE_ATTENTION_AVAILABLE:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
out = sageattn(
q, k, v, attn_mask=attn_mask, tensor_layout="NHD", is_causal=causal, dropout_p=dropout_p)
elif attention_type == "FLASH_ATTENTION" and (FLASH_ATTN_2_AVAILABLE or FLASH_ATTN_3_AVAILABLE):
return flash_attention(
q=q,
k=k,
v=v,
q_lens=q_lens,
k_lens=k_lens,
dropout_p=dropout_p,
softmax_scale=softmax_scale,
q_scale=q_scale,
causal=causal,
window_size=window_size,
deterministic=deterministic,
dtype=dtype,
version=fa_version,
)
else:
if q_lens is not None or k_lens is not None:
warnings.warn(
'Padding mask is disabled when using scaled_dot_product_attention. It can have a significant impact on performance.'
)
q = q.transpose(1, 2)
k = k.transpose(1, 2)
v = v.transpose(1, 2)
out = torch.nn.functional.scaled_dot_product_attention(
q, k, v, attn_mask=attn_mask, is_causal=causal, dropout_p=dropout_p)
out = out.transpose(1, 2).contiguous()
return out |