kernels-bot commited on
Commit
f81bfcb
·
verified ·
1 Parent(s): c4d0b6b

Uploaded using `kernel-builder`.

Browse files
Files changed (1) hide show
  1. README.md +8 -150
README.md CHANGED
@@ -1,160 +1,18 @@
1
  ---
 
2
  license: bsd-3-clause
3
- tags:
4
- - kernels
5
  ---
6
 
7
- <!-- ![Status](https://hubwebhook.dholtz.com/shield?repo=kernels-community/flash-attn2) -->
8
 
9
- # Flash Attention
10
 
11
- Flash Attention is a fast and memory-efficient implementation of the attention mechanism, designed to work with large models and long sequences. This is a Hugging Face compliant kernel build of Flash Attention.
12
 
13
- Original code here [https://github.com/Dao-AILab/flash-attention](https://github.com/Dao-AILab/flash-attention).
14
 
 
15
 
16
- [`scripts/readme_example.py`](scripts/readme_example.py) provides a simple example of how to use the Flash Attention kernel in PyTorch. It demonstrates standard attention, causal attention, and variable-length sequences.
17
- ```python
18
- # /// script
19
- # dependencies = [
20
- # "numpy",
21
- # "torch",
22
- # "kernels"
23
- # ]
24
- # ///
25
- import torch
26
- from kernels import get_kernel
27
 
28
- # Setup
29
- torch.manual_seed(42)
30
- flash_attn = get_kernel("kernels-community/flash-attn2")
31
- device = torch.device("cuda")
32
-
33
- # Create test tensors
34
- B, S, H, D = 2, 5, 4, 8 # batch, seq_len, heads, head_dim
35
- q = k = v = torch.randn(B, S, H, D, device=device, dtype=torch.float16)
36
-
37
- # Reference implementation using PyTorch SDPA
38
- def reference_attention(query, key, value, causal=False):
39
- query, key, value = (x.transpose(1, 2).contiguous() for x in (query, key, value))
40
- with torch.nn.attention.sdpa_kernel(torch.nn.attention.SDPBackend.MATH):
41
- out = torch.nn.functional.scaled_dot_product_attention(query, key, value, is_causal=causal)
42
- return out.transpose(1, 2).contiguous()
43
-
44
- # 1. Standard attention
45
- print("\n1. Standard attention:")
46
- out_ref = reference_attention(q, k, v)
47
- out_flash = flash_attn.fwd(
48
- q=q,
49
- k=k,
50
- v=v,
51
- is_causal=False,
52
- )[0]
53
- print(f"Reference output: {out_ref.shape}")
54
- print(f"Flash output: {out_flash.shape}")
55
- print(f"Outputs close: {torch.allclose(out_flash, out_ref, atol=1e-2, rtol=1e-3)}")
56
-
57
- # 2. Causal attention (for autoregressive models)
58
- print("\n2. Causal attention:")
59
-
60
- out_ref_causal = reference_attention(q, k, v, causal=True)
61
- out_causal = flash_attn.fwd(
62
- q=q,
63
- k=k,
64
- v=v,
65
- is_causal=True,
66
- )[0]
67
- print(f"Reference causal output: {out_ref_causal.shape}")
68
- print(f"Flash causal output: {out_causal.shape}")
69
- print(f"Outputs close: {torch.allclose(out_causal, out_ref_causal, atol=1e-2, rtol=1e-3)}")
70
-
71
- def var_reference_attention(q, k, v, cu_seqlens_q, cu_seqlens_k, max_seqlen_q, max_seqlen_k, causal=False):
72
- batch_size = cu_seqlens_q.shape[0] - 1
73
- # Return output in packed format (same as flash attention)
74
- total_tokens_q = q.shape[0]
75
- out = torch.zeros((total_tokens_q, q.shape[1], q.shape[2]), device=q.device, dtype=q.dtype)
76
-
77
- for b in range(batch_size):
78
- start_q, end_q = cu_seqlens_q[b], cu_seqlens_q[b + 1]
79
- start_k, end_k = cu_seqlens_k[b], cu_seqlens_k[b + 1]
80
-
81
- # Extract slices for this batch
82
- q_slice = q[start_q:end_q] # Shape: (seq_len_q, H, D)
83
- k_slice = k[start_k:end_k] # Shape: (seq_len_k, H, D)
84
- v_slice = v[start_k:end_k] # Shape: (seq_len_k, H, D)
85
-
86
- # Add batch dimension for reference_attention
87
- q_slice = q_slice.unsqueeze(0) # Shape: (1, seq_len_q, H, D)
88
- k_slice = k_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
89
- v_slice = v_slice.unsqueeze(0) # Shape: (1, seq_len_k, H, D)
90
-
91
- # Compute attention and remove batch dimension
92
- attn_out = reference_attention(q_slice, k_slice, v_slice, causal=causal)
93
- attn_out = attn_out.squeeze(0) # Shape: (seq_len_q, H, D)
94
-
95
- # Place result in output tensor (packed format)
96
- out[start_q:end_q] = attn_out
97
-
98
- return out
99
-
100
- # 3. Variable length sequences (packed format)
101
- print("\n3. Variable length sequences:")
102
- # Pack sequences of lengths [3,4,3] for q and [4,5,3] for k into single tensors
103
- q_var = torch.randn(10, H, D, device=device, dtype=torch.float16) # total_q=10
104
- k_var = v_var = torch.randn(12, H, D, device=device, dtype=torch.float16) # total_k=12
105
- cu_q = torch.tensor([0, 3, 7, 10], device=device, dtype=torch.int32) # cumulative sequence lengths
106
- cu_k = torch.tensor([0, 4, 9, 12], device=device, dtype=torch.int32)
107
-
108
- out_var_ref = var_reference_attention(q_var, k_var, v_var, cu_q, cu_k, max_seqlen_q=4, max_seqlen_k=5, causal=False)
109
- # Custom function to handle variable
110
- out_var = flash_attn.varlen_fwd(
111
- q=q_var,
112
- k=k_var,
113
- v=v_var,
114
- cu_seqlens_q=cu_q,
115
- cu_seqlens_k=cu_k,
116
- max_seqlen_q=4,
117
- max_seqlen_k=5,
118
- )[0]
119
- print(f"Variable length output: {out_var.shape}")
120
- print(f"Reference variable length output: {out_var_ref.shape}")
121
- print(f"Outputs close: {torch.allclose(out_var, out_var_ref, atol=1e-2, rtol=1e-3)}")
122
- ```
123
-
124
- run it using the following command:
125
-
126
- ```bash
127
- uv run scripts/readme_example.py
128
- ```
129
-
130
- ```txt
131
- Reading inline script metadata from `scripts/readme_example.py`
132
- Fetching 20 files: 100%|██████████████████████████████████████████████████| 20/20 [00:00<00:00, 16371.21it/s]
133
-
134
- 1. Standard attention:
135
- Reference output: torch.Size([2, 5, 4, 8])
136
- Flash output: torch.Size([2, 5, 4, 8])
137
- Outputs close: True
138
-
139
- 2. Causal attention:
140
- Reference causal output: torch.Size([2, 5, 4, 8])
141
- Flash causal output: torch.Size([2, 5, 4, 8])
142
- Outputs close: True
143
-
144
- 3. Variable length sequences:
145
- Variable length output: torch.Size([10, 4, 8])
146
- Reference variable length output: torch.Size([10, 4, 8])
147
- Outputs close: True
148
- ```
149
-
150
-
151
- ### Performance
152
-
153
- <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_animation.svg" />
154
- <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_animation.svg" />
155
-
156
- <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_latency.svg" />
157
- <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_latency.svg" />
158
-
159
- <img class="dark:hidden border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_light_throughput.svg" />
160
- <img class="hidden dark:block border border-gray-200 dark:border-gray-700 rounded-lg" src="media/benches_dark_throughput.svg" />
 
1
  ---
2
+ library_name: kernels
3
  license: bsd-3-clause
 
 
4
  ---
5
 
6
+ This is the repository card of kernels-community/flash-attn2 that has been pushed on the Hub. It was built to be used with the [`kernels` library](https://github.com/huggingface/kernels). This card was automatically generated.
7
 
8
+ ## How to use
9
 
10
+ Usage example not available.
11
 
12
+ ## Available functions
13
 
14
+ Function list not available.
15
 
16
+ ## Benchmarks
 
 
 
 
 
 
 
 
 
 
17
 
18
+ Benchmarking script is available for this kernel. Run `kernels benchmark kernels-community/flash-attn2`.