import torch import triton import triton.language as tl import spaces import gradio as gr import os from PIL import Image import torch import triton.language as tl import triton # Standard SDPA def attention(q, k, v): # q, k, v shape: (B, H, N, D) # 1. Transpose K for the dot product: (B, H, D, N) # We only want to flip the last two dimensions k_t = k.transpose(-2, -1) # 2. Scaled Dot Product # d_k is the last dimension of q d_k = q.shape[-1] attn_weights = (q @ k_t) * (d_k ** -0.5) # 3. Softmax along the last dimension (columns of the score matrix) A = torch.softmax(attn_weights, dim=-1) # 4. Multiply by V: (B, H, N, N) @ (B, H, N, D) -> (B, H, N, D) O = A @ v return O # Define the search space configs = [ triton.Config({'BLOCK_M': 128, 'BLOCK_N': 64, 'num_warps': 8, 'num_stages': 2}), triton.Config({'BLOCK_M': 128, 'BLOCK_N': 128, 'num_warps': 16, 'num_stages': 2}), ] @triton.autotune( configs=configs, key=['N', 'D'], # Re-tune if sequence length or head dim changes ) @triton.jit def flash_attn_kernel( Q, K, V, Out, stride_qb, stride_qh, stride_qn, stride_qd, N, D: tl.constexpr, BLOCK_M: tl.constexpr, BLOCK_N: tl.constexpr ): batch_id = tl.program_id(0) head_id = tl.program_id(1) row_block_id = tl.program_id(2) q_ptr_base = Q + (batch_id * stride_qb) + (head_id * stride_qh) k_ptr_base = K + (batch_id * stride_qb) + (head_id * stride_qh) v_ptr_base = V + (batch_id * stride_qb) + (head_id * stride_qh) offs_m = row_block_id * BLOCK_M + tl.arange(0, BLOCK_M) offs_d = tl.arange(0, D) q_ptrs = q_ptr_base + (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd) q_block = tl.load(q_ptrs, mask=offs_m[:, None] < N, other=0.0) # --- Keep all accumulators in float32 --- m_i = tl.full([BLOCK_M], float('-inf'), dtype=tl.float32) l_i = tl.zeros([BLOCK_M], dtype=tl.float32) acc = tl.zeros([BLOCK_M, D], dtype=tl.float32) qk_scale = 1.0 / (D ** 0.5) offs_n = tl.arange(0, BLOCK_N) # K is laid out as (D, BLOCK_N) for the dot: q(M,D) @ k(D,N) k_ptrs = k_ptr_base + (offs_n[None, :] * stride_qn + offs_d[:, None] * stride_qd) v_ptrs = v_ptr_base + (offs_n[:, None] * stride_qn + offs_d[None, :] * stride_qd) for start_n in range(0, N, BLOCK_N): # Load K block: shape (D, BLOCK_N) k_block = tl.load( k_ptrs + start_n * stride_qn, mask=(start_n + offs_n[None, :]) < N, other=0.0 ) # q(M, D) @ k(D, N) -> qk(M, N) qk = tl.dot(q_block, k_block) qk = qk * qk_scale # float32 # --- Online softmax update (all float32) --- m_ij = tl.max(qk, axis=1) # (M,) m_i_new = tl.maximum(m_i, m_ij) # (M,) alpha = tl.exp(m_i - m_i_new) # (M,) rescale factor p_ij = tl.exp(qk - m_i_new[:, None]) # (M, N) in float32 l_ij = tl.sum(p_ij, axis=1) # (M,) l_i_new = alpha * l_i + l_ij # (M,) # Rescale accumulator, then add new contribution acc = acc * alpha[:, None] # Load V block: shape (BLOCK_N, D) v_block = tl.load( v_ptrs + start_n * stride_qn, mask=(start_n + offs_n[:, None]) < N, other=0.0 ) # Cast to fp16 ONLY for the dot (tensor cores), immediately cast result back acc += tl.dot(p_ij.to(tl.float16), v_block.to(tl.float16)).to(tl.float32) m_i = m_i_new l_i = l_i_new # Normalize acc = acc / l_i[:, None] # Write output — cast down to original dtype only at store out_ptrs = ( Out + (batch_id * stride_qb) + (head_id * stride_qh) + (offs_m[:, None] * stride_qn + offs_d[None, :] * stride_qd) ) tl.store(out_ptrs, acc.to(Out.dtype.element_ty), mask=offs_m[:, None] < N) def flash_attention(q, k, v): B, H, N, D = q.shape out = torch.empty_like(q) # We still need to define the grid, but we don't know BLOCK_M yet. # We can use a helper or just assume a reasonable default for grid calc. grid = lambda META: (B, H, triton.cdiv(N, META['BLOCK_M'])) flash_attn_kernel[grid]( q, k, v, out, q.stride(0), q.stride(1), q.stride(2), q.stride(3), N, D, # BLOCK_M and BLOCK_N are omitted here; autotune injects them ) return out @triton.testing.perf_report( triton.testing.Benchmark( x_names=["N"], # x-axis: Sequence Length x_vals=[128 * i for i in range(2, 33)], # Sweep from 256 to 4096 line_arg="provider", line_vals=["torch-native", "triton"], line_names=["Torch (native)", "Triton"], styles=[("blue", "-"), ("green", "-")], ylabel="TFLOPS", # Changed to TFLOPS for better insight plot_name="Flash Attention Performance", args={"Batch": 1, "Heads": 12, "D_head": 64}, ) ) def benchmark(Batch, Heads, N, D_head, provider): # Use the N passed from x_vals q = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16) k = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16) v = torch.randn((Batch, Heads, N, D_head), device="cuda", dtype=torch.float16) quantiles = [0.5, 0.2, 0.8] if provider == "torch-native": ms, min_ms, max_ms = triton.testing.do_bench(lambda: attention(q, k, v), quantiles=quantiles) if provider == "triton": ms, min_ms, max_ms = triton.testing.do_bench(lambda: flash_attention(q, k, v), quantiles=quantiles) # Calculation for Attention TFLOPS: # 2 * (Q@K) + 2 * (Softmax@V) = 4 * Batch * Heads * N^2 * D_head tflops = lambda ms: 4 * Batch * Heads * N**2 * D_head * 1e-12 / (ms * 1e-3) return tflops(ms), tflops(max_ms), tflops(min_ms) # 2. --- WRAP THE RUN COMMAND IN A DECORATED FUNCTION --- @spaces.GPU(duration=150) # High duration for Triton compilation + Benchmarking def start_benchmarking(): # Triton saves plots to the current directory by default save_path = "./plots" if not os.path.exists(save_path): os.makedirs(save_path) # Run your original benchmark function # Note: Ensure bench_flash_attention is defined above this bench_flash_attention.run(save_path=save_path, print_data=True) # Find the .png files generated by Triton images = [os.path.join(save_path, f) for f in os.listdir(save_path) if f.endswith('.png')] return images # 3. --- CREATE THE GRADIO GUI TO KEEP THE SPACE ALIVE --- with gr.Blocks() as demo: gr.Markdown("# Triton Attention Benchmark") gr.Markdown("Click the button below to trigger the ZeroGPU and run the Triton benchmark.") run_btn = gr.Button("Run Benchmark (H100/H200)", variant="primary") plot_gallery = gr.Gallery(label="Generated Performance Plots", columns=2) run_btn.click(fn=start_benchmarking, outputs=plot_gallery) demo.launch()