Spaces:
Running on Zero
Running on Zero
fix: fa3 broken on Blackwell+ (#10)
Browse files- fix: fa3 broken on Blackwell+, fallback on scaled dot product attention (3faa48b0376d51437db80eb00f45fc3fecad8c68)
app.py
CHANGED
|
@@ -72,15 +72,30 @@ from ltx_pipelines.utils.helpers import (
|
|
| 72 |
)
|
| 73 |
from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
|
| 74 |
|
| 75 |
-
#
|
|
|
|
| 76 |
from ltx_core.model.transformer import attention as _attn_mod
|
| 77 |
-
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
|
| 82 |
-
|
| 83 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 84 |
|
| 85 |
logging.getLogger().setLevel(logging.INFO)
|
| 86 |
|
|
|
|
| 72 |
)
|
| 73 |
from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
|
| 74 |
|
| 75 |
+
# Patch attention backend into the LTX attention module.
|
| 76 |
+
import torch.nn.functional as F
|
| 77 |
from ltx_core.model.transformer import attention as _attn_mod
|
| 78 |
+
|
| 79 |
+
def _sdpa_as_mea(query, key, value, attn_bias=None, scale=None, **kwargs):
|
| 80 |
+
# xformers memory_efficient_attention: (B, S, H, D) -> (B, S, H, D)
|
| 81 |
+
# torch SDPA: (B, H, S, D) -> (B, H, S, D)
|
| 82 |
+
q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
|
| 83 |
+
return F.scaled_dot_product_attention(q, k, v, scale=scale).transpose(1, 2)
|
| 84 |
+
|
| 85 |
+
_cap = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
|
| 86 |
+
_use_xformers = False
|
| 87 |
+
if _cap < (12, 0):
|
| 88 |
+
try:
|
| 89 |
+
from xformers.ops import memory_efficient_attention as _mea
|
| 90 |
+
_attn_mod.memory_efficient_attention = _mea
|
| 91 |
+
_use_xformers = True
|
| 92 |
+
print(f"[ATTN] Using xformers memory_efficient_attention")
|
| 93 |
+
except Exception as e:
|
| 94 |
+
print(f"[ATTN] xformers unavailable ({e}), falling back to SDPA")
|
| 95 |
+
|
| 96 |
+
if not _use_xformers:
|
| 97 |
+
_attn_mod.memory_efficient_attention = _sdpa_as_mea
|
| 98 |
+
print(f"[ATTN] Using SDPA fallback (sm_{_cap[0]}{_cap[1]})")
|
| 99 |
|
| 100 |
logging.getLogger().setLevel(logging.INFO)
|
| 101 |
|