Spaces:
Sleeping
Sleeping
File size: 7,153 Bytes
932ccfc 8bb5d2c 932ccfc 8bb5d2c 932ccfc 8d83efa 932ccfc 8d83efa 932ccfc 8d83efa 932ccfc 8d83efa 932ccfc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 | import torch
import triton
import triton.language as tl
import spaces
import gradio as gr
import os
from PIL import Image
import torch
import triton.language as tl
import triton
# Standard SDPA
def attention(q, k, v):
# q, k, v shape: (B, H, N, D)
# 1. Transpose K for the dot product: (B, H, D, N)
# We only want to flip the last two dimensions
k_t = k.transpose(-2, -1)
# 2. Scaled Dot Product
# d_k is the last dimension of q
d_k = q.shape[-1]
attn_weights = (q @ k_t) * (d_k ** -0.5)
# 3. Softmax along the last dimension (columns of the score matrix)
A = torch.softmax(attn_weights, dim=-1)
# 4. Multiply by V: (B, H, N, N) @ (B, H, N, D) -> (B, H, N, D)
O = A @ v
return O
# Define the search space
configs = [
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 2}),
triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 16, 'num_stages': 2}),
]
@triton.autotune(
configs=configs,
key=['N', 'D'], # Re-tune if sequence length or head dim changes
)
@triton.jit
def flash_attn_kernel(
Q, K, V, Out,
stride_qb, stride_qh, stride_qn, stride_qd,
N, D: tl.constexpr,
BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr
):
batch_id = tl.program_id(0)
head_id = tl.program_id(1)
row_block_id = tl.program_id(2)
q_ptr_base = Q + (batch_id * stride_qb) + (head_id * stride_qh)
k_ptr_base = K + (batch_id * stride_qb) + (head_id * stride_qh)
v_ptr_base = V + (batch_id * stride_qb) + (head_id * stride_qh)
offs_m = row_block_id * BLOCK_M + tl.arange(0, BLOCK_M)
offs_d = tl.arange(0, D)
q_ptrs = q_ptr_base + (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd)
q_block = tl.load(q_ptrs, mask=offs_m[:, None] < N, other=0.0)
# --- Keep all accumulators in float32 ---
m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32)
l_i = tl.zeros([BLOCK_M], dtype=tl.float32)
acc = tl.zeros([BLOCK_M, D], dtype=tl.float32)
qk_scale = 1.0 / (D ** 0.5)
offs_n = tl.arange(0, BLOCK_N)
# K is laid out as (D, BLOCK_N) for the dot: q(M,D) @ k(D,N)
k_ptrs = k_ptr_base + (offs_n[None, :] * stride_qn + offs_d[:, None] * stride_qd)
v_ptrs = v_ptr_base + (offs_n[:, None] * stride_qn + offs_d[None, :] * stride_qd)
for start_n in range(0, N, BLOCK_N):
# Load K block: shape (D, BLOCK_N)
k_block = tl.load(
k_ptrs + start_n * stride_qn,
mask=(start_n + offs_n[None, :]) < N,
other=0.0
)
# q(M, D) @ k(D, N) -> qk(M, N)
qk = tl.dot(q_block, k_block)
qk = qk * qk_scale # float32
# --- Online softmax update (all float32) ---
m_ij = tl.max(qk, axis=1) # (M,)
m_i_new = tl.maximum(m_i, m_ij) # (M,)
alpha = tl.exp(m_i - m_i_new) # (M,) rescale factor
p_ij = tl.exp(qk - m_i_new[:, None]) # (M, N) in float32
l_ij = tl.sum(p_ij, axis=1) # (M,)
l_i_new = alpha * l_i + l_ij # (M,)
# Rescale accumulator, then add new contribution
acc = acc * alpha[:, None]
# Load V block: shape (BLOCK_N, D)
v_block = tl.load(
v_ptrs + start_n * stride_qn,
mask=(start_n + offs_n[:, None]) < N,
other=0.0
)
# Cast to fp16 ONLY for the dot (tensor cores), immediately cast result back
acc += tl.dot(p_ij.to(tl.float16), v_block.to(tl.float16)).to(tl.float32)
m_i = m_i_new
l_i = l_i_new
# Normalize
acc = acc / l_i[:, None]
# Write output — cast down to original dtype only at store
out_ptrs = (
Out
+ (batch_id * stride_qb)
+ (head_id * stride_qh)
+ (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd)
)
tl.store(out_ptrs, acc.to(Out.dtype.element_ty), mask=offs_m[:, None] < N)
def flash_attention(q, k, v):
B, H, N, D = q.shape
out = torch.empty_like(q)
# We still need to define the grid, but we don't know BLOCK_M yet.
# We can use a helper or just assume a reasonable default for grid calc.
grid = lambda META: (B, H, triton.cdiv(N, META['BLOCK_M']))
flash_attn_kernel[grid](
q, k, v, out,
q.stride(0), q.stride(1), q.stride(2), q.stride(3),
N, D,
# BLOCK_M and BLOCK_N are omitted here; autotune injects them
)
return out
@triton.testing.perf_report(
triton.testing.Benchmark(
x_names=["N"], # x-axis: Sequence Length
x_vals=[128 * i for i in range(2, 33)], # Sweep from 256 to 4096
line_arg="provider",
line_vals=["torch-native", "triton"],
line_names=["Torch (native)", "Triton"],
styles=[("blue", "-"), ("green", "-")],
ylabel="TFLOPS", # Changed to TFLOPS for better insight
plot_name="Flash Attention Performance",
args={"Batch": 1, "Heads": 12, "D_head": 64},
)
)
def benchmark(Batch, Heads, N, D_head, provider):
# Use the N passed from x_vals
q = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
k = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
v = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16)
quantiles = [0.5, 0.2, 0.8]
if provider == "torch-native":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: attention(q, k, v), quantiles=quantiles)
if provider == "triton":
ms, min_ms, max_ms = triton.testing.do_bench(lambda: flash_attention(q, k, v), quantiles=quantiles)
# Calculation for Attention TFLOPS:
# 2 * (Q@K) + 2 * (Softmax@V) = 4 * Batch * Heads * N^2 * D_head
tflops = lambda ms: 4 * Batch * Heads * N**2 * D_head * 1e-12 / (ms * 1e-3)
return tflops(ms), tflops(max_ms), tflops(min_ms)
# 2. --- WRAP THE RUN COMMAND IN A DECORATED FUNCTION ---
@spaces.GPU(duration=150) # High duration for Triton compilation + Benchmarking
def start_benchmarking():
# Triton saves plots to the current directory by default
save_path = "./plots"
if not os.path.exists(save_path):
os.makedirs(save_path)
# Run your original benchmark function
# Note: Ensure bench_flash_attention is defined above this
bench_flash_attention.run(save_path=save_path, print_data=True)
# Find the .png files generated by Triton
images = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith('.png')]
return images
# 3. --- CREATE THE GRADIO GUI TO KEEP THE SPACE ALIVE ---
with gr.Blocks() as demo:
gr.Markdown("# Triton Attention Benchmark")
gr.Markdown("Click the button below to trigger the ZeroGPU and run the Triton benchmark.")
run_btn = gr.Button("Run Benchmark (H100/H200)", variant="primary")
plot_gallery = gr.Gallery(label="Generated Performance Plots", columns=2)
run_btn.click(fn=start_benchmarking, outputs=plot_gallery)
demo.launch() |