Dssd_Demo / tests /run_benchmark.py
Florian valade
Track metrics during streaming, remove redundant generation re-runs
33efa44
#!/usr/bin/env python3
"""
Benchmark comparison: Standard generation vs Cache-optimized generation.
This script measures and compares:
- Layer forward counts
- Wall clock time
- Tokens per second
Usage:
python tests/run_benchmark.py --model Qwen/Qwen3-0.6B --heads-path /path/to/heads.pt
"""
import argparse
import time
import sys
import os
# Add project to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
import torch
def make_dummy_decoder():
"""Create a minimal decoder for benchmarking without GPU."""
from src.jagged_cache import JaggedKVCache
print("\n" + "=" * 60)
print("BENCHMARK: JaggedKVCache Operations (No GPU Required)")
print("=" * 60)
# Test cache performance
num_layers = 28
batch_size = 1
num_heads = 8
head_dim = 128
seq_len = 100
cache = JaggedKVCache(
num_layers=num_layers,
batch_size=batch_size,
num_kv_heads=num_heads,
head_dim=head_dim,
device="cpu",
dtype=torch.float32,
)
# Simulate prefill
print(f"\nSimulating prefill ({seq_len} tokens, {num_layers} layers)...")
start = time.perf_counter()
for pos in range(seq_len):
for layer_idx in range(num_layers):
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
cache.update(layer_idx, k, v, torch.tensor([pos]))
prefill_time = (time.perf_counter() - start) * 1000
print(f" Prefill time: {prefill_time:.2f} ms")
# Simulate draft phase (early exit at different layers)
print("\nSimulating draft phase (5 tokens, variable exit layers)...")
exit_layers = [4, 8, 6, 12, 10] # Simulate different exit layers
draft_cache = cache.clone()
start = time.perf_counter()
for i, exit_layer in enumerate(exit_layers):
pos = seq_len + i
for layer_idx in range(exit_layer + 1):
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
draft_time = (time.perf_counter() - start) * 1000
print(f" Draft time: {draft_time:.2f} ms")
# Print cache state
print("\nCache state after drafting:")
for layer_idx in [0, 4, 8, 12, 16, 20, 24, 27]:
filled = len(draft_cache.filled_positions[layer_idx])
print(f" Layer {layer_idx:2d}: {filled} positions filled")
# Simulate verification (fill all layers for all positions)
print("\nSimulating verification (lazy fill + full model)...")
start = time.perf_counter()
for pos in range(seq_len, seq_len + 5):
# Find missing layers
missing = draft_cache.get_missing_layers(pos, num_layers - 1)
for layer_idx in missing:
k = torch.randn(batch_size, num_heads, 1, head_dim)
v = torch.randn(batch_size, num_heads, 1, head_dim)
draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
verify_time = (time.perf_counter() - start) * 1000
print(f" Verify time: {verify_time:.2f} ms")
# Calculate and explain savings
print("\n" + "=" * 60)
print("ANALYSIS: Layer Operations")
print("=" * 60)
# Prefill ops (same for all approaches - one-time cost)
prefill_ops = seq_len * num_layers
print(f"\nPREFILL (one-time): {prefill_ops} layer ops")
# Draft phase with early exit
draft_ops = sum(exit_layer + 1 for exit_layer in exit_layers)
draft_ops_full = 5 * num_layers # Without early exit
print(f"\nDRAFT PHASE (5 tokens):")
print(f" With early exit: {draft_ops} ops (avg {draft_ops / 5:.1f} layers/token)")
print(f" Without early exit: {draft_ops_full} ops ({num_layers} layers/token)")
print(
f" Draft savings: {draft_ops_full - draft_ops} ops ({100 * (1 - draft_ops / draft_ops_full):.0f}% reduction)"
)
# The KEY benefit: with cache, each draft token is O(1 token * exit_layer)
# Without cache, it would be O(seq_len * exit_layer) per token
print(f"\nCACHE BENEFIT:")
print(f" Without cache, each draft would recompute {seq_len}-token context")
print(f" With cache, each draft processes only 1 new token")
per_token_savings = seq_len - 1 # Positions we don't recompute
total_context_savings = per_token_savings * draft_ops
print(f" Context reuse savings: ~{total_context_savings} avoided operations")
# Verify phase
verify_ops = 5 * num_layers
print(f"\nVERIFY PHASE: {verify_ops} ops (fills all layers for drafted tokens)")
print(f"\nTotal time: {prefill_time + draft_time + verify_time:.2f} ms")
return True
def run_full_benchmark(model_name, heads_path, config_path, calibration_path=None):
"""Run full benchmark with actual model."""
from src.inference import load_dssd_model
print("\n" + "=" * 60)
print(f"BENCHMARK: Full Model Comparison")
print(f"Model: {model_name}")
print("=" * 60)
try:
decoder, tokenizer = load_dssd_model(
model_name=model_name,
heads_path=heads_path,
config_path=config_path,
calibration_path=calibration_path,
device="auto",
)
except Exception as e:
print(f"Error loading model: {e}")
return False
prompt = "Explain what machine learning is in three sentences."
max_tokens = 50
# Warmup
print("\nWarming up...")
_ = decoder.generate(
prompt, max_tokens=10, use_early_exit=False, use_chat_template=True
)
# Benchmark standard generation
print("\nRunning standard generation (no cache)...")
start = time.perf_counter()
result_standard = decoder.generate(
prompt,
max_tokens=max_tokens,
use_early_exit=True,
accuracy_level=0.75,
use_chat_template=True,
)
time_standard = time.perf_counter() - start
# Benchmark cache-optimized generation (fast version)
print("Running cache-optimized generation (fast)...")
start = time.perf_counter()
result_cached = decoder.generate_fast(
prompt,
max_tokens=max_tokens,
accuracy_level=0.75,
use_chat_template=True,
)
time_cached = time.perf_counter() - start
# Print results
print("\n" + "=" * 60)
print("RESULTS")
print("=" * 60)
print("\nStandard Generation:")
print(f" Tokens generated: {len(result_standard.tokens)}")
print(f" Time: {time_standard:.2f}s")
print(f" Tokens/sec: {len(result_standard.tokens) / time_standard:.2f}")
print(f" Avg exit layer: {result_standard.avg_exit_layer:.1f}")
print("\nCache-Optimized Generation:")
print(f" Tokens generated: {len(result_cached.tokens)}")
print(f" Time: {time_cached:.2f}s")
print(f" Tokens/sec: {len(result_cached.tokens) / time_cached:.2f}")
print(f" Avg exit layer: {result_cached.avg_exit_layer:.1f}")
if "total_drafted" in result_cached.exit_distribution:
print(f" Drafted: {result_cached.exit_distribution['total_drafted']}")
print(f" Accepted: {result_cached.exit_distribution['total_accepted']}")
print(
f" Acceptance rate: {result_cached.exit_distribution['acceptance_rate']:.1%}"
)
print("\nSpeedup:")
speedup = time_standard / time_cached if time_cached > 0 else 0
print(f" {speedup:.2f}x faster with cache")
return True
def main():
parser = argparse.ArgumentParser(description="Benchmark DSSD generation")
parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="Model name")
parser.add_argument("--heads-path", help="Path to aux heads checkpoint")
parser.add_argument("--config-path", help="Path to model config")
parser.add_argument("--calibration-path", help="Path to calibration file")
parser.add_argument(
"--cpu-only", action="store_true", help="Run CPU-only cache benchmark"
)
args = parser.parse_args()
if args.cpu_only or not args.heads_path:
# Run CPU-only cache operations benchmark
make_dummy_decoder()
else:
# Run full benchmark with model
run_full_benchmark(
args.model,
args.heads_path,
args.config_path,
args.calibration_path,
)
if __name__ == "__main__":
main()