File size: 11,304 Bytes
33efa44 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 |
"""
Benchmark tests for KV Cache optimization in DSSD.
This module provides deterministic benchmarks to measure:
1. Layer forward counts (direct measure of computation)
2. Wall clock time for draft + verify phases
3. Optional FLOPs estimation
Run with: python -m tests.benchmark_kv_cache
"""
import time
import torch
import torch.nn as nn
from dataclasses import dataclass, field
from typing import Dict, List, Optional, Tuple
from contextlib import contextmanager
# =============================================================================
# Instrumentation
# =============================================================================
@dataclass
class BenchmarkMetrics:
"""Tracks metrics during benchmark run."""
# Layer forward counts
layer_forward_counts: Dict[int, int] = field(default_factory=dict)
total_layer_forwards: int = 0
# Timing
draft_time_ms: float = 0.0
verify_time_ms: float = 0.0
total_time_ms: float = 0.0
# Token counts
tokens_drafted: int = 0
tokens_accepted: int = 0
tokens_rejected: int = 0
# Early exit distribution
exit_layers: List[int] = field(default_factory=list)
def reset(self):
"""Reset all metrics."""
self.layer_forward_counts.clear()
self.total_layer_forwards = 0
self.draft_time_ms = 0.0
self.verify_time_ms = 0.0
self.total_time_ms = 0.0
self.tokens_drafted = 0
self.tokens_accepted = 0
self.tokens_rejected = 0
self.exit_layers.clear()
def record_layer_forward(self, layer_idx: int):
"""Record a layer forward pass."""
self.layer_forward_counts[layer_idx] = (
self.layer_forward_counts.get(layer_idx, 0) + 1
)
self.total_layer_forwards += 1
def summary(self) -> str:
"""Return human-readable summary."""
lines = [
"=" * 50,
"BENCHMARK METRICS",
"=" * 50,
f"Total Layer Forwards: {self.total_layer_forwards}",
f"Tokens Drafted: {self.tokens_drafted}",
f"Tokens Accepted: {self.tokens_accepted}",
f"Tokens Rejected: {self.tokens_rejected}",
f"Draft Time: {self.draft_time_ms:.2f} ms",
f"Verify Time: {self.verify_time_ms:.2f} ms",
f"Total Time: {self.total_time_ms:.2f} ms",
"",
"Layer Forward Distribution:",
]
for layer_idx in sorted(self.layer_forward_counts.keys()):
count = self.layer_forward_counts[layer_idx]
lines.append(f" Layer {layer_idx:2d}: {count} forwards")
if self.exit_layers:
avg_exit = sum(self.exit_layers) / len(self.exit_layers)
lines.append(f"\nAverage Exit Layer: {avg_exit:.1f}")
lines.append("=" * 50)
return "\n".join(lines)
# Global metrics instance for instrumentation
_metrics: Optional[BenchmarkMetrics] = None
def get_metrics() -> Optional[BenchmarkMetrics]:
"""Get the current metrics instance."""
return _metrics
@contextmanager
def benchmark_context():
"""Context manager that enables metric collection."""
global _metrics
_metrics = BenchmarkMetrics()
try:
yield _metrics
finally:
_metrics = None
def instrument_layer_forward(layer_idx: int):
"""Call this from forward_layer to record layer execution."""
if _metrics is not None:
_metrics.record_layer_forward(layer_idx)
# =============================================================================
# Timer Utilities
# =============================================================================
class Timer:
"""Simple timer for benchmarking."""
def __init__(self):
self.start_time = None
self.elapsed_ms = 0.0
def start(self):
torch.cuda.synchronize() if torch.cuda.is_available() else None
self.start_time = time.perf_counter()
def stop(self) -> float:
torch.cuda.synchronize() if torch.cuda.is_available() else None
if self.start_time is not None:
self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000
return self.elapsed_ms
# =============================================================================
# Benchmark Test Scenarios
# =============================================================================
@dataclass
class BenchmarkConfig:
"""Configuration for benchmark runs."""
# Model setting
model_name: str = "Qwen/Qwen3-0.6B"
# Generation settings
prompt: str = "Explain what machine learning is in simple terms."
max_draft_length: int = 5
num_iterations: int = 3 # Multiple iterations for averaging
# Thresholds for early exit (simulated or real)
accuracy_level: float = 0.75
# Reproducibility
seed: int = 42
def run_single_draft_verify_benchmark(
decoder, # DSSDecoder
config: BenchmarkConfig,
use_cache: bool = False,
) -> BenchmarkMetrics:
"""
Run a single draft + verify cycle and measure metrics.
Args:
decoder: The DSSDecoder instance
config: Benchmark configuration
use_cache: Whether to use JaggedKVCache (for comparison)
Returns:
BenchmarkMetrics with recorded data
"""
# Set seed for reproducibility
torch.manual_seed(config.seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(config.seed)
with benchmark_context() as metrics:
timer = Timer()
# Tokenize prompt
input_ids = decoder.tokenizer.encode(config.prompt, return_tensors="pt").to(
decoder.device
)
# Get thresholds
thresholds = {}
if decoder.calibration:
thresholds = decoder.calibration.get_thresholds_for_level(
config.accuracy_level
)
# ========== DRAFT PHASE ==========
timer.start()
drafted_tokens = []
draft_ids = input_ids.clone()
for _ in range(config.max_draft_length):
# Call the drafting function
# Note: This will need to be modified to use our instrumented version
draft_result = decoder._draft_single_token(draft_ids, thresholds)
if draft_result is None:
break
token_id, exit_head, exit_layer, uncertainty = draft_result
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty))
metrics.exit_layers.append(exit_layer)
if token_id == decoder.tokenizer.eos_token_id:
break
draft_ids = torch.cat(
[draft_ids, torch.tensor([[token_id]], device=decoder.device)], dim=1
)
metrics.draft_time_ms = timer.stop()
metrics.tokens_drafted = len(drafted_tokens)
# ========== VERIFY PHASE ==========
timer.start()
if drafted_tokens:
with torch.no_grad():
outputs = decoder.model(draft_ids, use_cache=False)
verify_logits = outputs.logits
# Verify each token
start_pos = input_ids.shape[1] - 1
accepted = 0
for i, (token_id, exit_head, exit_layer, uncertainty) in enumerate(
drafted_tokens
):
verify_pos = start_pos + i
verified_token = torch.argmax(verify_logits[0, verify_pos, :]).item()
if token_id == verified_token:
accepted += 1
else:
break
metrics.tokens_accepted = accepted
metrics.tokens_rejected = len(drafted_tokens) - accepted
metrics.verify_time_ms = timer.stop()
metrics.total_time_ms = metrics.draft_time_ms + metrics.verify_time_ms
return metrics
def run_baseline_benchmark(decoder, config: BenchmarkConfig) -> BenchmarkMetrics:
"""
Run baseline benchmark (current implementation without cache optimization).
"""
print(f"\n{'=' * 60}")
print("BASELINE BENCHMARK (No Cache)")
print(f"{'=' * 60}")
print(f"Model: {config.model_name}")
print(f"Prompt: '{config.prompt[:50]}...'")
print(f"Max Draft Length: {config.max_draft_length}")
print(f"Iterations: {config.num_iterations}")
all_metrics = []
for i in range(config.num_iterations):
print(f"\nIteration {i + 1}/{config.num_iterations}...")
metrics = run_single_draft_verify_benchmark(decoder, config, use_cache=False)
all_metrics.append(metrics)
print(f" Layer Forwards: {metrics.total_layer_forwards}")
print(f" Draft Time: {metrics.draft_time_ms:.2f} ms")
print(f" Verify Time: {metrics.verify_time_ms:.2f} ms")
# Average metrics
avg_metrics = BenchmarkMetrics()
avg_metrics.total_layer_forwards = sum(
m.total_layer_forwards for m in all_metrics
) // len(all_metrics)
avg_metrics.draft_time_ms = sum(m.draft_time_ms for m in all_metrics) / len(
all_metrics
)
avg_metrics.verify_time_ms = sum(m.verify_time_ms for m in all_metrics) / len(
all_metrics
)
avg_metrics.total_time_ms = sum(m.total_time_ms for m in all_metrics) / len(
all_metrics
)
avg_metrics.tokens_drafted = all_metrics[0].tokens_drafted
avg_metrics.tokens_accepted = all_metrics[0].tokens_accepted
avg_metrics.tokens_rejected = all_metrics[0].tokens_rejected
# Combine layer counts
for m in all_metrics:
for layer_idx, count in m.layer_forward_counts.items():
avg_metrics.layer_forward_counts[layer_idx] = (
avg_metrics.layer_forward_counts.get(layer_idx, 0)
+ count // len(all_metrics)
)
print("\n" + avg_metrics.summary())
return avg_metrics
# =============================================================================
# Main Entry Point
# =============================================================================
def main():
"""Run benchmark suite."""
import sys
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo")
from src.inference import load_dssd_model
config = BenchmarkConfig()
print("Loading model...")
try:
# You'll need to update these paths to match your setup
decoder, tokenizer = load_dssd_model(
model_name=config.model_name,
heads_path="../checkpoints/qwen3-0.6b/aux_heads.pt",
config_path="../checkpoints/qwen3-0.6b/config.json",
calibration_path="../checkpoints/qwen3-0.6b/calibration.json",
device="auto",
)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
print("\nTo run this benchmark, ensure you have:")
print(" 1. A trained auxiliary heads checkpoint")
print(" 2. The corresponding config.json")
print(" 3. (Optional) calibration.json for thresholds")
return
# Run baseline benchmark
baseline_metrics = run_baseline_benchmark(decoder, config)
# Save results for later comparison
print("\n" + "=" * 60)
print("BASELINE RESULTS SAVED")
print("Run this again after implementing JaggedKVCache to compare.")
print("=" * 60)
if __name__ == "__main__":
main()
|