Spaces:
Running on Zero
Running on Zero
Add Flash Attention 3 support (optional)
Browse files- Modified src/flux/math.py to support FA3 when USE_FA3=1
- Uses kernels library to load vllm-flash-attn3 from HuggingFace
- Registered as custom op for torch.export compatibility
- Falls back to PyTorch SDPA (FA2) when FA3 not available
- Added kernels to requirements.txt
To enable FA3: export USE_FA3=1
- requirements.txt +3 -0
- src/flux/math.py +45 -2
requirements.txt
CHANGED
|
@@ -30,3 +30,6 @@ pypinyin
|
|
| 30 |
# Web UI (spaces handles torch 2.8+ AOT compilation)
|
| 31 |
gradio>=5.0
|
| 32 |
spaces>=0.47.0
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
# Web UI (spaces handles torch 2.8+ AOT compilation)
|
| 31 |
gradio>=5.0
|
| 32 |
spaces>=0.47.0
|
| 33 |
+
|
| 34 |
+
# Flash Attention 3 support (optional, for H100 GPUs)
|
| 35 |
+
kernels
|
src/flux/math.py
CHANGED
|
@@ -1,13 +1,56 @@
|
|
| 1 |
import torch
|
| 2 |
from einops import rearrange
|
| 3 |
from torch import Tensor
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 4 |
|
| 5 |
|
| 6 |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 7 |
q, k = apply_rope(q, k, pe)
|
| 8 |
|
| 9 |
-
|
| 10 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
return x
|
| 13 |
|
|
|
|
| 1 |
import torch
|
| 2 |
from einops import rearrange
|
| 3 |
from torch import Tensor
|
| 4 |
+
from typing import Optional, List
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# ============================================================
|
| 8 |
+
# Flash Attention 3 Support (for H100 GPUs)
|
| 9 |
+
# ============================================================
|
| 10 |
+
_USE_FA3 = os.environ.get("USE_FA3", "0") == "1"
|
| 11 |
+
_flash_attn_func = None
|
| 12 |
+
|
| 13 |
+
if _USE_FA3:
|
| 14 |
+
try:
|
| 15 |
+
from kernels import get_kernel
|
| 16 |
+
_fa3_kernel = get_kernel("kernels-community/vllm-flash-attn3")
|
| 17 |
+
_flash_attn_func_raw = _fa3_kernel.flash_attn_func
|
| 18 |
+
|
| 19 |
+
@torch.library.custom_op("flash::flash_attn_func", mutates_args=())
|
| 20 |
+
def _flash_attn_func(
|
| 21 |
+
q: torch.Tensor,
|
| 22 |
+
k: torch.Tensor,
|
| 23 |
+
v: torch.Tensor,
|
| 24 |
+
softmax_scale: Optional[float] = None,
|
| 25 |
+
causal: bool = False,
|
| 26 |
+
) -> torch.Tensor:
|
| 27 |
+
outputs = _flash_attn_func_raw(q, k, v, softmax_scale=softmax_scale, causal=causal)
|
| 28 |
+
return outputs[0]
|
| 29 |
+
|
| 30 |
+
@_flash_attn_func.register_fake
|
| 31 |
+
def _(q, k, v, **kwargs):
|
| 32 |
+
return torch.empty_like(q).contiguous()
|
| 33 |
+
|
| 34 |
+
print("✓ Flash Attention 3 loaded successfully!")
|
| 35 |
+
except Exception as e:
|
| 36 |
+
print(f"Flash Attention 3 not available: {e}")
|
| 37 |
+
_USE_FA3 = False
|
| 38 |
|
| 39 |
|
| 40 |
def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
|
| 41 |
q, k = apply_rope(q, k, pe)
|
| 42 |
|
| 43 |
+
if _USE_FA3 and _flash_attn_func is not None:
|
| 44 |
+
# FA3 expects (B, L, H, D) format
|
| 45 |
+
q_fa3 = rearrange(q, "B H L D -> B L H D")
|
| 46 |
+
k_fa3 = rearrange(k, "B H L D -> B L H D")
|
| 47 |
+
v_fa3 = rearrange(v, "B H L D -> B L H D")
|
| 48 |
+
x = _flash_attn_func(q_fa3, k_fa3, v_fa3)
|
| 49 |
+
x = rearrange(x, "B L H D -> B L (H D)")
|
| 50 |
+
else:
|
| 51 |
+
# Standard PyTorch SDPA (uses FA2 if available)
|
| 52 |
+
x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
|
| 53 |
+
x = rearrange(x, "B H L D -> B L (H D)")
|
| 54 |
|
| 55 |
return x
|
| 56 |
|