Spaces:
Running on Zero
Running on Zero
Commit ·
3661ca3
1
Parent(s): 1aa5d2f
✅ Temporary fixes for FA3, which is currently broken; fallback to SDPA. (#2)
Browse files- ✅ Temporary fixes for FA3, which is currently broken; fallback to SDPA. (255c75283331ec54cbc2afa4a94afb932a549e08)
- qwenimage/qwen_fa3_processor.py +151 -60
qwenimage/qwen_fa3_processor.py
CHANGED
|
@@ -1,89 +1,183 @@
|
|
| 1 |
"""
|
| 2 |
Paired with a good language model. Thanks!
|
|
|
|
|
|
|
|
|
|
|
|
|
| 3 |
"""
|
| 4 |
|
| 5 |
import torch
|
|
|
|
| 6 |
from typing import Optional, Tuple
|
| 7 |
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
|
| 12 |
-
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
)
|
| 25 |
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 29 |
) -> torch.Tensor:
|
| 30 |
-
|
| 31 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
|
| 33 |
-
|
| 34 |
-
def _(q, k, v, **kwargs):
|
| 35 |
-
# two outputs:
|
| 36 |
-
# 1. output: (batch, seq_len, num_heads, head_dim)
|
| 37 |
-
# 2. softmax_lse: (batch, num_heads, seq_len) with dtype=torch.float32
|
| 38 |
-
meta_q = torch.empty_like(q).contiguous()
|
| 39 |
-
return meta_q #, q.new_empty((q.size(0), q.size(2), q.size(1)), dtype=torch.float32)
|
| 40 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
class QwenDoubleStreamAttnProcessorFA3:
|
| 43 |
"""
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
"""
|
| 53 |
|
| 54 |
-
_attention_backend
|
| 55 |
|
| 56 |
def __init__(self):
|
| 57 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 58 |
|
| 59 |
@torch.no_grad()
|
| 60 |
def __call__(
|
| 61 |
self,
|
| 62 |
-
attn,
|
| 63 |
-
hidden_states: torch.FloatTensor,
|
| 64 |
-
encoder_hidden_states: torch.FloatTensor = None,
|
| 65 |
-
encoder_hidden_states_mask: torch.FloatTensor = None,
|
| 66 |
-
attention_mask: Optional[torch.FloatTensor] = None,
|
| 67 |
-
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 68 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
| 69 |
-
if encoder_hidden_states is None:
|
| 70 |
-
raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).")
|
| 71 |
-
if attention_mask is not None:
|
| 72 |
-
# FA3 kernel path here does not consume arbitrary masks; fail fast to avoid silent correctness issues.
|
| 73 |
-
raise NotImplementedError("attention_mask is not supported in this FA3 implementation.")
|
| 74 |
|
| 75 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 76 |
|
| 77 |
B, S_img, _ = hidden_states.shape
|
| 78 |
S_txt = encoder_hidden_states.shape[1]
|
| 79 |
|
| 80 |
-
# ---- QKV projections
|
| 81 |
-
img_q = attn.to_q(hidden_states)
|
| 82 |
img_k = attn.to_k(hidden_states)
|
| 83 |
img_v = attn.to_v(hidden_states)
|
| 84 |
|
| 85 |
-
|
| 86 |
-
txt_q = attn.add_q_proj(encoder_hidden_states) # (B, S_txt, D)
|
| 87 |
txt_k = attn.add_k_proj(encoder_hidden_states)
|
| 88 |
txt_v = attn.add_v_proj(encoder_hidden_states)
|
| 89 |
|
|
@@ -97,7 +191,7 @@ class QwenDoubleStreamAttnProcessorFA3:
|
|
| 97 |
txt_k = txt_k.unflatten(-1, (H, -1))
|
| 98 |
txt_v = txt_v.unflatten(-1, (H, -1))
|
| 99 |
|
| 100 |
-
# ---- Q/K normalization
|
| 101 |
if getattr(attn, "norm_q", None) is not None:
|
| 102 |
img_q = attn.norm_q(img_q)
|
| 103 |
if getattr(attn, "norm_k", None) is not None:
|
|
@@ -110,25 +204,22 @@ class QwenDoubleStreamAttnProcessorFA3:
|
|
| 110 |
# ---- RoPE (Qwen variant) ----
|
| 111 |
if image_rotary_emb is not None:
|
| 112 |
img_freqs, txt_freqs = image_rotary_emb
|
| 113 |
-
# expects tensors shaped (B, S, H, D_h)
|
| 114 |
img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
|
| 115 |
img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
|
| 116 |
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
|
| 117 |
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
|
| 118 |
|
| 119 |
# ---- Joint attention over [text, image] along sequence axis ----
|
| 120 |
-
|
| 121 |
-
q = torch.cat([txt_q, img_q], dim=1)
|
| 122 |
k = torch.cat([txt_k, img_k], dim=1)
|
| 123 |
v = torch.cat([txt_v, img_v], dim=1)
|
| 124 |
|
| 125 |
-
|
| 126 |
-
out = flash_attn_func(q, k, v, causal=False) # out: (B, S_total, H, D_h)
|
| 127 |
|
| 128 |
# ---- Back to (B, S, D_model) ----
|
| 129 |
out = out.flatten(2, 3).to(q.dtype)
|
| 130 |
|
| 131 |
-
# Split
|
| 132 |
txt_attn_out = out[:, :S_txt, :]
|
| 133 |
img_attn_out = out[:, S_txt:, :]
|
| 134 |
|
|
|
|
| 1 |
"""
|
| 2 |
Paired with a good language model. Thanks!
|
| 3 |
+
|
| 4 |
+
FA3 is currently broken on Blackwell (sm_100) GPUs; this module detects that
|
| 5 |
+
at import time and falls back to PyTorch scaled-dot-product attention (SDPA)
|
| 6 |
+
automatically. The public class name / call signature are unchanged.
|
| 7 |
"""
|
| 8 |
|
| 9 |
import torch
|
| 10 |
+
import torch.nn.functional as F
|
| 11 |
from typing import Optional, Tuple
|
| 12 |
from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen
|
| 13 |
|
| 14 |
+
|
| 15 |
+
# ---------------------------------------------------------------------------
|
| 16 |
+
# FA3 availability check
|
| 17 |
+
# ---------------------------------------------------------------------------
|
| 18 |
+
|
| 19 |
+
def _is_blackwell() -> bool:
|
| 20 |
+
"""Return True when the current default CUDA device is an sm_100 (Blackwell) GPU."""
|
| 21 |
+
if not torch.cuda.is_available():
|
| 22 |
+
return False
|
| 23 |
+
cap = torch.cuda.get_device_capability()
|
| 24 |
+
# Blackwell → compute capability 10.x (sm_100)
|
| 25 |
+
return cap[0] >= 10
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
_fa3_available: bool = False
|
| 29 |
+
_fa3_unavailable_reason: str = ""
|
| 30 |
+
_flash_attn_func = None
|
| 31 |
+
|
| 32 |
+
if _is_blackwell():
|
| 33 |
+
_fa3_unavailable_reason = (
|
| 34 |
+
"FlashAttention-3 is not yet supported on Blackwell (sm_100) GPUs. "
|
| 35 |
+
"Falling back to scaled-dot-product attention (SDPA)."
|
| 36 |
+
)
|
| 37 |
+
else:
|
| 38 |
+
try:
|
| 39 |
+
from kernels import get_kernel
|
| 40 |
+
_k = get_kernel("kernels-community/vllm-flash-attn3")
|
| 41 |
+
_flash_attn_func = _k.flash_attn_func
|
| 42 |
+
_fa3_available = True
|
| 43 |
+
except Exception as e:
|
| 44 |
+
_fa3_unavailable_reason = (
|
| 45 |
+
"FlashAttention-3 via Hugging Face `kernels` is unavailable. "
|
| 46 |
+
f"Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n{e}\n"
|
| 47 |
+
"Falling back to scaled-dot-product attention (SDPA)."
|
| 48 |
)
|
| 49 |
|
| 50 |
+
|
| 51 |
+
# ---------------------------------------------------------------------------
|
| 52 |
+
# FA3 custom op (registered only when the kernel loaded successfully)
|
| 53 |
+
# ---------------------------------------------------------------------------
|
| 54 |
+
|
| 55 |
+
if _fa3_available:
|
| 56 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 57 |
+
def flash_attn_func(
|
| 58 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
|
| 59 |
+
) -> torch.Tensor:
|
| 60 |
+
# _flash_attn_func returns (output, softmax_lse); we only need output.
|
| 61 |
+
output, _lse = _flash_attn_func(q, k, v, causal=causal)
|
| 62 |
+
return output
|
| 63 |
+
|
| 64 |
+
@flash_attn_func.register_fake
|
| 65 |
+
def _flash_attn_func_fake(q, k, v, causal=False):
|
| 66 |
+
# output shape mirrors q: (batch, seq_len, num_heads, head_dim)
|
| 67 |
+
return torch.empty_like(q).contiguous()
|
| 68 |
+
|
| 69 |
+
else:
|
| 70 |
+
# Provide a stub so call-sites that import the symbol don't break at
|
| 71 |
+
# module load; the processor will route around it at runtime.
|
| 72 |
+
def flash_attn_func(
|
| 73 |
+
q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False
|
| 74 |
+
) -> torch.Tensor:
|
| 75 |
+
raise RuntimeError(_fa3_unavailable_reason)
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# ---------------------------------------------------------------------------
|
| 79 |
+
# SDPA fallback helper
|
| 80 |
+
# ---------------------------------------------------------------------------
|
| 81 |
+
|
| 82 |
+
def _sdpa_attention(
|
| 83 |
+
q: torch.Tensor,
|
| 84 |
+
k: torch.Tensor,
|
| 85 |
+
v: torch.Tensor,
|
| 86 |
+
causal: bool = False,
|
| 87 |
) -> torch.Tensor:
|
| 88 |
+
"""
|
| 89 |
+
Scaled dot-product attention using torch.nn.functional.scaled_dot_product_attention.
|
| 90 |
+
|
| 91 |
+
Input / output layout: (B, S, H, D_h) — same as the FA3 kernel.
|
| 92 |
+
"""
|
| 93 |
+
# SDPA expects (B, H, S, D_h)
|
| 94 |
+
q = q.transpose(1, 2)
|
| 95 |
+
k = k.transpose(1, 2)
|
| 96 |
+
v = v.transpose(1, 2)
|
| 97 |
|
| 98 |
+
out = F.scaled_dot_product_attention(q, k, v, is_causal=causal)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 99 |
|
| 100 |
+
# Back to (B, S, H, D_h)
|
| 101 |
+
return out.transpose(1, 2)
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
# ---------------------------------------------------------------------------
|
| 105 |
+
# Attention processor
|
| 106 |
+
# ---------------------------------------------------------------------------
|
| 107 |
|
| 108 |
class QwenDoubleStreamAttnProcessorFA3:
|
| 109 |
"""
|
| 110 |
+
Attention processor for the Qwen double-stream architecture.
|
| 111 |
+
|
| 112 |
+
Preferred backend: vLLM FlashAttention-3 via Hugging Face ``kernels``.
|
| 113 |
+
Automatic fallback: PyTorch ``scaled_dot_product_attention`` (SDPA) when
|
| 114 |
+
FA3 is unavailable — e.g. on Blackwell (sm_100) GPUs where FA3 is not yet
|
| 115 |
+
supported, or when the ``kernels`` package is absent.
|
| 116 |
+
|
| 117 |
+
Notes / limitations
|
| 118 |
+
-------------------
|
| 119 |
+
- Arbitrary attention masks are not supported on the FA3 path. Pass
|
| 120 |
+
``attention_mask=None`` (the default) to stay on the fast path.
|
| 121 |
+
- On the SDPA path, ``attention_mask`` is likewise ignored; add explicit
|
| 122 |
+
support here if you need it.
|
| 123 |
+
- ``encoder_hidden_states`` (text stream) is required.
|
| 124 |
"""
|
| 125 |
|
| 126 |
+
_attention_backend: str # set in __init__ after capability detection
|
| 127 |
|
| 128 |
def __init__(self):
|
| 129 |
+
if _fa3_available:
|
| 130 |
+
self._attention_backend = "fa3"
|
| 131 |
+
else:
|
| 132 |
+
import warnings
|
| 133 |
+
warnings.warn(
|
| 134 |
+
f"QwenDoubleStreamAttnProcessorFA3: {_fa3_unavailable_reason}",
|
| 135 |
+
stacklevel=2,
|
| 136 |
+
)
|
| 137 |
+
self._attention_backend = "sdpa"
|
| 138 |
+
|
| 139 |
+
def _attend(
|
| 140 |
+
self,
|
| 141 |
+
q: torch.Tensor,
|
| 142 |
+
k: torch.Tensor,
|
| 143 |
+
v: torch.Tensor,
|
| 144 |
+
causal: bool = False,
|
| 145 |
+
) -> torch.Tensor:
|
| 146 |
+
"""Dispatch to FA3 or SDPA depending on what is available."""
|
| 147 |
+
if self._attention_backend == "fa3":
|
| 148 |
+
return flash_attn_func(q, k, v, causal=causal)
|
| 149 |
+
return _sdpa_attention(q, k, v, causal=causal)
|
| 150 |
|
| 151 |
@torch.no_grad()
|
| 152 |
def __call__(
|
| 153 |
self,
|
| 154 |
+
attn,
|
| 155 |
+
hidden_states: torch.FloatTensor, # (B, S_img, D_model)
|
| 156 |
+
encoder_hidden_states: torch.FloatTensor = None, # (B, S_txt, D_model)
|
| 157 |
+
encoder_hidden_states_mask: torch.FloatTensor = None, # unused
|
| 158 |
+
attention_mask: Optional[torch.FloatTensor] = None, # unsupported on FA3 path
|
| 159 |
+
image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None,
|
| 160 |
) -> Tuple[torch.FloatTensor, torch.FloatTensor]:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 161 |
|
| 162 |
+
if encoder_hidden_states is None:
|
| 163 |
+
raise ValueError(
|
| 164 |
+
"QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream)."
|
| 165 |
+
)
|
| 166 |
+
if attention_mask is not None and self._attention_backend == "fa3":
|
| 167 |
+
raise NotImplementedError(
|
| 168 |
+
"attention_mask is not supported on the FA3 path. "
|
| 169 |
+
"Either drop the mask or let the processor fall back to SDPA."
|
| 170 |
+
)
|
| 171 |
|
| 172 |
B, S_img, _ = hidden_states.shape
|
| 173 |
S_txt = encoder_hidden_states.shape[1]
|
| 174 |
|
| 175 |
+
# ---- QKV projections ----
|
| 176 |
+
img_q = attn.to_q(hidden_states)
|
| 177 |
img_k = attn.to_k(hidden_states)
|
| 178 |
img_v = attn.to_v(hidden_states)
|
| 179 |
|
| 180 |
+
txt_q = attn.add_q_proj(encoder_hidden_states)
|
|
|
|
| 181 |
txt_k = attn.add_k_proj(encoder_hidden_states)
|
| 182 |
txt_v = attn.add_v_proj(encoder_hidden_states)
|
| 183 |
|
|
|
|
| 191 |
txt_k = txt_k.unflatten(-1, (H, -1))
|
| 192 |
txt_v = txt_v.unflatten(-1, (H, -1))
|
| 193 |
|
| 194 |
+
# ---- Q/K normalization ----
|
| 195 |
if getattr(attn, "norm_q", None) is not None:
|
| 196 |
img_q = attn.norm_q(img_q)
|
| 197 |
if getattr(attn, "norm_k", None) is not None:
|
|
|
|
| 204 |
# ---- RoPE (Qwen variant) ----
|
| 205 |
if image_rotary_emb is not None:
|
| 206 |
img_freqs, txt_freqs = image_rotary_emb
|
|
|
|
| 207 |
img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False)
|
| 208 |
img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False)
|
| 209 |
txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False)
|
| 210 |
txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False)
|
| 211 |
|
| 212 |
# ---- Joint attention over [text, image] along sequence axis ----
|
| 213 |
+
q = torch.cat([txt_q, img_q], dim=1) # (B, S_txt + S_img, H, D_h)
|
|
|
|
| 214 |
k = torch.cat([txt_k, img_k], dim=1)
|
| 215 |
v = torch.cat([txt_v, img_v], dim=1)
|
| 216 |
|
| 217 |
+
out = self._attend(q, k, v, causal=False) # (B, S_total, H, D_h)
|
|
|
|
| 218 |
|
| 219 |
# ---- Back to (B, S, D_model) ----
|
| 220 |
out = out.flatten(2, 3).to(q.dtype)
|
| 221 |
|
| 222 |
+
# ---- Split text / image segments ----
|
| 223 |
txt_attn_out = out[:, :S_txt, :]
|
| 224 |
img_attn_out = out[:, S_txt:, :]
|
| 225 |
|