| |
| """Benchmark causal mask performance scaling with sequence length""" |
|
|
| import torch |
| import time |
| import matplotlib.pyplot as plt |
| import numpy as np |
| from typing import List |
| import kernels |
|
|
| metal_flash_sdpa = kernels.get_kernel("kernels-community/metal-flash-sdpa") |
|
|
|
|
| def create_cu_seqlens(seq_lengths: List[int]) -> torch.Tensor: |
| """Create cumulative sequence lengths tensor.""" |
| cu_seqlens = [0] |
| for length in seq_lengths: |
| cu_seqlens.append(cu_seqlens[-1] + length) |
| return torch.tensor(cu_seqlens, dtype=torch.int32, device="mps") |
|
|
|
|
| def benchmark_flash_sdpa_causal( |
| batch_size: int, |
| num_heads: int, |
| seq_len: int, |
| head_dim: int, |
| dtype: torch.dtype, |
| num_iterations: int = 20, |
| ) -> float: |
| """Benchmark Flash SDPA with causal mask""" |
|
|
| seq_lengths = [seq_len] * batch_size |
| cu_seqlens = create_cu_seqlens(seq_lengths) |
| total_tokens = sum(seq_lengths) |
|
|
| |
| query = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
| key = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
| value = torch.randn(total_tokens, num_heads, head_dim, dtype=dtype, device="mps") |
| out = torch.empty_like(query) |
|
|
| scale = 1.0 / (head_dim**0.5) |
|
|
| |
| for _ in range(5): |
| metal_flash_sdpa.flash_attention_varlen( |
| out=out, |
| query=query, |
| key=key, |
| value=value, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=seq_len, |
| max_seqlen_k=seq_len, |
| do_causal=True, |
| scale=scale, |
| softcapping=1.0, |
| ) |
| torch.mps.synchronize() |
|
|
| |
| start_time = time.perf_counter() |
| for _ in range(num_iterations): |
| metal_flash_sdpa.flash_attention_varlen( |
| out=out, |
| query=query, |
| key=key, |
| value=value, |
| cu_seqlens_q=cu_seqlens, |
| cu_seqlens_k=cu_seqlens, |
| max_seqlen_q=seq_len, |
| max_seqlen_k=seq_len, |
| do_causal=True, |
| scale=scale, |
| softcapping=1.0, |
| ) |
| torch.mps.synchronize() |
| end_time = time.perf_counter() |
|
|
| return (end_time - start_time) * 1000 / num_iterations |
|
|
|
|
| def benchmark_naive_sdpa_causal( |
| batch_size: int, |
| num_heads: int, |
| seq_len: int, |
| head_dim: int, |
| dtype: torch.dtype, |
| num_iterations: int = 20, |
| ) -> float: |
| """Benchmark naive SDPA with causal mask""" |
|
|
| |
| query = torch.randn( |
| batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps" |
| ) |
| key = torch.randn( |
| batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps" |
| ) |
| value = torch.randn( |
| batch_size, num_heads, seq_len, head_dim, dtype=dtype, device="mps" |
| ) |
|
|
| scale = 1.0 / (head_dim**0.5) |
|
|
| |
| mask = torch.triu(torch.ones(seq_len, seq_len, device="mps"), diagonal=1).bool() |
|
|
| |
| for _ in range(5): |
| scores = torch.matmul(query, key.transpose(-2, -1)) * scale |
| scores = scores.masked_fill(mask, float("-inf")) |
| attn_weights = torch.softmax(scores, dim=-1) |
| out = torch.matmul(attn_weights, value) |
| torch.mps.synchronize() |
|
|
| |
| start_time = time.perf_counter() |
| for _ in range(num_iterations): |
| scores = torch.matmul(query, key.transpose(-2, -1)) * scale |
| scores = scores.masked_fill(mask, float("-inf")) |
| attn_weights = torch.softmax(scores, dim=-1) |
| out = torch.matmul(attn_weights, value) |
| torch.mps.synchronize() |
| end_time = time.perf_counter() |
|
|
| return (end_time - start_time) * 1000 / num_iterations |
|
|
|
|
| def run_scaling_benchmark(): |
| """Run causal mask scaling benchmark""" |
|
|
| print("=" * 80) |
| print("Causal Mask Performance Scaling Benchmark") |
| print("Batch Size: 4, Head Dimension: 64") |
| print("=" * 80) |
|
|
| |
| batch_size = 4 |
| num_heads = 16 |
| head_dim = 64 |
| dtype = torch.float16 |
|
|
| |
| seq_lengths = [512, 768, 1024, 1536, 2048, 3072, 4096] |
|
|
| flash_times = [] |
| naive_times = [] |
| speedups = [] |
|
|
| print(f"{'Seq Len':<8} {'Flash (ms)':<12} {'Naive (ms)':<12} {'Speedup':<10}") |
| print("-" * 50) |
|
|
| for seq_len in seq_lengths: |
| |
| flash_time = benchmark_flash_sdpa_causal( |
| batch_size, num_heads, seq_len, head_dim, dtype |
| ) |
| flash_times.append(flash_time) |
|
|
| |
| naive_time = benchmark_naive_sdpa_causal( |
| batch_size, num_heads, seq_len, head_dim, dtype |
| ) |
| naive_times.append(naive_time) |
|
|
| speedup = naive_time / flash_time |
| speedups.append(speedup) |
|
|
| print(f"{seq_len:<8} {flash_time:<12.2f} {naive_time:<12.2f} {speedup:<10.2f}x") |
|
|
| return seq_lengths, flash_times, naive_times, speedups |
|
|
|
|
| def create_line_plot(seq_lengths, flash_times, naive_times, speedups): |
| """Create line graph visualization""" |
|
|
| |
| fig, ax = plt.subplots(1, 1, figsize=(12, 8)) |
| fig.suptitle( |
| "Causal Mask Performance Scaling\n(Batch Size: 4, Head Dimension: 64)", |
| fontsize=16, |
| ) |
|
|
| |
| ax.plot( |
| seq_lengths, |
| flash_times, |
| marker="o", |
| linewidth=3, |
| markersize=10, |
| label="Flash SDPA", |
| color="blue", |
| ) |
| ax.plot( |
| seq_lengths, |
| naive_times, |
| marker="s", |
| linewidth=3, |
| markersize=10, |
| label="Naive SDPA", |
| color="red", |
| ) |
|
|
| ax.set_xlabel("Sequence Length", fontsize=14) |
| ax.set_ylabel("Time (ms)", fontsize=14) |
| ax.set_title("Execution Time vs Sequence Length", fontsize=16) |
| ax.grid(True, alpha=0.3) |
| ax.legend(fontsize=12) |
|
|
| |
| for i, (seq_len, flash_time, naive_time) in enumerate( |
| zip(seq_lengths, flash_times, naive_times) |
| ): |
| ax.annotate( |
| f"{flash_time:.1f}ms", |
| xy=(seq_len, flash_time), |
| xytext=(5, 5), |
| textcoords="offset points", |
| fontsize=10, |
| color="blue", |
| ) |
| ax.annotate( |
| f"{naive_time:.1f}ms", |
| xy=(seq_len, naive_time), |
| xytext=(5, 5), |
| textcoords="offset points", |
| fontsize=10, |
| color="red", |
| ) |
|
|
| |
| ax.set_xlim(seq_lengths[0] - 100, seq_lengths[-1] + 100) |
| ax.set_ylim(0, max(naive_times) * 1.1) |
|
|
| plt.tight_layout() |
| plt.savefig("benchmark.png", dpi=300, bbox_inches="tight") |
| plt.show() |
|
|
|
|
| def print_analysis(seq_lengths, flash_times, naive_times, speedups): |
| """Print detailed analysis of the results""" |
|
|
| print("\n" + "=" * 80) |
| print("DETAILED ANALYSIS") |
| print("=" * 80) |
|
|
| |
| print("\n1. Performance Scaling:") |
| print( |
| f" • Flash SDPA: {flash_times[0]:.2f}ms → {flash_times[-1]:.2f}ms ({flash_times[-1] / flash_times[0]:.1f}x increase)" |
| ) |
| print( |
| f" • Naive SDPA: {naive_times[0]:.2f}ms → {naive_times[-1]:.2f}ms ({naive_times[-1] / naive_times[0]:.1f}x increase)" |
| ) |
|
|
| |
| print("\n2. Speedup Analysis:") |
| print(f" • Average Speedup: {np.mean(speedups):.2f}x") |
| print( |
| f" • Max Speedup: {np.max(speedups):.2f}x (at seq_len={seq_lengths[np.argmax(speedups)]})" |
| ) |
| print( |
| f" • Min Speedup: {np.min(speedups):.2f}x (at seq_len={seq_lengths[np.argmin(speedups)]})" |
| ) |
|
|
| |
| print("\n3. Efficiency Analysis:") |
| speedup_improvement = speedups[-1] / speedups[0] |
| print(f" • Speedup improvement from 512→4096: {speedup_improvement:.2f}x") |
|
|
| if speedup_improvement > 1.1: |
| print(" • Flash SDPA becomes MORE efficient at longer sequences") |
| elif speedup_improvement < 0.9: |
| print(" • Flash SDPA becomes LESS efficient at longer sequences") |
| else: |
| print(" • Flash SDPA maintains consistent efficiency across sequence lengths") |
|
|
| |
| print("\n4. Theoretical Complexity:") |
| print(f" • Sequence length increased by: {seq_lengths[-1] / seq_lengths[0]:.1f}x") |
| print( |
| f" • Theoretical O(n²) complexity increase: {(seq_lengths[-1] / seq_lengths[0]) ** 2:.1f}x" |
| ) |
| print(f" • Actual Flash SDPA increase: {flash_times[-1] / flash_times[0]:.1f}x") |
| efficiency_ratio = (flash_times[-1] / flash_times[0]) / ( |
| (seq_lengths[-1] / seq_lengths[0]) ** 2 |
| ) |
| print(f" • Flash SDPA efficiency ratio: {efficiency_ratio:.3f} (lower is better)") |
|
|
|
|
| def main(): |
| |
| seq_lengths, flash_times, naive_times, speedups = run_scaling_benchmark() |
|
|
| |
| create_line_plot(seq_lengths, flash_times, naive_times, speedups) |
|
|
| |
| print_analysis(seq_lengths, flash_times, naive_times, speedups) |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|