Spaces:
Running on Zero
Running on Zero
Bobby commited on
Commit ·
31860a5
1
Parent(s): 86e4232
Fallback TRELLIS attention backends to SDPA when flash-attn missing
Browse files- app.py +9 -1
- trellis/modules/attention/__init__.py +5 -2
- trellis/modules/attention/full_attn.py +13 -3
- trellis/modules/sparse/__init__.py +6 -3
- trellis/modules/sparse/attention/full_attn.py +47 -3
- trellis/modules/sparse/attention/serialized_attn.py +35 -3
- trellis/modules/sparse/attention/windowed_attn.py +35 -3
app.py
CHANGED
|
@@ -1,5 +1,6 @@
|
|
| 1 |
import argparse
|
| 2 |
import concurrent.futures
|
|
|
|
| 3 |
import os
|
| 4 |
import sys
|
| 5 |
import time
|
|
@@ -8,7 +9,14 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
|
|
| 8 |
|
| 9 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 10 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
os.environ.setdefault("SPCONV_ALGO", "native")
|
| 13 |
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
|
| 14 |
os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json"
|
|
|
|
| 1 |
import argparse
|
| 2 |
import concurrent.futures
|
| 3 |
+
import importlib.util
|
| 4 |
import os
|
| 5 |
import sys
|
| 6 |
import time
|
|
|
|
| 9 |
|
| 10 |
os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
|
| 11 |
os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
|
| 12 |
+
if importlib.util.find_spec("flash_attn") is not None:
|
| 13 |
+
_attn_backend = "flash_attn"
|
| 14 |
+
elif importlib.util.find_spec("xformers") is not None:
|
| 15 |
+
_attn_backend = "xformers"
|
| 16 |
+
else:
|
| 17 |
+
_attn_backend = "sdpa"
|
| 18 |
+
os.environ.setdefault("ATTN_BACKEND", _attn_backend)
|
| 19 |
+
os.environ.setdefault("SPARSE_ATTN_BACKEND", _attn_backend)
|
| 20 |
os.environ.setdefault("SPCONV_ALGO", "native")
|
| 21 |
os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
|
| 22 |
os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json"
|
trellis/modules/attention/__init__.py
CHANGED
|
@@ -1,6 +1,6 @@
|
|
| 1 |
from typing import *
|
| 2 |
|
| 3 |
-
BACKEND = '
|
| 4 |
DEBUG = False
|
| 5 |
|
| 6 |
def __from_env():
|
|
@@ -12,6 +12,9 @@ def __from_env():
|
|
| 12 |
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
| 13 |
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
| 14 |
|
|
|
|
|
|
|
|
|
|
| 15 |
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 16 |
BACKEND = env_attn_backend
|
| 17 |
if env_sttn_debug is not None:
|
|
@@ -25,7 +28,7 @@ def __from_env():
|
|
| 25 |
__from_env()
|
| 26 |
|
| 27 |
|
| 28 |
-
def set_backend(backend: Literal['xformers', 'flash_attn']):
|
| 29 |
global BACKEND
|
| 30 |
BACKEND = backend
|
| 31 |
|
|
|
|
| 1 |
from typing import *
|
| 2 |
|
| 3 |
+
BACKEND = 'sdpa'
|
| 4 |
DEBUG = False
|
| 5 |
|
| 6 |
def __from_env():
|
|
|
|
| 12 |
env_attn_backend = os.environ.get('ATTN_BACKEND')
|
| 13 |
env_sttn_debug = os.environ.get('ATTN_DEBUG')
|
| 14 |
|
| 15 |
+
if env_attn_backend == 'flash_attn_3':
|
| 16 |
+
env_attn_backend = 'flash_attn'
|
| 17 |
+
|
| 18 |
if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 19 |
BACKEND = env_attn_backend
|
| 20 |
if env_sttn_debug is not None:
|
|
|
|
| 28 |
__from_env()
|
| 29 |
|
| 30 |
|
| 31 |
+
def set_backend(backend: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
|
| 32 |
global BACKEND
|
| 33 |
BACKEND = backend
|
| 34 |
|
trellis/modules/attention/full_attn.py
CHANGED
|
@@ -1,12 +1,22 @@
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
-
from . import DEBUG, BACKEND
|
| 5 |
|
| 6 |
if BACKEND == 'xformers':
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
elif BACKEND == 'flash_attn':
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
elif BACKEND == 'sdpa':
|
| 11 |
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 12 |
elif BACKEND == 'naive':
|
|
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
+
from . import DEBUG, BACKEND, set_backend
|
| 5 |
|
| 6 |
if BACKEND == 'xformers':
|
| 7 |
+
try:
|
| 8 |
+
import xformers.ops as xops
|
| 9 |
+
except ImportError:
|
| 10 |
+
BACKEND = 'sdpa'
|
| 11 |
+
set_backend(BACKEND)
|
| 12 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 13 |
elif BACKEND == 'flash_attn':
|
| 14 |
+
try:
|
| 15 |
+
import flash_attn
|
| 16 |
+
except ImportError:
|
| 17 |
+
BACKEND = 'sdpa'
|
| 18 |
+
set_backend(BACKEND)
|
| 19 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 20 |
elif BACKEND == 'sdpa':
|
| 21 |
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 22 |
elif BACKEND == 'naive':
|
trellis/modules/sparse/__init__.py
CHANGED
|
@@ -2,7 +2,7 @@ from typing import *
|
|
| 2 |
|
| 3 |
BACKEND = 'spconv'
|
| 4 |
DEBUG = False
|
| 5 |
-
ATTN = '
|
| 6 |
|
| 7 |
def __from_env():
|
| 8 |
import os
|
|
@@ -21,7 +21,10 @@ def __from_env():
|
|
| 21 |
BACKEND = env_sparse_backend
|
| 22 |
if env_sparse_debug is not None:
|
| 23 |
DEBUG = env_sparse_debug == '1'
|
| 24 |
-
if env_sparse_attn
|
|
|
|
|
|
|
|
|
|
| 25 |
ATTN = env_sparse_attn
|
| 26 |
|
| 27 |
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
|
@@ -38,7 +41,7 @@ def set_debug(debug: bool):
|
|
| 38 |
global DEBUG
|
| 39 |
DEBUG = debug
|
| 40 |
|
| 41 |
-
def set_attn(attn: Literal['xformers', 'flash_attn']):
|
| 42 |
global ATTN
|
| 43 |
ATTN = attn
|
| 44 |
|
|
|
|
| 2 |
|
| 3 |
BACKEND = 'spconv'
|
| 4 |
DEBUG = False
|
| 5 |
+
ATTN = 'sdpa'
|
| 6 |
|
| 7 |
def __from_env():
|
| 8 |
import os
|
|
|
|
| 21 |
BACKEND = env_sparse_backend
|
| 22 |
if env_sparse_debug is not None:
|
| 23 |
DEBUG = env_sparse_debug == '1'
|
| 24 |
+
if env_sparse_attn == 'flash_attn_3':
|
| 25 |
+
env_sparse_attn = 'flash_attn'
|
| 26 |
+
|
| 27 |
+
if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa', 'naive']:
|
| 28 |
ATTN = env_sparse_attn
|
| 29 |
|
| 30 |
print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
|
|
|
|
| 41 |
global DEBUG
|
| 42 |
DEBUG = debug
|
| 43 |
|
| 44 |
+
def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
|
| 45 |
global ATTN
|
| 46 |
ATTN = attn
|
| 47 |
|
trellis/modules/sparse/attention/full_attn.py
CHANGED
|
@@ -1,12 +1,23 @@
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
|
|
|
| 3 |
from .. import SparseTensor
|
| 4 |
-
from .. import DEBUG, ATTN
|
| 5 |
|
| 6 |
if ATTN == 'xformers':
|
| 7 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 8 |
elif ATTN == 'flash_attn':
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
else:
|
| 11 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 12 |
|
|
@@ -16,6 +27,14 @@ __all__ = [
|
|
| 16 |
]
|
| 17 |
|
| 18 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 19 |
@overload
|
| 20 |
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
|
| 21 |
"""
|
|
@@ -206,6 +225,31 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
|
|
| 206 |
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
| 207 |
elif num_all_args == 3:
|
| 208 |
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 209 |
else:
|
| 210 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 211 |
|
|
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
| 3 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 4 |
from .. import SparseTensor
|
| 5 |
+
from .. import DEBUG, ATTN, set_attn
|
| 6 |
|
| 7 |
if ATTN == 'xformers':
|
| 8 |
+
try:
|
| 9 |
+
import xformers.ops as xops
|
| 10 |
+
except ImportError:
|
| 11 |
+
ATTN = 'sdpa'
|
| 12 |
+
set_attn(ATTN)
|
| 13 |
elif ATTN == 'flash_attn':
|
| 14 |
+
try:
|
| 15 |
+
import flash_attn
|
| 16 |
+
except ImportError:
|
| 17 |
+
ATTN = 'sdpa'
|
| 18 |
+
set_attn(ATTN)
|
| 19 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 20 |
+
pass
|
| 21 |
else:
|
| 22 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 23 |
|
|
|
|
| 27 |
]
|
| 28 |
|
| 29 |
|
| 30 |
+
def _sdpa_chunk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
|
| 31 |
+
q = q.permute(1, 0, 2).unsqueeze(0) # [1, H, Lq, C]
|
| 32 |
+
k = k.permute(1, 0, 2).unsqueeze(0) # [1, H, Lk, C]
|
| 33 |
+
v = v.permute(1, 0, 2).unsqueeze(0) # [1, H, Lk, C]
|
| 34 |
+
out = sdpa(q, k, v)
|
| 35 |
+
return out.squeeze(0).permute(1, 0, 2) # [Lq, H, C]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
@overload
|
| 39 |
def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
|
| 40 |
"""
|
|
|
|
| 225 |
out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
| 226 |
elif num_all_args == 3:
|
| 227 |
out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
|
| 228 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 229 |
+
outs = []
|
| 230 |
+
q_start, kv_start = 0, 0
|
| 231 |
+
for q_len, kv_len in zip(q_seqlen, kv_seqlen):
|
| 232 |
+
if num_all_args == 1:
|
| 233 |
+
qkv_chunk = qkv[q_start:q_start + q_len]
|
| 234 |
+
q_i, k_i, v_i = qkv_chunk.unbind(dim=1)
|
| 235 |
+
elif num_all_args == 2:
|
| 236 |
+
q_i = q[q_start:q_start + q_len]
|
| 237 |
+
kv_chunk = kv[kv_start:kv_start + kv_len]
|
| 238 |
+
k_i, v_i = kv_chunk.unbind(dim=1)
|
| 239 |
+
else:
|
| 240 |
+
q_i = q[q_start:q_start + q_len]
|
| 241 |
+
k_i = k[kv_start:kv_start + kv_len]
|
| 242 |
+
v_i = v[kv_start:kv_start + kv_len]
|
| 243 |
+
|
| 244 |
+
outs.append(_sdpa_chunk(q_i, k_i, v_i))
|
| 245 |
+
q_start += q_len
|
| 246 |
+
kv_start += kv_len
|
| 247 |
+
if outs:
|
| 248 |
+
out = torch.cat(outs, dim=0)
|
| 249 |
+
elif num_all_args == 1:
|
| 250 |
+
out = torch.empty((0, qkv.shape[-2], qkv.shape[-1]), device=device, dtype=qkv.dtype)
|
| 251 |
+
else:
|
| 252 |
+
out = torch.empty((0, q.shape[-2], q.shape[-1]), device=device, dtype=q.dtype)
|
| 253 |
else:
|
| 254 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 255 |
|
trellis/modules/sparse/attention/serialized_attn.py
CHANGED
|
@@ -2,13 +2,24 @@ from typing import *
|
|
| 2 |
from enum import Enum
|
| 3 |
import torch
|
| 4 |
import math
|
|
|
|
| 5 |
from .. import SparseTensor
|
| 6 |
-
from .. import DEBUG, ATTN
|
| 7 |
|
| 8 |
if ATTN == 'xformers':
|
| 9 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
elif ATTN == 'flash_attn':
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
else:
|
| 13 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 14 |
|
|
@@ -18,6 +29,21 @@ __all__ = [
|
|
| 18 |
]
|
| 19 |
|
| 20 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
class SerializeMode(Enum):
|
| 22 |
Z_ORDER = 0
|
| 23 |
Z_ORDER_TRANSPOSED = 1
|
|
@@ -168,6 +194,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
|
|
| 168 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 169 |
elif ATTN == 'flash_attn':
|
| 170 |
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
|
|
|
|
|
|
| 171 |
else:
|
| 172 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 173 |
out = out.reshape(B * N, H, C) # [M, H, C]
|
|
@@ -183,6 +211,10 @@ def sparse_serialized_scaled_dot_product_self_attention(
|
|
| 183 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 184 |
.to(qkv.device).int()
|
| 185 |
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 186 |
|
| 187 |
out = out[bwd_indices] # [T, H, C]
|
| 188 |
|
|
|
|
| 2 |
from enum import Enum
|
| 3 |
import torch
|
| 4 |
import math
|
| 5 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 6 |
from .. import SparseTensor
|
| 7 |
+
from .. import DEBUG, ATTN, set_attn
|
| 8 |
|
| 9 |
if ATTN == 'xformers':
|
| 10 |
+
try:
|
| 11 |
+
import xformers.ops as xops
|
| 12 |
+
except ImportError:
|
| 13 |
+
ATTN = 'sdpa'
|
| 14 |
+
set_attn(ATTN)
|
| 15 |
elif ATTN == 'flash_attn':
|
| 16 |
+
try:
|
| 17 |
+
import flash_attn
|
| 18 |
+
except ImportError:
|
| 19 |
+
ATTN = 'sdpa'
|
| 20 |
+
set_attn(ATTN)
|
| 21 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 22 |
+
pass
|
| 23 |
else:
|
| 24 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 25 |
|
|
|
|
| 29 |
]
|
| 30 |
|
| 31 |
|
| 32 |
+
def _sdpa_varlen_qkv(qkv_feats: torch.Tensor, seq_lens: List[int]) -> torch.Tensor:
|
| 33 |
+
outs = []
|
| 34 |
+
start = 0
|
| 35 |
+
for seq_len in seq_lens:
|
| 36 |
+
chunk = qkv_feats[start:start + seq_len]
|
| 37 |
+
q, k, v = chunk.unbind(dim=1)
|
| 38 |
+
q = q.permute(1, 0, 2).unsqueeze(0)
|
| 39 |
+
k = k.permute(1, 0, 2).unsqueeze(0)
|
| 40 |
+
v = v.permute(1, 0, 2).unsqueeze(0)
|
| 41 |
+
out = sdpa(q, k, v).squeeze(0).permute(1, 0, 2)
|
| 42 |
+
outs.append(out)
|
| 43 |
+
start += seq_len
|
| 44 |
+
return torch.cat(outs, dim=0) if outs else qkv_feats.new_empty((0, qkv_feats.shape[2], qkv_feats.shape[3]))
|
| 45 |
+
|
| 46 |
+
|
| 47 |
class SerializeMode(Enum):
|
| 48 |
Z_ORDER = 0
|
| 49 |
Z_ORDER_TRANSPOSED = 1
|
|
|
|
| 194 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 195 |
elif ATTN == 'flash_attn':
|
| 196 |
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
| 197 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 198 |
+
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 199 |
else:
|
| 200 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 201 |
out = out.reshape(B * N, H, C) # [M, H, C]
|
|
|
|
| 211 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 212 |
.to(qkv.device).int()
|
| 213 |
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
| 214 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 215 |
+
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 216 |
+
else:
|
| 217 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 218 |
|
| 219 |
out = out[bwd_indices] # [T, H, C]
|
| 220 |
|
trellis/modules/sparse/attention/windowed_attn.py
CHANGED
|
@@ -1,13 +1,24 @@
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
| 3 |
import math
|
|
|
|
| 4 |
from .. import SparseTensor
|
| 5 |
-
from .. import DEBUG, ATTN
|
| 6 |
|
| 7 |
if ATTN == 'xformers':
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
elif ATTN == 'flash_attn':
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
else:
|
| 12 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 13 |
|
|
@@ -17,6 +28,21 @@ __all__ = [
|
|
| 17 |
]
|
| 18 |
|
| 19 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 20 |
def calc_window_partition(
|
| 21 |
tensor: SparseTensor,
|
| 22 |
window_size: Union[int, Tuple[int, ...]],
|
|
@@ -110,6 +136,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
|
| 110 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 111 |
elif ATTN == 'flash_attn':
|
| 112 |
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
|
|
|
|
|
|
| 113 |
else:
|
| 114 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 115 |
out = out.reshape(B * N, H, C) # [M, H, C]
|
|
@@ -125,6 +153,10 @@ def sparse_windowed_scaled_dot_product_self_attention(
|
|
| 125 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 126 |
.to(qkv.device).int()
|
| 127 |
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
|
|
|
|
|
|
|
|
|
|
|
|
| 128 |
|
| 129 |
out = out[bwd_indices] # [T, H, C]
|
| 130 |
|
|
|
|
| 1 |
from typing import *
|
| 2 |
import torch
|
| 3 |
import math
|
| 4 |
+
from torch.nn.functional import scaled_dot_product_attention as sdpa
|
| 5 |
from .. import SparseTensor
|
| 6 |
+
from .. import DEBUG, ATTN, set_attn
|
| 7 |
|
| 8 |
if ATTN == 'xformers':
|
| 9 |
+
try:
|
| 10 |
+
import xformers.ops as xops
|
| 11 |
+
except ImportError:
|
| 12 |
+
ATTN = 'sdpa'
|
| 13 |
+
set_attn(ATTN)
|
| 14 |
elif ATTN == 'flash_attn':
|
| 15 |
+
try:
|
| 16 |
+
import flash_attn
|
| 17 |
+
except ImportError:
|
| 18 |
+
ATTN = 'sdpa'
|
| 19 |
+
set_attn(ATTN)
|
| 20 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 21 |
+
pass
|
| 22 |
else:
|
| 23 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 24 |
|
|
|
|
| 28 |
]
|
| 29 |
|
| 30 |
|
| 31 |
+
def _sdpa_varlen_qkv(qkv_feats: torch.Tensor, seq_lens: List[int]) -> torch.Tensor:
|
| 32 |
+
outs = []
|
| 33 |
+
start = 0
|
| 34 |
+
for seq_len in seq_lens:
|
| 35 |
+
chunk = qkv_feats[start:start + seq_len]
|
| 36 |
+
q, k, v = chunk.unbind(dim=1)
|
| 37 |
+
q = q.permute(1, 0, 2).unsqueeze(0)
|
| 38 |
+
k = k.permute(1, 0, 2).unsqueeze(0)
|
| 39 |
+
v = v.permute(1, 0, 2).unsqueeze(0)
|
| 40 |
+
out = sdpa(q, k, v).squeeze(0).permute(1, 0, 2)
|
| 41 |
+
outs.append(out)
|
| 42 |
+
start += seq_len
|
| 43 |
+
return torch.cat(outs, dim=0) if outs else qkv_feats.new_empty((0, qkv_feats.shape[2], qkv_feats.shape[3]))
|
| 44 |
+
|
| 45 |
+
|
| 46 |
def calc_window_partition(
|
| 47 |
tensor: SparseTensor,
|
| 48 |
window_size: Union[int, Tuple[int, ...]],
|
|
|
|
| 136 |
out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
|
| 137 |
elif ATTN == 'flash_attn':
|
| 138 |
out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
|
| 139 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 140 |
+
out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
|
| 141 |
else:
|
| 142 |
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 143 |
out = out.reshape(B * N, H, C) # [M, H, C]
|
|
|
|
| 153 |
cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
|
| 154 |
.to(qkv.device).int()
|
| 155 |
out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
|
| 156 |
+
elif ATTN in {'sdpa', 'naive'}:
|
| 157 |
+
out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
|
| 158 |
+
else:
|
| 159 |
+
raise ValueError(f"Unknown attention module: {ATTN}")
|
| 160 |
|
| 161 |
out = out[bwd_indices] # [T, H, C]
|
| 162 |
|