Bobby commited on
Commit
31860a5
·
1 Parent(s): 86e4232

Fallback TRELLIS attention backends to SDPA when flash-attn missing

Browse files
app.py CHANGED
@@ -1,5 +1,6 @@
1
  import argparse
2
  import concurrent.futures
 
3
  import os
4
  import sys
5
  import time
@@ -8,7 +9,14 @@ from typing import Any, Dict, Generator, List, Optional, Tuple
8
 
9
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
10
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
11
- os.environ.setdefault("ATTN_BACKEND", "flash_attn_3")
 
 
 
 
 
 
 
12
  os.environ.setdefault("SPCONV_ALGO", "native")
13
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
14
  os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json"
 
1
  import argparse
2
  import concurrent.futures
3
+ import importlib.util
4
  import os
5
  import sys
6
  import time
 
9
 
10
  os.environ["OPENCV_IO_ENABLE_OPENEXR"] = "1"
11
  os.environ["PYTORCH_CUDA_ALLOC_CONF"] = "expandable_segments:True"
12
+ if importlib.util.find_spec("flash_attn") is not None:
13
+ _attn_backend = "flash_attn"
14
+ elif importlib.util.find_spec("xformers") is not None:
15
+ _attn_backend = "xformers"
16
+ else:
17
+ _attn_backend = "sdpa"
18
+ os.environ.setdefault("ATTN_BACKEND", _attn_backend)
19
+ os.environ.setdefault("SPARSE_ATTN_BACKEND", _attn_backend)
20
  os.environ.setdefault("SPCONV_ALGO", "native")
21
  os.environ["FLEX_GEMM_AUTOTUNE_CACHE_PATH"] = os.path.join(
22
  os.path.dirname(os.path.abspath(__file__)), "autotune_cache.json"
trellis/modules/attention/__init__.py CHANGED
@@ -1,6 +1,6 @@
1
  from typing import *
2
 
3
- BACKEND = 'flash_attn'
4
  DEBUG = False
5
 
6
  def __from_env():
@@ -12,6 +12,9 @@ def __from_env():
12
  env_attn_backend = os.environ.get('ATTN_BACKEND')
13
  env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
 
 
 
 
15
  if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
16
  BACKEND = env_attn_backend
17
  if env_sttn_debug is not None:
@@ -25,7 +28,7 @@ def __from_env():
25
  __from_env()
26
 
27
 
28
- def set_backend(backend: Literal['xformers', 'flash_attn']):
29
  global BACKEND
30
  BACKEND = backend
31
 
 
1
  from typing import *
2
 
3
+ BACKEND = 'sdpa'
4
  DEBUG = False
5
 
6
  def __from_env():
 
12
  env_attn_backend = os.environ.get('ATTN_BACKEND')
13
  env_sttn_debug = os.environ.get('ATTN_DEBUG')
14
 
15
+ if env_attn_backend == 'flash_attn_3':
16
+ env_attn_backend = 'flash_attn'
17
+
18
  if env_attn_backend is not None and env_attn_backend in ['xformers', 'flash_attn', 'sdpa', 'naive']:
19
  BACKEND = env_attn_backend
20
  if env_sttn_debug is not None:
 
28
  __from_env()
29
 
30
 
31
+ def set_backend(backend: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
32
  global BACKEND
33
  BACKEND = backend
34
 
trellis/modules/attention/full_attn.py CHANGED
@@ -1,12 +1,22 @@
1
  from typing import *
2
  import torch
3
  import math
4
- from . import DEBUG, BACKEND
5
 
6
  if BACKEND == 'xformers':
7
- import xformers.ops as xops
 
 
 
 
 
8
  elif BACKEND == 'flash_attn':
9
- import flash_attn
 
 
 
 
 
10
  elif BACKEND == 'sdpa':
11
  from torch.nn.functional import scaled_dot_product_attention as sdpa
12
  elif BACKEND == 'naive':
 
1
  from typing import *
2
  import torch
3
  import math
4
+ from . import DEBUG, BACKEND, set_backend
5
 
6
  if BACKEND == 'xformers':
7
+ try:
8
+ import xformers.ops as xops
9
+ except ImportError:
10
+ BACKEND = 'sdpa'
11
+ set_backend(BACKEND)
12
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
13
  elif BACKEND == 'flash_attn':
14
+ try:
15
+ import flash_attn
16
+ except ImportError:
17
+ BACKEND = 'sdpa'
18
+ set_backend(BACKEND)
19
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
20
  elif BACKEND == 'sdpa':
21
  from torch.nn.functional import scaled_dot_product_attention as sdpa
22
  elif BACKEND == 'naive':
trellis/modules/sparse/__init__.py CHANGED
@@ -2,7 +2,7 @@ from typing import *
2
 
3
  BACKEND = 'spconv'
4
  DEBUG = False
5
- ATTN = 'flash_attn'
6
 
7
  def __from_env():
8
  import os
@@ -21,7 +21,10 @@ def __from_env():
21
  BACKEND = env_sparse_backend
22
  if env_sparse_debug is not None:
23
  DEBUG = env_sparse_debug == '1'
24
- if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn']:
 
 
 
25
  ATTN = env_sparse_attn
26
 
27
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
@@ -38,7 +41,7 @@ def set_debug(debug: bool):
38
  global DEBUG
39
  DEBUG = debug
40
 
41
- def set_attn(attn: Literal['xformers', 'flash_attn']):
42
  global ATTN
43
  ATTN = attn
44
 
 
2
 
3
  BACKEND = 'spconv'
4
  DEBUG = False
5
+ ATTN = 'sdpa'
6
 
7
  def __from_env():
8
  import os
 
21
  BACKEND = env_sparse_backend
22
  if env_sparse_debug is not None:
23
  DEBUG = env_sparse_debug == '1'
24
+ if env_sparse_attn == 'flash_attn_3':
25
+ env_sparse_attn = 'flash_attn'
26
+
27
+ if env_sparse_attn is not None and env_sparse_attn in ['xformers', 'flash_attn', 'sdpa', 'naive']:
28
  ATTN = env_sparse_attn
29
 
30
  print(f"[SPARSE] Backend: {BACKEND}, Attention: {ATTN}")
 
41
  global DEBUG
42
  DEBUG = debug
43
 
44
+ def set_attn(attn: Literal['xformers', 'flash_attn', 'sdpa', 'naive']):
45
  global ATTN
46
  ATTN = attn
47
 
trellis/modules/sparse/attention/full_attn.py CHANGED
@@ -1,12 +1,23 @@
1
  from typing import *
2
  import torch
 
3
  from .. import SparseTensor
4
- from .. import DEBUG, ATTN
5
 
6
  if ATTN == 'xformers':
7
- import xformers.ops as xops
 
 
 
 
8
  elif ATTN == 'flash_attn':
9
- import flash_attn
 
 
 
 
 
 
10
  else:
11
  raise ValueError(f"Unknown attention module: {ATTN}")
12
 
@@ -16,6 +27,14 @@ __all__ = [
16
  ]
17
 
18
 
 
 
 
 
 
 
 
 
19
  @overload
20
  def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
21
  """
@@ -206,6 +225,31 @@ def sparse_scaled_dot_product_attention(*args, **kwargs):
206
  out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
207
  elif num_all_args == 3:
208
  out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
209
  else:
210
  raise ValueError(f"Unknown attention module: {ATTN}")
211
 
 
1
  from typing import *
2
  import torch
3
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
4
  from .. import SparseTensor
5
+ from .. import DEBUG, ATTN, set_attn
6
 
7
  if ATTN == 'xformers':
8
+ try:
9
+ import xformers.ops as xops
10
+ except ImportError:
11
+ ATTN = 'sdpa'
12
+ set_attn(ATTN)
13
  elif ATTN == 'flash_attn':
14
+ try:
15
+ import flash_attn
16
+ except ImportError:
17
+ ATTN = 'sdpa'
18
+ set_attn(ATTN)
19
+ elif ATTN in {'sdpa', 'naive'}:
20
+ pass
21
  else:
22
  raise ValueError(f"Unknown attention module: {ATTN}")
23
 
 
27
  ]
28
 
29
 
30
+ def _sdpa_chunk(q: torch.Tensor, k: torch.Tensor, v: torch.Tensor) -> torch.Tensor:
31
+ q = q.permute(1, 0, 2).unsqueeze(0) # [1, H, Lq, C]
32
+ k = k.permute(1, 0, 2).unsqueeze(0) # [1, H, Lk, C]
33
+ v = v.permute(1, 0, 2).unsqueeze(0) # [1, H, Lk, C]
34
+ out = sdpa(q, k, v)
35
+ return out.squeeze(0).permute(1, 0, 2) # [Lq, H, C]
36
+
37
+
38
  @overload
39
  def sparse_scaled_dot_product_attention(qkv: SparseTensor) -> SparseTensor:
40
  """
 
225
  out = flash_attn.flash_attn_varlen_kvpacked_func(q, kv, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
226
  elif num_all_args == 3:
227
  out = flash_attn.flash_attn_varlen_func(q, k, v, cu_seqlens_q, cu_seqlens_kv, max(q_seqlen), max(kv_seqlen))
228
+ elif ATTN in {'sdpa', 'naive'}:
229
+ outs = []
230
+ q_start, kv_start = 0, 0
231
+ for q_len, kv_len in zip(q_seqlen, kv_seqlen):
232
+ if num_all_args == 1:
233
+ qkv_chunk = qkv[q_start:q_start + q_len]
234
+ q_i, k_i, v_i = qkv_chunk.unbind(dim=1)
235
+ elif num_all_args == 2:
236
+ q_i = q[q_start:q_start + q_len]
237
+ kv_chunk = kv[kv_start:kv_start + kv_len]
238
+ k_i, v_i = kv_chunk.unbind(dim=1)
239
+ else:
240
+ q_i = q[q_start:q_start + q_len]
241
+ k_i = k[kv_start:kv_start + kv_len]
242
+ v_i = v[kv_start:kv_start + kv_len]
243
+
244
+ outs.append(_sdpa_chunk(q_i, k_i, v_i))
245
+ q_start += q_len
246
+ kv_start += kv_len
247
+ if outs:
248
+ out = torch.cat(outs, dim=0)
249
+ elif num_all_args == 1:
250
+ out = torch.empty((0, qkv.shape[-2], qkv.shape[-1]), device=device, dtype=qkv.dtype)
251
+ else:
252
+ out = torch.empty((0, q.shape[-2], q.shape[-1]), device=device, dtype=q.dtype)
253
  else:
254
  raise ValueError(f"Unknown attention module: {ATTN}")
255
 
trellis/modules/sparse/attention/serialized_attn.py CHANGED
@@ -2,13 +2,24 @@ from typing import *
2
  from enum import Enum
3
  import torch
4
  import math
 
5
  from .. import SparseTensor
6
- from .. import DEBUG, ATTN
7
 
8
  if ATTN == 'xformers':
9
- import xformers.ops as xops
 
 
 
 
10
  elif ATTN == 'flash_attn':
11
- import flash_attn
 
 
 
 
 
 
12
  else:
13
  raise ValueError(f"Unknown attention module: {ATTN}")
14
 
@@ -18,6 +29,21 @@ __all__ = [
18
  ]
19
 
20
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  class SerializeMode(Enum):
22
  Z_ORDER = 0
23
  Z_ORDER_TRANSPOSED = 1
@@ -168,6 +194,8 @@ def sparse_serialized_scaled_dot_product_self_attention(
168
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
169
  elif ATTN == 'flash_attn':
170
  out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
 
 
171
  else:
172
  raise ValueError(f"Unknown attention module: {ATTN}")
173
  out = out.reshape(B * N, H, C) # [M, H, C]
@@ -183,6 +211,10 @@ def sparse_serialized_scaled_dot_product_self_attention(
183
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
184
  .to(qkv.device).int()
185
  out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
 
 
 
 
186
 
187
  out = out[bwd_indices] # [T, H, C]
188
 
 
2
  from enum import Enum
3
  import torch
4
  import math
5
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
6
  from .. import SparseTensor
7
+ from .. import DEBUG, ATTN, set_attn
8
 
9
  if ATTN == 'xformers':
10
+ try:
11
+ import xformers.ops as xops
12
+ except ImportError:
13
+ ATTN = 'sdpa'
14
+ set_attn(ATTN)
15
  elif ATTN == 'flash_attn':
16
+ try:
17
+ import flash_attn
18
+ except ImportError:
19
+ ATTN = 'sdpa'
20
+ set_attn(ATTN)
21
+ elif ATTN in {'sdpa', 'naive'}:
22
+ pass
23
  else:
24
  raise ValueError(f"Unknown attention module: {ATTN}")
25
 
 
29
  ]
30
 
31
 
32
+ def _sdpa_varlen_qkv(qkv_feats: torch.Tensor, seq_lens: List[int]) -> torch.Tensor:
33
+ outs = []
34
+ start = 0
35
+ for seq_len in seq_lens:
36
+ chunk = qkv_feats[start:start + seq_len]
37
+ q, k, v = chunk.unbind(dim=1)
38
+ q = q.permute(1, 0, 2).unsqueeze(0)
39
+ k = k.permute(1, 0, 2).unsqueeze(0)
40
+ v = v.permute(1, 0, 2).unsqueeze(0)
41
+ out = sdpa(q, k, v).squeeze(0).permute(1, 0, 2)
42
+ outs.append(out)
43
+ start += seq_len
44
+ return torch.cat(outs, dim=0) if outs else qkv_feats.new_empty((0, qkv_feats.shape[2], qkv_feats.shape[3]))
45
+
46
+
47
  class SerializeMode(Enum):
48
  Z_ORDER = 0
49
  Z_ORDER_TRANSPOSED = 1
 
194
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
195
  elif ATTN == 'flash_attn':
196
  out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
197
+ elif ATTN in {'sdpa', 'naive'}:
198
+ out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
199
  else:
200
  raise ValueError(f"Unknown attention module: {ATTN}")
201
  out = out.reshape(B * N, H, C) # [M, H, C]
 
211
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
212
  .to(qkv.device).int()
213
  out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
214
+ elif ATTN in {'sdpa', 'naive'}:
215
+ out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
216
+ else:
217
+ raise ValueError(f"Unknown attention module: {ATTN}")
218
 
219
  out = out[bwd_indices] # [T, H, C]
220
 
trellis/modules/sparse/attention/windowed_attn.py CHANGED
@@ -1,13 +1,24 @@
1
  from typing import *
2
  import torch
3
  import math
 
4
  from .. import SparseTensor
5
- from .. import DEBUG, ATTN
6
 
7
  if ATTN == 'xformers':
8
- import xformers.ops as xops
 
 
 
 
9
  elif ATTN == 'flash_attn':
10
- import flash_attn
 
 
 
 
 
 
11
  else:
12
  raise ValueError(f"Unknown attention module: {ATTN}")
13
 
@@ -17,6 +28,21 @@ __all__ = [
17
  ]
18
 
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  def calc_window_partition(
21
  tensor: SparseTensor,
22
  window_size: Union[int, Tuple[int, ...]],
@@ -110,6 +136,8 @@ def sparse_windowed_scaled_dot_product_self_attention(
110
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
111
  elif ATTN == 'flash_attn':
112
  out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
 
 
113
  else:
114
  raise ValueError(f"Unknown attention module: {ATTN}")
115
  out = out.reshape(B * N, H, C) # [M, H, C]
@@ -125,6 +153,10 @@ def sparse_windowed_scaled_dot_product_self_attention(
125
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
126
  .to(qkv.device).int()
127
  out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
 
 
 
 
128
 
129
  out = out[bwd_indices] # [T, H, C]
130
 
 
1
  from typing import *
2
  import torch
3
  import math
4
+ from torch.nn.functional import scaled_dot_product_attention as sdpa
5
  from .. import SparseTensor
6
+ from .. import DEBUG, ATTN, set_attn
7
 
8
  if ATTN == 'xformers':
9
+ try:
10
+ import xformers.ops as xops
11
+ except ImportError:
12
+ ATTN = 'sdpa'
13
+ set_attn(ATTN)
14
  elif ATTN == 'flash_attn':
15
+ try:
16
+ import flash_attn
17
+ except ImportError:
18
+ ATTN = 'sdpa'
19
+ set_attn(ATTN)
20
+ elif ATTN in {'sdpa', 'naive'}:
21
+ pass
22
  else:
23
  raise ValueError(f"Unknown attention module: {ATTN}")
24
 
 
28
  ]
29
 
30
 
31
+ def _sdpa_varlen_qkv(qkv_feats: torch.Tensor, seq_lens: List[int]) -> torch.Tensor:
32
+ outs = []
33
+ start = 0
34
+ for seq_len in seq_lens:
35
+ chunk = qkv_feats[start:start + seq_len]
36
+ q, k, v = chunk.unbind(dim=1)
37
+ q = q.permute(1, 0, 2).unsqueeze(0)
38
+ k = k.permute(1, 0, 2).unsqueeze(0)
39
+ v = v.permute(1, 0, 2).unsqueeze(0)
40
+ out = sdpa(q, k, v).squeeze(0).permute(1, 0, 2)
41
+ outs.append(out)
42
+ start += seq_len
43
+ return torch.cat(outs, dim=0) if outs else qkv_feats.new_empty((0, qkv_feats.shape[2], qkv_feats.shape[3]))
44
+
45
+
46
  def calc_window_partition(
47
  tensor: SparseTensor,
48
  window_size: Union[int, Tuple[int, ...]],
 
136
  out = xops.memory_efficient_attention(q, k, v) # [B, N, H, C]
137
  elif ATTN == 'flash_attn':
138
  out = flash_attn.flash_attn_qkvpacked_func(qkv_feats) # [B, N, H, C]
139
+ elif ATTN in {'sdpa', 'naive'}:
140
+ out = _sdpa_varlen_qkv(qkv_feats.reshape(B * N, 3, H, C), [N] * B)
141
  else:
142
  raise ValueError(f"Unknown attention module: {ATTN}")
143
  out = out.reshape(B * N, H, C) # [M, H, C]
 
153
  cu_seqlens = torch.cat([torch.tensor([0]), torch.cumsum(torch.tensor(seq_lens), dim=0)], dim=0) \
154
  .to(qkv.device).int()
155
  out = flash_attn.flash_attn_varlen_qkvpacked_func(qkv_feats, cu_seqlens, max(seq_lens)) # [M, H, C]
156
+ elif ATTN in {'sdpa', 'naive'}:
157
+ out = _sdpa_varlen_qkv(qkv_feats, seq_lens)
158
+ else:
159
+ raise ValueError(f"Unknown attention module: {ATTN}")
160
 
161
  out = out[bwd_indices] # [T, H, C]
162