linoyts HF Staff commited on
Commit
0c5fe0e
·
1 Parent(s): 86682c5

fix: fa3 broken on Blackwell+ (#10)

Browse files

- fix: fa3 broken on Blackwell+, fallback on scaled dot product attention (3faa48b0376d51437db80eb00f45fc3fecad8c68)

Files changed (1) hide show
  1. app.py +23 -8
app.py CHANGED
@@ -72,15 +72,30 @@ from ltx_pipelines.utils.helpers import (
72
  )
73
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
74
 
75
- # Force-patch xformers attention into the LTX attention module.
 
76
  from ltx_core.model.transformer import attention as _attn_mod
77
- print(f"[ATTN] Before patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
78
- try:
79
- from xformers.ops import memory_efficient_attention as _mea
80
- _attn_mod.memory_efficient_attention = _mea
81
- print(f"[ATTN] After patch: memory_efficient_attention={_attn_mod.memory_efficient_attention}")
82
- except Exception as e:
83
- print(f"[ATTN] xformers patch FAILED: {type(e).__name__}: {e}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
  logging.getLogger().setLevel(logging.INFO)
86
 
 
72
  )
73
  from ltx_pipelines.utils.media_io import decode_audio_from_file, encode_video
74
 
75
+ # Patch attention backend into the LTX attention module.
76
+ import torch.nn.functional as F
77
  from ltx_core.model.transformer import attention as _attn_mod
78
+
79
+ def _sdpa_as_mea(query, key, value, attn_bias=None, scale=None, **kwargs):
80
+ # xformers memory_efficient_attention: (B, S, H, D) -> (B, S, H, D)
81
+ # torch SDPA: (B, H, S, D) -> (B, H, S, D)
82
+ q, k, v = query.transpose(1, 2), key.transpose(1, 2), value.transpose(1, 2)
83
+ return F.scaled_dot_product_attention(q, k, v, scale=scale).transpose(1, 2)
84
+
85
+ _cap = torch.cuda.get_device_capability() if torch.cuda.is_available() else (0, 0)
86
+ _use_xformers = False
87
+ if _cap < (12, 0):
88
+ try:
89
+ from xformers.ops import memory_efficient_attention as _mea
90
+ _attn_mod.memory_efficient_attention = _mea
91
+ _use_xformers = True
92
+ print(f"[ATTN] Using xformers memory_efficient_attention")
93
+ except Exception as e:
94
+ print(f"[ATTN] xformers unavailable ({e}), falling back to SDPA")
95
+
96
+ if not _use_xformers:
97
+ _attn_mod.memory_efficient_attention = _sdpa_as_mea
98
+ print(f"[ATTN] Using SDPA fallback (sm_{_cap[0]}{_cap[1]})")
99
 
100
  logging.getLogger().setLevel(logging.INFO)
101