TSXu commited on
Commit
8fc8d44
·
1 Parent(s): dd7a8c2

Add Flash Attention 3 support (optional)

Browse files

- Modified src/flux/math.py to support FA3 when USE_FA3=1
- Uses kernels library to load vllm-flash-attn3 from HuggingFace
- Registered as custom op for torch.export compatibility
- Falls back to PyTorch SDPA (FA2) when FA3 not available
- Added kernels to requirements.txt

To enable FA3: export USE_FA3=1

Files changed (2) hide show
  1. requirements.txt +3 -0
  2. src/flux/math.py +45 -2
requirements.txt CHANGED
@@ -30,3 +30,6 @@ pypinyin
30
  # Web UI (spaces handles torch 2.8+ AOT compilation)
31
  gradio>=5.0
32
  spaces>=0.47.0
 
 
 
 
30
  # Web UI (spaces handles torch 2.8+ AOT compilation)
31
  gradio>=5.0
32
  spaces>=0.47.0
33
+
34
+ # Flash Attention 3 support (optional, for H100 GPUs)
35
+ kernels
src/flux/math.py CHANGED
@@ -1,13 +1,56 @@
1
  import torch
2
  from einops import rearrange
3
  from torch import Tensor
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4
 
5
 
6
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
7
  q, k = apply_rope(q, k, pe)
8
 
9
- x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
10
- x = rearrange(x, "B H L D -> B L (H D)")
 
 
 
 
 
 
 
 
 
11
 
12
  return x
13
 
 
1
  import torch
2
  from einops import rearrange
3
  from torch import Tensor
4
+ from typing import Optional, List
5
+ import os
6
+
7
+ # ============================================================
8
+ # Flash Attention 3 Support (for H100 GPUs)
9
+ # ============================================================
10
+ _USE_FA3 = os.environ.get("USE_FA3", "0") == "1"
11
+ _flash_attn_func = None
12
+
13
+ if _USE_FA3:
14
+ try:
15
+ from kernels import get_kernel
16
+ _fa3_kernel = get_kernel("kernels-community/vllm-flash-attn3")
17
+ _flash_attn_func_raw = _fa3_kernel.flash_attn_func
18
+
19
+ @torch.library.custom_op("flash::flash_attn_func", mutates_args=())
20
+ def _flash_attn_func(
21
+ q: torch.Tensor,
22
+ k: torch.Tensor,
23
+ v: torch.Tensor,
24
+ softmax_scale: Optional[float] = None,
25
+ causal: bool = False,
26
+ ) -> torch.Tensor:
27
+ outputs = _flash_attn_func_raw(q, k, v, softmax_scale=softmax_scale, causal=causal)
28
+ return outputs[0]
29
+
30
+ @_flash_attn_func.register_fake
31
+ def _(q, k, v, **kwargs):
32
+ return torch.empty_like(q).contiguous()
33
+
34
+ print("✓ Flash Attention 3 loaded successfully!")
35
+ except Exception as e:
36
+ print(f"Flash Attention 3 not available: {e}")
37
+ _USE_FA3 = False
38
 
39
 
40
  def attention(q: Tensor, k: Tensor, v: Tensor, pe: Tensor) -> Tensor:
41
  q, k = apply_rope(q, k, pe)
42
 
43
+ if _USE_FA3 and _flash_attn_func is not None:
44
+ # FA3 expects (B, L, H, D) format
45
+ q_fa3 = rearrange(q, "B H L D -> B L H D")
46
+ k_fa3 = rearrange(k, "B H L D -> B L H D")
47
+ v_fa3 = rearrange(v, "B H L D -> B L H D")
48
+ x = _flash_attn_func(q_fa3, k_fa3, v_fa3)
49
+ x = rearrange(x, "B L H D -> B L (H D)")
50
+ else:
51
+ # Standard PyTorch SDPA (uses FA2 if available)
52
+ x = torch.nn.functional.scaled_dot_product_attention(q, k, v)
53
+ x = rearrange(x, "B H L D -> B L (H D)")
54
 
55
  return x
56