File size: 16,915 Bytes
18b382b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 |
#!/usr/bin/env python3
# Copyright (c) Delanoe Pirard / Aedelon - Apache 2.0
"""
Flash Attention Benchmark for Depth Anything 3.
Provides clear performance comparison with tables and analysis.
Usage:
python benchmarks/flash_attention_benchmark.py
python benchmarks/flash_attention_benchmark.py --detailed
"""
import argparse
import gc
import os
import sys
import time
from dataclasses import dataclass
import torch
sys.path.insert(0, os.path.join(os.path.dirname(__file__), "..", "src"))
from depth_anything_3.model.dinov2.layers import (
FLASH_ATTN_AVAILABLE,
FLASH_ATTN_VERSION,
Attention,
)
@dataclass
class BenchmarkConfig:
"""Configuration for a benchmark test case."""
name: str
seq_len: int
batch_size: int
embed_dim: int
num_heads: int
image_size: str # Description of corresponding image size
@property
def description(self):
return f"{self.name} ({self.image_size})"
# Depth Anything 3 model configurations
DA3_CONFIGS = {
"vitb": {"embed_dim": 768, "num_heads": 12, "depth": 12},
"vitl": {"embed_dim": 1024, "num_heads": 16, "depth": 24},
"vitg": {"embed_dim": 1536, "num_heads": 24, "depth": 40},
}
def get_device_info():
"""Get device information."""
if torch.cuda.is_available():
device = torch.device("cuda")
device_name = torch.cuda.get_device_name()
memory_gb = torch.cuda.get_device_properties(0).total_memory / 1024**3
compute_cap = torch.cuda.get_device_capability()
return {
"type": "cuda",
"device": device,
"name": device_name,
"memory_gb": memory_gb,
"compute_capability": f"{compute_cap[0]}.{compute_cap[1]}",
}
elif torch.backends.mps.is_available():
return {
"type": "mps",
"device": torch.device("mps"),
"name": "Apple Silicon",
"memory_gb": None,
"compute_capability": None,
}
else:
return {
"type": "cpu",
"device": torch.device("cpu"),
"name": "CPU",
"memory_gb": None,
"compute_capability": None,
}
def benchmark_attention(attn_module, x, warmup=5, runs=20):
"""Run benchmark for a single attention module."""
device = x.device
# Warmup
with torch.no_grad():
for _ in range(warmup):
_ = attn_module(x)
if device.type == "cuda":
torch.cuda.synchronize()
# Reset memory tracking
if device.type == "cuda":
torch.cuda.reset_peak_memory_stats()
# Benchmark
times = []
with torch.no_grad():
for _ in range(runs):
if device.type == "cuda":
torch.cuda.synchronize()
start = time.perf_counter()
_ = attn_module(x)
if device.type == "cuda":
torch.cuda.synchronize()
times.append((time.perf_counter() - start) * 1000)
# Memory
peak_mem_mb = 0
if device.type == "cuda":
peak_mem_mb = torch.cuda.max_memory_allocated() / 1024 / 1024
times_tensor = torch.tensor(times)
return {
"mean_ms": times_tensor.mean().item(),
"std_ms": times_tensor.std().item(),
"min_ms": times_tensor.min().item(),
"peak_mem_mb": peak_mem_mb,
}
def print_header():
"""Print benchmark header."""
print("\n" + "=" * 80)
print(" " * 20 + "FLASH ATTENTION BENCHMARK - DEPTH ANYTHING 3")
print("=" * 80 + "\n")
def get_sdpa_backend_info():
"""Get info about which SDPA backend is being used."""
info = {}
if torch.cuda.is_available():
from torch.backends.cuda import (
flash_sdp_enabled,
mem_efficient_sdp_enabled,
math_sdp_enabled,
)
info["flash_sdp"] = flash_sdp_enabled()
info["mem_efficient_sdp"] = mem_efficient_sdp_enabled()
info["math_sdp"] = math_sdp_enabled()
return info
def print_device_info(device_info):
"""Print device information."""
print("π HARDWARE CONFIGURATION")
print("β" * 80)
print(f" Device Type : {device_info['type'].upper()}")
print(f" Device Name : {device_info['name']}")
if device_info["memory_gb"]:
print(f" Memory : {device_info['memory_gb']:.1f} GB")
if device_info["compute_capability"]:
print(f" Compute Cap. : {device_info['compute_capability']}")
cc = float(device_info["compute_capability"])
if cc >= 7.5:
print(f" β
Flash Attention supported (β₯7.5)")
else:
print(f" β Flash Attention requires β₯7.5")
# SDPA backend info
sdpa_info = get_sdpa_backend_info()
if sdpa_info:
print(f"\n PyTorch SDPA Backends:")
print(f" Flash SDP : {'β
Enabled' if sdpa_info.get('flash_sdp') else 'β Disabled'}")
print(f" MemEfficient : {'β
Enabled' if sdpa_info.get('mem_efficient_sdp') else 'β Disabled'}")
print(f" Math SDP : {'β
Enabled' if sdpa_info.get('math_sdp') else 'β Disabled'}")
if sdpa_info.get('flash_sdp'):
print(f"\n β‘ PyTorch SDPA uses Flash Attention internally!")
print(f" (No need for flash-attn package with PyTorch >= 2.2)")
print(f"\n flash-attn pkg : {'β
Installed v' + FLASH_ATTN_VERSION if FLASH_ATTN_AVAILABLE else 'β Not installed (optional)'}")
print()
def print_table_header():
"""Print benchmark table header."""
print(
"ββββββββββββββββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬βββββββββββββββ¬βββββββββββββ"
)
print(
"β Configuration β flash_attn β sdpa β manual β Speedup β"
)
print(
"ββββββββββββββββββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββββΌβββββββββββββ€"
)
def print_table_row(config_desc, results, baseline="sdpa"):
"""Print a benchmark result row."""
backends = ["flash_attn", "sdpa", "manual"]
# Format times
time_strs = []
for backend in backends:
if backend in results and results[backend]:
time_ms = results[backend]["mean_ms"]
time_strs.append(f"{time_ms:6.2f} ms")
else:
time_strs.append(" N/A")
# Calculate speedup
speedup_str = " -"
if "flash_attn" in results and results["flash_attn"] and baseline in results:
if results[baseline]:
speedup = results[baseline]["mean_ms"] / results["flash_attn"]["mean_ms"]
speedup_str = f" {speedup:.2f}x β‘" if speedup > 1.1 else f" {speedup:.2f}x"
print(
f"β {config_desc:24s} β {time_strs[0]:12s} β {time_strs[1]:12s} β {time_strs[2]:12s} β {speedup_str:10s} β"
)
def print_table_footer():
"""Print benchmark table footer."""
print(
"ββββββββββββββββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄βββββββββββββββ΄βββββββββββββ"
)
def print_model_analysis(model_name, config, results, num_layers):
"""Print detailed analysis for a specific model."""
if "flash_attn" not in results or not results["flash_attn"]:
return
flash_time = results["flash_attn"]["mean_ms"]
sdpa_time = results["sdpa"]["mean_ms"] if "sdpa" in results else flash_time
speedup = sdpa_time / flash_time
time_saved_per_layer = (sdpa_time - flash_time) / num_layers
total_time_saved = time_saved_per_layer * num_layers
print(f"\n π {model_name} Analysis:")
print(f" β’ Attention time per layer: {flash_time:.2f} ms (flash) vs {sdpa_time:.2f} ms (sdpa)")
print(f" β’ Time saved per layer: {time_saved_per_layer:.2f} ms")
print(f" β’ Total time saved ({num_layers} layers): {total_time_saved:.1f} ms")
print(f" β’ Speedup: {speedup:.2f}x on attention")
# Estimate full inference impact
# Attention is ~15-20% of total inference time
attn_fraction = 0.175
overall_speedup = 1 / (1 - attn_fraction + attn_fraction / speedup)
overall_improvement = (1 - 1 / overall_speedup) * 100
print(
f" β’ Estimated full inference speedup: {overall_speedup:.2f}x (~{overall_improvement:.1f}% faster)"
)
def run_benchmark(test_configs, backends, warmup=5, runs=20, detailed=False):
"""Run complete benchmark suite."""
device_info = get_device_info()
device = device_info["device"]
dtype = torch.float16 if device.type == "cuda" else torch.float32
print_header()
print_device_info(device_info)
# Filter backends based on availability
available_backends = []
if FLASH_ATTN_AVAILABLE and device.type == "cuda":
available_backends.append("flash_attn")
available_backends.append("sdpa")
if detailed:
available_backends.append("manual")
all_results = {}
# Run benchmarks by model
for model_name, model_config in DA3_CONFIGS.items():
print(f"\n㪠MODEL: {model_name.upper()} (dim={model_config['embed_dim']}, heads={model_config['num_heads']}, depth={model_config['depth']})")
print("β" * 80)
print_table_header()
model_results = {}
for test_config in test_configs:
# Adjust config for this model
config = BenchmarkConfig(
name=test_config.name,
seq_len=test_config.seq_len,
batch_size=test_config.batch_size,
embed_dim=model_config["embed_dim"],
num_heads=model_config["num_heads"],
image_size=test_config.image_size,
)
x = torch.randn(
config.batch_size, config.seq_len, config.embed_dim, device=device, dtype=dtype
)
results = {}
for backend in available_backends:
gc.collect()
if device.type == "cuda":
torch.cuda.empty_cache()
try:
attn = Attention(
dim=config.embed_dim,
num_heads=config.num_heads,
attn_backend=backend,
).to(device, dtype)
attn.eval()
result = benchmark_attention(attn, x, warmup=warmup, runs=runs)
results[backend] = result
del attn
except Exception as e:
results[backend] = None
if detailed:
print(f" {backend} failed: {e}")
model_results[config.name] = results
print_table_row(config.description, results)
print_table_footer()
# Analysis for this model
if detailed and model_results:
# Use medium config for analysis
medium_key = next(
(k for k in model_results.keys() if "1024" in k.lower() or "medium" in k.lower()),
list(model_results.keys())[0],
)
print_model_analysis(
model_name.upper(),
test_configs[0],
model_results[medium_key],
model_config["depth"],
)
all_results[model_name] = model_results
# Final summary
print("\n" + "=" * 80)
print("π SUMMARY & RECOMMENDATIONS")
print("=" * 80)
sdpa_info = get_sdpa_backend_info()
if device.type == "cuda":
# Check if PyTorch SDPA has Flash enabled
if sdpa_info.get('flash_sdp'):
print("\nβ
Flash Attention is ACTIVE via PyTorch SDPA!")
print("\n Your setup:")
print(f" β’ PyTorch {torch.__version__} with native Flash Attention")
print(" β’ SDPA backend: Flash SDP β‘")
print(" β’ No additional packages needed!")
print("\n Benefits you're already getting:")
print(" β’ 2-4x faster attention vs manual implementation")
print(" β’ Memory-efficient attention computation")
print(" β’ Automatic kernel selection per input size")
if FLASH_ATTN_AVAILABLE:
print(f"\n βΉοΈ flash-attn v{FLASH_ATTN_VERSION} also installed")
print(" (May provide slight additional optimization in some cases)")
else:
print("\n βΉοΈ flash-attn package: Not needed!")
print(" PyTorch >= 2.2 includes Flash Attention natively.")
elif FLASH_ATTN_AVAILABLE:
print("\nβ
Flash Attention is ACTIVE via flash-attn package")
print(f"\n Using flash-attn v{FLASH_ATTN_VERSION}")
print("\n Benefits:")
print(" β’ 2-3x faster attention computation")
print(" β’ ~15-25% overall inference speedup")
print(" β’ Lower memory usage")
else:
print("\nβ οΈ Flash Attention not available")
print("\n Options to enable:")
print(" 1. Upgrade PyTorch to >= 2.2 (recommended)")
print(" 2. Install flash-attn: pip install flash-attn --no-build-isolation")
elif device.type == "mps":
print("\nπ± Apple Silicon (MPS) detected")
print("\n β’ Flash Attention not available for MPS")
print(" β’ PyTorch SDPA uses optimized Metal kernels")
print(" β’ Already running at optimal speed for your hardware")
else:
print("\nπ» CPU detected")
print("\n β’ Consider using GPU for faster inference")
print(" β’ Flash Attention is CUDA-only")
# Print SDPA vs Manual speedup summary
print("\n" + "β" * 80)
print("β‘ PERFORMANCE COMPARISON")
print("β" * 80)
print("\n SDPA vs Manual attention speedup (per layer):")
for model_name, model_results in all_results.items():
if model_results:
# Get XLarge config results for most impact
xlarge_key = next((k for k in model_results.keys() if "xlarge" in k.lower()), list(model_results.keys())[-1])
if xlarge_key in model_results:
res = model_results[xlarge_key]
if res.get("sdpa") and res.get("manual"):
speedup = res["manual"]["mean_ms"] / res["sdpa"]["mean_ms"]
print(f" β’ {model_name.upper():6s}: {speedup:.1f}x faster (sdpa: {res['sdpa']['mean_ms']:.2f}ms vs manual: {res['manual']['mean_ms']:.2f}ms)")
print("\n" + "=" * 80)
print()
return all_results
def main():
parser = argparse.ArgumentParser(description="Flash Attention benchmark for DA3")
parser.add_argument(
"--detailed",
action="store_true",
help="Show detailed analysis and include manual backend",
)
parser.add_argument(
"--warmup",
type=int,
default=5,
help="Warmup iterations (default: 5)",
)
parser.add_argument(
"--runs",
type=int,
default=20,
help="Benchmark runs (default: 20)",
)
args = parser.parse_args()
# Test configurations based on common image sizes
test_configs = [
BenchmarkConfig(
name="Small",
seq_len=256,
batch_size=1,
embed_dim=768, # Will be overridden per model
num_heads=12, # Will be overridden per model
image_size="392px image",
),
BenchmarkConfig(
name="Medium",
seq_len=529,
batch_size=1,
embed_dim=768,
num_heads=12,
image_size="518px image",
),
BenchmarkConfig(
name="Large",
seq_len=1024,
batch_size=1,
embed_dim=768,
num_heads=12,
image_size="742px image",
),
BenchmarkConfig(
name="XLarge",
seq_len=1369,
batch_size=1,
embed_dim=768,
num_heads=12,
image_size="1024px image",
),
]
backends = ["flash_attn", "sdpa"]
if args.detailed:
backends.append("manual")
run_benchmark(
test_configs=test_configs,
backends=backends,
warmup=args.warmup,
runs=args.runs,
detailed=args.detailed,
)
if __name__ == "__main__":
main() |