File size: 10,979 Bytes
62dca4c | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 | import argparse
import time
import matplotlib.pyplot as plt
import numpy as np
import torch
import torch._dynamo as dynamo
from transformers import LlamaConfig
from transformers.cache_utils import DynamicCache
from specforge.modeling.draft.llama3_eagle import (
LlamaAttention,
LlamaFlexAttention,
prepare_decoder_attention_mask,
)
dynamo.config.recompile_limit = 64
config_dict = {
"hidden_size": 4096,
"num_attention_heads": 32,
"num_key_value_heads": 8,
"max_position_embeddings": 16384,
"rms_norm_eps": 1e-05,
"vocab_size": 32000,
"hidden_act": "silu",
"num_hidden_layers": 1,
}
config = LlamaConfig(**config_dict)
TTT_LENGTH = 7
BATCH_SIZE = 4
HIDDEN_SIZE = config.hidden_size * 2
def run_attention(
seq_len: int,
hidden_states_list: list[torch.Tensor],
attention_backend: str = "sdpa",
enable_profile: bool = False,
):
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
batch_size = hidden_states_list[0].shape[0]
# Initialize cache and attention function based on backend
if attention_backend == "sdpa":
cache_hidden = [[], []]
past_key_values = None
attn_func = LlamaAttention(config).to(device).to(torch.bfloat16)
elif attention_backend == "flex_attention":
cache_hidden = None
past_key_values = DynamicCache()
attn_func = LlamaFlexAttention(config).to(device).to(torch.bfloat16)
else:
raise ValueError(f"Unknown attention backend: {attention_backend}")
# Simulate inputs - move to device
position_ids = torch.arange(seq_len).unsqueeze(0).repeat(batch_size, 1).to(device)
input_embeds = torch.randn(batch_size, seq_len, config.hidden_size).to(device)
attention_mask = torch.ones(batch_size, seq_len).to(device)
decoder_attention_mask = prepare_decoder_attention_mask(
attention_mask=attention_mask,
input_shape=(batch_size, seq_len),
inputs_embeds=input_embeds,
past_key_values_length=0,
)
loss_list = []
if attention_backend == "flex_attention" and enable_profile:
profiler = torch.profiler.profile(
activities=[
torch.profiler.ProfilerActivity.CPU,
torch.profiler.ProfilerActivity.CUDA,
],
on_trace_ready=torch.profiler.tensorboard_trace_handler(
f"./profiler_logs/{attention_backend}"
),
record_shapes=False,
profile_memory=False,
with_stack=True,
with_modules=False,
)
profiler.start()
for idx in range(TTT_LENGTH):
is_last = idx == TTT_LENGTH - 1
hidden_states = hidden_states_list[idx]
# Call attention function with appropriate parameters
if attention_backend == "sdpa":
output = attn_func(
hidden_states=hidden_states,
attention_mask=decoder_attention_mask,
position_ids=position_ids,
cache_hidden=cache_hidden,
output_attentions=False,
use_cache=True,
)
else: # flex_attention
output = attn_func(
hidden_states=hidden_states,
attention_mask=attention_mask,
position_ids=position_ids,
past_key_values=past_key_values,
output_attentions=False,
use_cache=True,
)
# Compute a simple loss for benchmarking
loss = output[0].sum()
loss_list.append(loss)
# Compute mean loss and backward pass
if loss_list:
mean_loss = sum(loss_list) / len(loss_list)
mean_loss.backward()
if attention_backend == "flex_attention" and enable_profile:
profiler.stop()
def benchmark_function(
attention_backend: str,
seq_lengths: list,
enable_profile: bool = False,
enable_warmup: bool = True,
):
"""Benchmark a function for speed and GPU memory usage per sequence length."""
print(f"\n=== Benchmarking {attention_backend} ===")
results_per_seq_len = []
for seq_len in seq_lengths:
print(f"\nTesting sequence length: {seq_len}")
# Clear GPU cache
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Warm up runs for this sequence length
if enable_warmup:
print("Warming up...")
for _ in range(2):
hidden_states = [
torch.randn(
BATCH_SIZE,
seq_len,
HIDDEN_SIZE,
requires_grad=True,
device="cuda",
dtype=torch.bfloat16,
)
for _ in range(TTT_LENGTH)
]
run_attention(seq_len, hidden_states, attention_backend)
# Clear cache again after warmup
if torch.cuda.is_available():
torch.cuda.empty_cache()
torch.cuda.reset_peak_memory_stats()
# Record initial memory
initial_memory = 0
if torch.cuda.is_available():
initial_memory = torch.cuda.memory_allocated()
hidden_states = [
torch.randn(
BATCH_SIZE,
seq_len,
HIDDEN_SIZE,
requires_grad=True,
device="cuda",
dtype=torch.bfloat16,
)
for _ in range(TTT_LENGTH)
]
start_time = time.time()
run_attention(
seq_len,
hidden_states,
attention_backend,
enable_profile and seq_len == seq_lengths[0],
)
if torch.cuda.is_available():
torch.cuda.synchronize()
end_time = time.time()
# Record memory usage
peak_memory = 0
current_memory = 0
if torch.cuda.is_available():
peak_memory = torch.cuda.max_memory_allocated()
current_memory = torch.cuda.memory_allocated()
results_per_seq_len.append(
{
"seq_len": seq_len,
"time": end_time - start_time,
"peak_memory": peak_memory,
"memory_increase": current_memory - initial_memory,
}
)
print(f" Time: {end_time - start_time:.3f}s")
print(f" Peak memory: {peak_memory / 1024**3:.3f} GB")
print(
f" Memory increase: {(current_memory - initial_memory) / 1024**3:.3f} GB"
)
return results_per_seq_len
def plot_results(eagle_results, flex_results, seq_lengths):
"""Plot speed and memory comparison between Eagle and Flex attention."""
# Extract data for plotting
eagle_times = [r["time"] for r in eagle_results]
flex_times = [r["time"] for r in flex_results]
eagle_memory = [r["peak_memory"] / 1024**3 for r in eagle_results] # Convert to GB
flex_memory = [r["peak_memory"] / 1024**3 for r in flex_results] # Convert to GB
# Create subplots
fig, (ax1, ax2) = plt.subplots(1, 2, figsize=(15, 6))
# Speed comparison plot
ax1.plot(
seq_lengths, eagle_times, "b-o", label="Eagle (SDPA)", linewidth=2, markersize=8
)
ax1.plot(
seq_lengths,
flex_times,
"r-s",
label="Flex Attention",
linewidth=2,
markersize=8,
)
ax1.set_xlabel("Sequence Length")
ax1.set_ylabel("Time (seconds)")
ax1.set_title("Speed Comparison: Eagle vs Flex Attention")
ax1.legend()
ax1.grid(True, alpha=0.3)
ax1.set_xscale("linear")
ax1.set_yscale("log")
# Memory comparison plot
ax2.plot(
seq_lengths,
eagle_memory,
"b-o",
label="Eagle (SDPA)",
linewidth=2,
markersize=8,
)
ax2.plot(
seq_lengths,
flex_memory,
"r-s",
label="Flex Attention",
linewidth=2,
markersize=8,
)
ax2.set_xlabel("Sequence Length")
ax2.set_ylabel("Peak Memory (GB)")
ax2.set_title("Memory Usage Comparison: Eagle vs Flex Attention")
ax2.legend()
ax2.grid(True, alpha=0.3)
# Set y-axis ticks every 10GB
max_memory = max(max(eagle_memory), max(flex_memory))
ax2.set_yticks(np.arange(0, max_memory + 10, 10))
plt.tight_layout()
plt.savefig("attention_benchmark_comparison.png", dpi=300, bbox_inches="tight")
plt.show()
# Print summary statistics
print(f"\n=== Performance Summary ===")
print(f"Sequence lengths tested: {seq_lengths}")
print(f"\nSpeed ratios (Eagle/Flex):")
for i, seq_len in enumerate(seq_lengths):
ratio = eagle_times[i] / flex_times[i] if flex_times[i] > 0 else float("inf")
print(f" {seq_len:4d}: {ratio:.2f}x")
print(f"\nMemory ratios (Eagle/Flex):")
for i, seq_len in enumerate(seq_lengths):
ratio = eagle_memory[i] / flex_memory[i] if flex_memory[i] > 0 else float("inf")
print(f" {seq_len:4d}: {ratio:.2f}x")
if __name__ == "__main__":
parser = argparse.ArgumentParser(description="Benchmark attention mechanisms")
parser.add_argument(
"--enable-profile", action="store_true", help="Enable profiling"
)
args = parser.parse_args()
print("PyTorch version:", torch.__version__)
if torch.cuda.is_available():
print("CUDA available:", torch.cuda.is_available())
print("GPU:", torch.cuda.get_device_name())
print(
"GPU memory:",
torch.cuda.get_device_properties(0).total_memory / 1024**3,
"GB",
)
else:
print("CUDA not available - running on CPU")
# Define sequence lengths to test
seq_lengths = [128 * i for i in range(1, 28, 4)]
# Add extra long context
seq_lengths.extend([16384, 32768])
print(f"Testing sequence lengths: {seq_lengths}")
# Run benchmarks
print("\n" + "=" * 50)
# Truncate seqlen after 2560 since naive eagle goes OOM
eagle_seq_lengths = [seq_len for seq_len in seq_lengths if seq_len <= 2560]
eagle_results = benchmark_function("sdpa", eagle_seq_lengths)
print("\n" + "=" * 50)
flex_results = benchmark_function(
"flex_attention", seq_lengths, enable_profile=args.enable_profile
)
# Pad the memory usage on eagle to max memory 80GB when data not available
max_time = max(result["time"] for result in flex_results)
for result in flex_results:
if result["seq_len"] not in eagle_seq_lengths:
eagle_results.append(
{
"seq_len": result["seq_len"],
"time": max_time,
"peak_memory": 80 * 1024**3,
"memory_increase": 0, # Not used in plotting
}
)
# Plot results
plot_results(eagle_results, flex_results, seq_lengths)
|