File size: 8,374 Bytes
33efa44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 |
#!/usr/bin/env python3
"""
Benchmark comparison: Standard generation vs Cache-optimized generation.
This script measures and compares:
- Layer forward counts
- Wall clock time
- Tokens per second
Usage:
python tests/run_benchmark.py --model Qwen/Qwen3-0.6B --heads-path /path/to/heads.pt
"""
import argparse
import time
import sys
import os
# Add project to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
def make_dummy_decoder():
"""Create a minimal decoder for benchmarking without GPU."""
from src.jagged_cache import JaggedKVCache
print("\n" + "=" * 60)
print("BENCHMARK: JaggedKVCache Operations (No GPU Required)")
print("=" * 60)
# Test cache performance
num_layers = 28
batch_size = 1
num_heads = 8
head_dim = 128
seq_len = 100
cache = JaggedKVCache(
num_layers=num_layers,
batch_size=batch_size,
num_kv_heads=num_heads,
head_dim=head_dim,
device="cpu",
dtype=torch.float32,
)
# Simulate prefill
print(f"\nSimulating prefill ({seq_len} tokens, {num_layers} layers)...")
start = time.perf_counter()
for pos in range(seq_len):
for layer_idx in range(num_layers):
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
cache.update(layer_idx, k, v, torch.tensor([pos]))
prefill_time = (time.perf_counter() - start) * 1000
print(f" Prefill time: {prefill_time:.2f} ms")
# Simulate draft phase (early exit at different layers)
print("\nSimulating draft phase (5 tokens, variable exit layers)...")
exit_layers = [4, 8, 6, 12, 10] # Simulate different exit layers
draft_cache = cache.clone()
start = time.perf_counter()
for i, exit_layer in enumerate(exit_layers):
pos = seq_len + i
for layer_idx in range(exit_layer + 1):
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
draft_time = (time.perf_counter() - start) * 1000
print(f" Draft time: {draft_time:.2f} ms")
# Print cache state
print("\nCache state after drafting:")
for layer_idx in [0, 4, 8, 12, 16, 20, 24, 27]:
filled = len(draft_cache.filled_positions[layer_idx])
print(f" Layer {layer_idx:2d}: {filled} positions filled")
# Simulate verification (fill all layers for all positions)
print("\nSimulating verification (lazy fill + full model)...")
start = time.perf_counter()
for pos in range(seq_len, seq_len + 5):
# Find missing layers
missing = draft_cache.get_missing_layers(pos, num_layers - 1)
for layer_idx in missing:
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
verify_time = (time.perf_counter() - start) * 1000
print(f" Verify time: {verify_time:.2f} ms")
# Calculate and explain savings
print("\n" + "=" * 60)
print("ANALYSIS: Layer Operations")
print("=" * 60)
# Prefill ops (same for all approaches - one-time cost)
prefill_ops = seq_len * num_layers
print(f"\nPREFILL (one-time): {prefill_ops} layer ops")
# Draft phase with early exit
draft_ops = sum(exit_layer + 1 for exit_layer in exit_layers)
draft_ops_full = 5 * num_layers # Without early exit
print(f"\nDRAFT PHASE (5 tokens):")
print(f" With early exit: {draft_ops} ops (avg {draft_ops / 5:.1f} layers/token)")
print(f" Without early exit: {draft_ops_full} ops ({num_layers} layers/token)")
print(
f" Draft savings: {draft_ops_full - draft_ops} ops ({100 * (1 - draft_ops / draft_ops_full):.0f}% reduction)"
)
# The KEY benefit: with cache, each draft token is O(1 token * exit_layer)
# Without cache, it would be O(seq_len * exit_layer) per token
print(f"\nCACHE BENEFIT:")
print(f" Without cache, each draft would recompute {seq_len}-token context")
print(f" With cache, each draft processes only 1 new token")
per_token_savings = seq_len - 1 # Positions we don't recompute
total_context_savings = per_token_savings * draft_ops
print(f" Context reuse savings: ~{total_context_savings} avoided operations")
# Verify phase
verify_ops = 5 * num_layers
print(f"\nVERIFY PHASE: {verify_ops} ops (fills all layers for drafted tokens)")
print(f"\nTotal time: {prefill_time + draft_time + verify_time:.2f} ms")
return True
def run_full_benchmark(model_name, heads_path, config_path, calibration_path=None):
"""Run full benchmark with actual model."""
from src.inference import load_dssd_model
print("\n" + "=" * 60)
print(f"BENCHMARK: Full Model Comparison")
print(f"Model: {model_name}")
print("=" * 60)
try:
decoder, tokenizer = load_dssd_model(
model_name=model_name,
heads_path=heads_path,
config_path=config_path,
calibration_path=calibration_path,
device="auto",
)
except Exception as e:
print(f"Error loading model: {e}")
return False
prompt = "Explain what machine learning is in three sentences."
max_tokens = 50
# Warmup
print("\nWarming up...")
_ = decoder.generate(
prompt, max_tokens=10, use_early_exit=False, use_chat_template=True
)
# Benchmark standard generation
print("\nRunning standard generation (no cache)...")
start = time.perf_counter()
result_standard = decoder.generate(
prompt,
max_tokens=max_tokens,
use_early_exit=True,
accuracy_level=0.75,
use_chat_template=True,
)
time_standard = time.perf_counter() - start
# Benchmark cache-optimized generation (fast version)
print("Running cache-optimized generation (fast)...")
start = time.perf_counter()
result_cached = decoder.generate_fast(
prompt,
max_tokens=max_tokens,
accuracy_level=0.75,
use_chat_template=True,
)
time_cached = time.perf_counter() - start
# Print results
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
print("\nStandard Generation:")
print(f" Tokens generated: {len(result_standard.tokens)}")
print(f" Time: {time_standard:.2f}s")
print(f" Tokens/sec: {len(result_standard.tokens) / time_standard:.2f}")
print(f" Avg exit layer: {result_standard.avg_exit_layer:.1f}")
print("\nCache-Optimized Generation:")
print(f" Tokens generated: {len(result_cached.tokens)}")
print(f" Time: {time_cached:.2f}s")
print(f" Tokens/sec: {len(result_cached.tokens) / time_cached:.2f}")
print(f" Avg exit layer: {result_cached.avg_exit_layer:.1f}")
if "total_drafted" in result_cached.exit_distribution:
print(f" Drafted: {result_cached.exit_distribution['total_drafted']}")
print(f" Accepted: {result_cached.exit_distribution['total_accepted']}")
print(
f" Acceptance rate: {result_cached.exit_distribution['acceptance_rate']:.1%}"
)
print("\nSpeedup:")
speedup = time_standard / time_cached if time_cached > 0 else 0
print(f" {speedup:.2f}x faster with cache")
return True
def main():
parser = argparse.ArgumentParser(description="Benchmark DSSD generation")
parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="Model name")
parser.add_argument("--heads-path", help="Path to aux heads checkpoint")
parser.add_argument("--config-path", help="Path to model config")
parser.add_argument("--calibration-path", help="Path to calibration file")
parser.add_argument(
"--cpu-only", action="store_true", help="Run CPU-only cache benchmark"
)
args = parser.parse_args()
if args.cpu_only or not args.heads_path:
# Run CPU-only cache operations benchmark
make_dummy_decoder()
else:
# Run full benchmark with model
run_full_benchmark(
args.model,
args.heads_path,
args.config_path,
args.calibration_path,
)
if __name__ == "__main__":
main()
|