trinath3 commited on
Commit
932ccfc
·
verified ·
1 Parent(s): 8bb5d2c

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +205 -4
app.py CHANGED
@@ -1,7 +1,208 @@
 
 
 
 
1
  import gradio as gr
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import triton
3
+ import triton.language as tl
4
+ import spaces
5
  import gradio as gr
6
+ import os
7
+ from PIL import Image
8
 
 
 
9
 
10
+ import torch
11
+ import triton.language as tl
12
+ import triton
13
+
14
+
15
+ # Standard SDPA
16
+
17
+ def attention(q, k, v):
18
+ # q, k, v shape: (B, H, N, D)
19
+
20
+ # 1. Transpose K for the dot product: (B, H, D, N)
21
+ # We only want to flip the last two dimensions
22
+ k_t = k.transpose(-2, -1)
23
+
24
+ # 2. Scaled Dot Product
25
+ # d_k is the last dimension of q
26
+ d_k = q.shape[-1]
27
+ attn_weights = (q @ k_t) * (d_k ** -0.5)
28
+
29
+ # 3. Softmax along the last dimension (columns of the score matrix)
30
+ A = torch.softmax(attn_weights, dim=-1)
31
+
32
+ # 4. Multiply by V: (B, H, N, N) @ (B, H, N, D) -> (B, H, N, D)
33
+ O = A @ v
34
+ return O
35
+
36
+
37
+
38
+ # Define the search space
39
+ configs = [
40
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 2}),
41
+ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 16, 'num_stages': 2}),
42
+ ]
43
+
44
+ @triton.autotune(
45
+ configs=configs,
46
+ key=['N', 'D'], # Re-tune if sequence length or head dim changes
47
+ )
48
+ @triton.jit
49
+ def flash_attn_kernel(
50
+ Q, K, V, Out,
51
+ stride_qb, stride_qh, stride_qn, stride_qd,
52
+ N, D: tl.constexpr,
53
+ BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
54
+ ):
55
+ batch_id = tl.program_id(0)
56
+ head_id = tl.program_id(1)
57
+ row_block_id = tl.program_id(2)
58
+
59
+ q_ptr_base = Q + (batch_id * stride_qb) + (head_id * stride_qh)
60
+ k_ptr_base = K + (batch_id * stride_qb) + (head_id * stride_qh)
61
+ v_ptr_base = V + (batch_id * stride_qb) + (head_id * stride_qh)
62
+
63
+ offs_m = row_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
64
+ offs_d = tl.arange(0, D)
65
+
66
+ q_ptrs = q_ptr_base + (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd)
67
+ q_block = tl.load(q_ptrs, mask=offs_m[:, None] < N, other=0.0)
68
+
69
+ # --- Keep all accumulators in float32 ---
70
+ m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
71
+ l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
72
+ acc = tl.zeros([BLOCK_M, D], dtype=tl.float32)
73
+
74
+ qk_scale = 1.0 / (D ** 0.5)
75
+
76
+ offs_n = tl.arange(0, BLOCK_N)
77
+ # K is laid out as (D, BLOCK_N) for the dot: q(M,D) @ k(D,N)
78
+ k_ptrs = k_ptr_base + (offs_n[None, :] * stride_qn + offs_d[:, None] * stride_qd)
79
+ v_ptrs = v_ptr_base + (offs_n[:, None] * stride_qn + offs_d[None, :] * stride_qd)
80
+
81
+ for start_n in range(0, N, BLOCK_N):
82
+ # Load K block: shape (D, BLOCK_N)
83
+ k_block = tl.load(
84
+ k_ptrs + start_n * stride_qn,
85
+ mask=(start_n + offs_n[None, :]) < N,
86
+ other=0.0
87
+ )
88
+
89
+ # q(M, D) @ k(D, N) -> qk(M, N)
90
+ qk = tl.dot(q_block, k_block)
91
+ qk = qk * qk_scale # float32
92
+
93
+ # --- Online softmax update (all float32) ---
94
+ m_ij = tl.max(qk, axis=1) # (M,)
95
+ m_i_new = tl.maximum(m_i, m_ij) # (M,)
96
+
97
+ alpha = tl.exp(m_i - m_i_new) # (M,) rescale factor
98
+ p_ij = tl.exp(qk - m_i_new[:, None]) # (M, N) in float32
99
+
100
+ l_ij = tl.sum(p_ij, axis=1) # (M,)
101
+ l_i_new = alpha * l_i + l_ij # (M,)
102
+
103
+ # Rescale accumulator, then add new contribution
104
+ acc = acc * alpha[:, None]
105
+
106
+ # Load V block: shape (BLOCK_N, D)
107
+ v_block = tl.load(
108
+ v_ptrs + start_n * stride_qn,
109
+ mask=(start_n + offs_n[:, None]) < N,
110
+ other=0.0
111
+ )
112
+
113
+ # Cast to fp16 ONLY for the dot (tensor cores), immediately cast result back
114
+ acc += tl.dot(p_ij.to(tl.float16), v_block.to(tl.float16)).to(tl.float32)
115
+
116
+ m_i = m_i_new
117
+ l_i = l_i_new
118
+
119
+ # Normalize
120
+ acc = acc / l_i[:, None]
121
+
122
+ # Write output — cast down to original dtype only at store
123
+ out_ptrs = (
124
+ Out
125
+ + (batch_id * stride_qb)
126
+ + (head_id * stride_qh)
127
+ + (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd)
128
+ )
129
+ tl.store(out_ptrs, acc.to(Out.dtype.element_ty), mask=offs_m[:, None] < N)
130
+
131
+
132
+
133
+
134
+ def flash_attention(q, k, v):
135
+ B, H, N, D = q.shape
136
+ out = torch.empty_like(q)
137
+
138
+ # We still need to define the grid, but we don't know BLOCK_M yet.
139
+ # We can use a helper or just assume a reasonable default for grid calc.
140
+ grid = lambda META: (B, H, triton.cdiv(N, META['BLOCK_M']))
141
+
142
+ flash_attn_kernel[grid](
143
+ q, k, v, out,
144
+ q.stride(0), q.stride(1), q.stride(2), q.stride(3),
145
+ N, D,
146
+ # BLOCK_M and BLOCK_N are omitted here; autotune injects them
147
+ )
148
+ return out
149
+
150
+
151
+ @triton.testing.perf_report(
152
+ triton.testing.Benchmark(
153
+ x_names=["N"], # x-axis: Sequence Length
154
+ x_vals=[128 * i for i in range(2, 33)], # Sweep from 256 to 4096
155
+ line_arg="provider",
156
+ line_vals=["torch-native", "triton"],
157
+ line_names=["Torch (native)", "Triton"],
158
+ styles=[("blue", "-"), ("green", "-")],
159
+ ylabel="TFLOPS", # Changed to TFLOPS for better insight
160
+ plot_name="Flash Attention Performance",
161
+ args={"Batch": 1, "Heads": 12, "D_head": 64},
162
+ )
163
+ )
164
+ def benchmark(Batch, Heads, N, D_head, provider):
165
+ # Use the N passed from x_vals
166
+ q = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
167
+ k = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
168
+ v = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
169
+
170
+ quantiles = [0.5, 0.2, 0.8]
171
+
172
+ if provider == "torch-native":
173
+ ms, min_ms, max_ms = triton.testing.do_bench(lambda: attention(q, k, v), quantiles=quantiles)
174
+ if provider == "triton":
175
+ ms, min_ms, max_ms = triton.testing.do_bench(lambda: flash_attention(q, k, v), quantiles=quantiles)
176
+
177
+ # Calculation for Attention TFLOPS:
178
+ # 2 * (Q@K) + 2 * (Softmax@V) = 4 * Batch * Heads * N^2 * D_head
179
+ tflops = lambda ms: 4 * Batch * Heads * N**2 * D_head * 1e-12 / (ms * 1e-3)
180
+
181
+ return tflops(ms), tflops(max_ms), tflops(min_ms)
182
+
183
+
184
+
185
+ @spaces.GPU(duration=180) # Triton benchmarks can take a minute
186
+ def run_benchmark():
187
+ # Ensure we are in a clean directory for images
188
+ output_dir = "./plots"
189
+ if not os.path.exists(output_dir):
190
+ os.makedirs(output_dir)
191
+
192
+ # Run the triton benchmark
193
+ # This will generate several .png files in the save_path
194
+ bench_flash_attention.run(save_path=output_dir, print_data=True)
195
+
196
+ # Collect the generated images
197
+ images = [os.path.join(output_dir, f) for f in os.listdir(output_dir) if f.endswith('.png')]
198
+ return images
199
+
200
+ # Gradio Interface
201
+ with gr.Blocks() as demo:
202
+ gr.Markdown("# Triton Fused Attention Benchmark on ZeroGPU")
203
+ run_btn = gr.Button("Run Benchmark")
204
+ out_gallery = gr.Gallery(label="Performance Plots")
205
+
206
+ run_btn.click(fn=run_benchmark, outputs=out_gallery)
207
+
208
+ demo.launch()