prithivMLmods commited on
Commit
3661ca3
·
1 Parent(s): 1aa5d2f

✅ Temporary fixes for FA3, which is currently broken; fallback to SDPA. (#2)

Browse files

- ✅ Temporary fixes for FA3, which is currently broken; fallback to SDPA. (255c75283331ec54cbc2afa4a94afb932a549e08)

Files changed (1) hide show
  1. qwenimage/qwen_fa3_processor.py +151 -60
qwenimage/qwen_fa3_processor.py CHANGED
@@ -1,89 +1,183 @@
1
  """
2
  Paired with a good language model. Thanks!
 
 
 
 
3
  """
4
 
5
  import torch
 
6
  from typing import Optional, Tuple
7
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
8
 
9
- try:
10
- from kernels import get_kernel
11
- _k = get_kernel("kernels-community/vllm-flash-attn3")
12
- _flash_attn_func = _k.flash_attn_func
13
- except Exception as e:
14
- _flash_attn_func = None
15
- _kernels_err = e
16
-
17
-
18
- def _ensure_fa3_available():
19
- if _flash_attn_func is None:
20
- raise ImportError(
21
- "FlashAttention-3 via Hugging Face `kernels` is required. "
22
- "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n"
23
- f"{_kernels_err}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
24
  )
25
 
26
- @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
27
- def flash_attn_func(
28
- q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
29
  ) -> torch.Tensor:
30
- outputs, lse = _flash_attn_func(q, k, v, causal=causal)
31
- return outputs
 
 
 
 
 
 
 
32
 
33
- @flash_attn_func.register_fake
34
- def _(q, k, v, **kwargs):
35
- # two outputs:
36
- # 1. output: (batch, seq_len, num_heads, head_dim)
37
- # 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
38
- meta_q = torch.empty_like(q).contiguous()
39
- return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
40
 
 
 
 
 
 
 
 
41
 
42
  class QwenDoubleStreamAttnProcessorFA3:
43
  """
44
- FA3-based attention processor for Qwen double-stream architecture.
45
- Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3
46
- accessed via Hugging Face `kernels`.
47
-
48
- Notes / limitations:
49
- - General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask.
50
- - Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features.
51
- - Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor).
 
 
 
 
 
 
52
  """
53
 
54
- _attention_backend = "fa3" # for parity with your other processors, not used internally
55
 
56
  def __init__(self):
57
- _ensure_fa3_available()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
58
 
59
  @torch.no_grad()
60
  def __call__(
61
  self,
62
- attn, # Attention module with to_q/to_k/to_v/add_*_proj, norms, to_out, to_add_out, and .heads
63
- hidden_states: torch.FloatTensor, # (B, S_img, D_model) image stream
64
- encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model) text stream
65
- encoder_hidden_states_mask: torch.FloatTensor = None, # unused in FA3 path
66
- attention_mask: Optional[torch.FloatTensor] = None, # unused in FA3 path
67
- image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, # (img_freqs, txt_freqs)
68
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
69
- if encoder_hidden_states is None:
70
- raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
71
- if attention_mask is not None:
72
- # FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
73
- raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")
74
 
75
- _ensure_fa3_available()
 
 
 
 
 
 
 
 
76
 
77
  B, S_img, _ = hidden_states.shape
78
  S_txt = encoder_hidden_states.shape[1]
79
 
80
- # ---- QKV projections (image/sample stream) ----
81
- img_q = attn.to_q(hidden_states) # (B, S_img, D)
82
  img_k = attn.to_k(hidden_states)
83
  img_v = attn.to_v(hidden_states)
84
 
85
- # ---- QKV projections (text/context stream) ----
86
- txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D)
87
  txt_k = attn.add_k_proj(encoder_hidden_states)
88
  txt_v = attn.add_v_proj(encoder_hidden_states)
89
 
@@ -97,7 +191,7 @@ class QwenDoubleStreamAttnProcessorFA3:
97
  txt_k = txt_k.unflatten(-1, (H, -1))
98
  txt_v = txt_v.unflatten(-1, (H, -1))
99
 
100
- # ---- Q/K normalization (per your module contract) ----
101
  if getattr(attn, "norm_q", None) is not None:
102
  img_q = attn.norm_q(img_q)
103
  if getattr(attn, "norm_k", None) is not None:
@@ -110,25 +204,22 @@ class QwenDoubleStreamAttnProcessorFA3:
110
  # ---- RoPE (Qwen variant) ----
111
  if image_rotary_emb is not None:
112
  img_freqs, txt_freqs = image_rotary_emb
113
- # expects tensors shaped (B, S, H, D_h)
114
  img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
115
  img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
116
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
117
  txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
118
 
119
  # ---- Joint attention over [text, image] along sequence axis ----
120
- # Shapes: (B, S_total, H, D_h)
121
- q = torch.cat([txt_q, img_q], dim=1)
122
  k = torch.cat([txt_k, img_k], dim=1)
123
  v = torch.cat([txt_v, img_v], dim=1)
124
 
125
- # FlashAttention-3 path expects (B, S, H, D_h) and returns (out, softmax_lse)
126
- out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h)
127
 
128
  # ---- Back to (B, S, D_model) ----
129
  out = out.flatten(2, 3).to(q.dtype)
130
 
131
- # Split back to text / image segments
132
  txt_attn_out = out[:, :S_txt, :]
133
  img_attn_out = out[:, S_txt:, :]
134
 
 
1
  """
2
  Paired with a good language model. Thanks!
3
+
4
+ FA3 is currently broken on Blackwell (sm_100) GPUs; this module detects that
5
+ at import time and falls back to PyTorch scaled-dot-product attention (SDPA)
6
+ automatically. The public class name / call signature are unchanged.
7
  """
8
 
9
  import torch
10
+ import torch.nn.functional as F
11
  from typing import Optional, Tuple
12
  from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
13
 
14
+
15
+ # ---------------------------------------------------------------------------
16
+ # FA3 availability check
17
+ # ---------------------------------------------------------------------------
18
+
19
+ def _is_blackwell() -> bool:
20
+ """Return True when the current default CUDA device is an sm_100 (Blackwell) GPU."""
21
+ if not torch.cuda.is_available():
22
+ return False
23
+ cap = torch.cuda.get_device_capability()
24
+ # Blackwell compute capability 10.x (sm_100)
25
+ return cap[0] >= 10
26
+
27
+
28
+ _fa3_available: bool = False
29
+ _fa3_unavailable_reason: str = ""
30
+ _flash_attn_func = None
31
+
32
+ if _is_blackwell():
33
+ _fa3_unavailable_reason = (
34
+ "FlashAttention-3 is not yet supported on Blackwell (sm_100) GPUs. "
35
+ "Falling back to scaled-dot-product attention (SDPA)."
36
+ )
37
+ else:
38
+ try:
39
+ from kernels import get_kernel
40
+ _k = get_kernel("kernels-community/vllm-flash-attn3")
41
+ _flash_attn_func = _k.flash_attn_func
42
+ _fa3_available = True
43
+ except Exception as e:
44
+ _fa3_unavailable_reason = (
45
+ "FlashAttention-3 via Hugging Face `kernels` is unavailable. "
46
+ f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{e}\n"
47
+ "Falling back to scaled-dot-product attention (SDPA)."
48
  )
49
 
50
+
51
+ # ---------------------------------------------------------------------------
52
+ # FA3 custom op (registered only when the kernel loaded successfully)
53
+ # ---------------------------------------------------------------------------
54
+
55
+ if _fa3_available:
56
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
57
+ def flash_attn_func(
58
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
59
+ ) -> torch.Tensor:
60
+ # _flash_attn_func returns (output, softmax_lse); we only need output.
61
+ output, _lse = _flash_attn_func(q, k, v, causal=causal)
62
+ return output
63
+
64
+ @flash_attn_func.register_fake
65
+ def _flash_attn_func_fake(q, k, v, causal=False):
66
+ # output shape mirrors q: (batch, seq_len, num_heads, head_dim)
67
+ return torch.empty_like(q).contiguous()
68
+
69
+ else:
70
+ # Provide a stub so call-sites that import the symbol don't break at
71
+ # module load; the processor will route around it at runtime.
72
+ def flash_attn_func(
73
+ q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
74
+ ) -> torch.Tensor:
75
+ raise RuntimeError(_fa3_unavailable_reason)
76
+
77
+
78
+ # ---------------------------------------------------------------------------
79
+ # SDPA fallback helper
80
+ # ---------------------------------------------------------------------------
81
+
82
+ def _sdpa_attention(
83
+ q: torch.Tensor,
84
+ k: torch.Tensor,
85
+ v: torch.Tensor,
86
+ causal: bool = False,
87
  ) -> torch.Tensor:
88
+ """
89
+ Scaled dot-product attention using torch.nn.functional.scaled_dot_product_attention.
90
+
91
+ Input / output layout: (B, S, H, D_h) — same as the FA3 kernel.
92
+ """
93
+ # SDPA expects (B, H, S, D_h)
94
+ q = q.transpose(1, 2)
95
+ k = k.transpose(1, 2)
96
+ v = v.transpose(1, 2)
97
 
98
+ out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
 
 
 
 
 
 
99
 
100
+ # Back to (B, S, H, D_h)
101
+ return out.transpose(1, 2)
102
+
103
+
104
+ # ---------------------------------------------------------------------------
105
+ # Attention processor
106
+ # ---------------------------------------------------------------------------
107
 
108
  class QwenDoubleStreamAttnProcessorFA3:
109
  """
110
+ Attention processor for the Qwen double-stream architecture.
111
+
112
+ Preferred backend: vLLM FlashAttention-3 via Hugging Face ``kernels``.
113
+ Automatic fallback: PyTorch ``scaled_dot_product_attention`` (SDPA) when
114
+ FA3 is unavailable — e.g. on Blackwell (sm_100) GPUs where FA3 is not yet
115
+ supported, or when the ``kernels`` package is absent.
116
+
117
+ Notes / limitations
118
+ -------------------
119
+ - Arbitrary attention masks are not supported on the FA3 path. Pass
120
+ ``attention_mask=None`` (the default) to stay on the fast path.
121
+ - On the SDPA path, ``attention_mask`` is likewise ignored; add explicit
122
+ support here if you need it.
123
+ - ``encoder_hidden_states`` (text stream) is required.
124
  """
125
 
126
+ _attention_backend: str # set in __init__ after capability detection
127
 
128
  def __init__(self):
129
+ if _fa3_available:
130
+ self._attention_backend = "fa3"
131
+ else:
132
+ import warnings
133
+ warnings.warn(
134
+ f"QwenDoubleStreamAttnProcessorFA3: {_fa3_unavailable_reason}",
135
+ stacklevel=2,
136
+ )
137
+ self._attention_backend = "sdpa"
138
+
139
+ def _attend(
140
+ self,
141
+ q: torch.Tensor,
142
+ k: torch.Tensor,
143
+ v: torch.Tensor,
144
+ causal: bool = False,
145
+ ) -> torch.Tensor:
146
+ """Dispatch to FA3 or SDPA depending on what is available."""
147
+ if self._attention_backend == "fa3":
148
+ return flash_attn_func(q, k, v, causal=causal)
149
+ return _sdpa_attention(q, k, v, causal=causal)
150
 
151
  @torch.no_grad()
152
  def __call__(
153
  self,
154
+ attn,
155
+ hidden_states: torch.FloatTensor, # (B, S_img, D_model)
156
+ encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model)
157
+ encoder_hidden_states_mask: torch.FloatTensor = None, # unused
158
+ attention_mask: Optional[torch.FloatTensor] = None, # unsupported on FA3 path
159
+ image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
160
  ) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
 
 
 
 
 
161
 
162
+ if encoder_hidden_states is None:
163
+ raise ValueError(
164
+ "QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream)."
165
+ )
166
+ if attention_mask is not None and self._attention_backend == "fa3":
167
+ raise NotImplementedError(
168
+ "attention_mask is not supported on the FA3 path. "
169
+ "Either drop the mask or let the processor fall back to SDPA."
170
+ )
171
 
172
  B, S_img, _ = hidden_states.shape
173
  S_txt = encoder_hidden_states.shape[1]
174
 
175
+ # ---- QKV projections ----
176
+ img_q = attn.to_q(hidden_states)
177
  img_k = attn.to_k(hidden_states)
178
  img_v = attn.to_v(hidden_states)
179
 
180
+ txt_q = attn.add_q_proj(encoder_hidden_states)
 
181
  txt_k = attn.add_k_proj(encoder_hidden_states)
182
  txt_v = attn.add_v_proj(encoder_hidden_states)
183
 
 
191
  txt_k = txt_k.unflatten(-1, (H, -1))
192
  txt_v = txt_v.unflatten(-1, (H, -1))
193
 
194
+ # ---- Q/K normalization ----
195
  if getattr(attn, "norm_q", None) is not None:
196
  img_q = attn.norm_q(img_q)
197
  if getattr(attn, "norm_k", None) is not None:
 
204
  # ---- RoPE (Qwen variant) ----
205
  if image_rotary_emb is not None:
206
  img_freqs, txt_freqs = image_rotary_emb
 
207
  img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
208
  img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
209
  txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
210
  txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
211
 
212
  # ---- Joint attention over [text, image] along sequence axis ----
213
+ q = torch.cat([txt_q, img_q], dim=1) # (B, S_txt + S_img, H, D_h)
 
214
  k = torch.cat([txt_k, img_k], dim=1)
215
  v = torch.cat([txt_v, img_v], dim=1)
216
 
217
+ out = self._attend(q, k, v, causal=False) # (B, S_total, H, D_h)
 
218
 
219
  # ---- Back to (B, S, D_model) ----
220
  out = out.flatten(2, 3).to(q.dtype)
221
 
222
+ # ---- Split text / image segments ----
223
  txt_attn_out = out[:, :S_txt, :]
224
  img_attn_out = out[:, S_txt:, :]
225