| """ |
| Paired with a good language model. Thanks! |
| """ |
|
|
| import torch |
| from typing import Optional, Tuple |
| from diffusers.models.transformers.transformer_qwenimage import apply_rotary_emb_qwen |
|
|
| try: |
| from kernels import get_kernel |
| _k = get_kernel("kernels-community/vllm-flash-attn3") |
| _flash_attn_func = _k.flash_attn_func |
| except Exception as e: |
| _flash_attn_func = None |
| _kernels_err = e |
|
|
|
|
| def _ensure_fa3_available(): |
| if _flash_attn_func is None: |
| raise ImportError( |
| "FlashAttention-3 via Hugging Face `kernels` is required. " |
| "Tried `get_kernel('kernels-community/vllm-flash-attn3')` and failed with:\n" |
| f"{_kernels_err}" |
| ) |
|
|
| @torch.library.custom_op("flash::flash_attn_func", mutates_args=()) |
| def flash_attn_func( |
| q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, causal: bool = False |
| ) -> torch.Tensor: |
| outputs, lse = _flash_attn_func(q, k, v, causal=causal) |
| return outputs |
|
|
| @flash_attn_func.register_fake |
| def _(q, k, v, **kwargs): |
| |
| |
| |
| meta_q = torch.empty_like(q).contiguous() |
| return meta_q |
|
|
|
|
| class QwenDoubleStreamAttnProcessorFA3: |
| """ |
| FA3-based attention processor for Qwen double-stream architecture. |
| Computes joint attention over concatenated [text, image] streams using vLLM FlashAttention-3 |
| accessed via Hugging Face `kernels`. |
| |
| Notes / limitations: |
| - General attention masks are not supported here (FA3 path). `is_causal=False` and no arbitrary mask. |
| - Optional windowed attention / sink tokens / softcap can be plumbed through if you use those features. |
| - Expects an available `apply_rotary_emb_qwen` in scope (same as your non-FA3 processor). |
| """ |
|
|
| _attention_backend = "fa3" |
|
|
| def __init__(self): |
| _ensure_fa3_available() |
|
|
| @torch.no_grad() |
| def __call__( |
| self, |
| attn, |
| hidden_states: torch.FloatTensor, |
| encoder_hidden_states: torch.FloatTensor = None, |
| encoder_hidden_states_mask: torch.FloatTensor = None, |
| attention_mask: Optional[torch.FloatTensor] = None, |
| image_rotary_emb: Optional[Tuple[torch.Tensor, torch.Tensor]] = None, |
| ) -> Tuple[torch.FloatTensor, torch.FloatTensor]: |
| if encoder_hidden_states is None: |
| raise ValueError("QwenDoubleStreamAttnProcessorFA3 requires encoder_hidden_states (text stream).") |
| if attention_mask is not None: |
| |
| raise NotImplementedError("attention_mask is not supported in this FA3 implementation.") |
|
|
| _ensure_fa3_available() |
|
|
| B, S_img, _ = hidden_states.shape |
| S_txt = encoder_hidden_states.shape[1] |
|
|
| |
| img_q = attn.to_q(hidden_states) |
| img_k = attn.to_k(hidden_states) |
| img_v = attn.to_v(hidden_states) |
|
|
| |
| txt_q = attn.add_q_proj(encoder_hidden_states) |
| txt_k = attn.add_k_proj(encoder_hidden_states) |
| txt_v = attn.add_v_proj(encoder_hidden_states) |
|
|
| |
| H = attn.heads |
| img_q = img_q.unflatten(-1, (H, -1)) |
| img_k = img_k.unflatten(-1, (H, -1)) |
| img_v = img_v.unflatten(-1, (H, -1)) |
|
|
| txt_q = txt_q.unflatten(-1, (H, -1)) |
| txt_k = txt_k.unflatten(-1, (H, -1)) |
| txt_v = txt_v.unflatten(-1, (H, -1)) |
|
|
| |
| if getattr(attn, "norm_q", None) is not None: |
| img_q = attn.norm_q(img_q) |
| if getattr(attn, "norm_k", None) is not None: |
| img_k = attn.norm_k(img_k) |
| if getattr(attn, "norm_added_q", None) is not None: |
| txt_q = attn.norm_added_q(txt_q) |
| if getattr(attn, "norm_added_k", None) is not None: |
| txt_k = attn.norm_added_k(txt_k) |
|
|
| |
| if image_rotary_emb is not None: |
| img_freqs, txt_freqs = image_rotary_emb |
| |
| img_q = apply_rotary_emb_qwen(img_q, img_freqs, use_real=False) |
| img_k = apply_rotary_emb_qwen(img_k, img_freqs, use_real=False) |
| txt_q = apply_rotary_emb_qwen(txt_q, txt_freqs, use_real=False) |
| txt_k = apply_rotary_emb_qwen(txt_k, txt_freqs, use_real=False) |
|
|
| |
| |
| q = torch.cat([txt_q, img_q], dim=1) |
| k = torch.cat([txt_k, img_k], dim=1) |
| v = torch.cat([txt_v, img_v], dim=1) |
|
|
| |
| out = flash_attn_func(q, k, v, causal=False) |
|
|
| |
| out = out.flatten(2, 3).to(q.dtype) |
|
|
| |
| txt_attn_out = out[:, :S_txt, :] |
| img_attn_out = out[:, S_txt:, :] |
|
|
| |
| img_attn_out = attn.to_out[0](img_attn_out) |
| if len(attn.to_out) > 1: |
| img_attn_out = attn.to_out[1](img_attn_out) |
|
|
| txt_attn_out = attn.to_add_out(txt_attn_out) |
|
|
| return img_attn_out, txt_attn_out |