|
|
""" |
|
|
Benchmark tests for KV Cache optimization in DSSD. |
|
|
|
|
|
This module provides deterministic benchmarks to measure: |
|
|
1. Layer forward counts (direct measure of computation) |
|
|
2. Wall clock time for draft + verify phases |
|
|
3. Optional FLOPs estimation |
|
|
|
|
|
Run with: python -m tests.benchmark_kv_cache |
|
|
""" |
|
|
|
|
|
import time |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from dataclasses import dataclass, field |
|
|
from typing import Dict, List, Optional, Tuple |
|
|
from contextlib import contextmanager |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchmarkMetrics: |
|
|
"""Tracks metrics during benchmark run.""" |
|
|
|
|
|
|
|
|
layer_forward_counts: Dict[int, int] = field(default_factory=dict) |
|
|
total_layer_forwards: int = 0 |
|
|
|
|
|
|
|
|
draft_time_ms: float = 0.0 |
|
|
verify_time_ms: float = 0.0 |
|
|
total_time_ms: float = 0.0 |
|
|
|
|
|
|
|
|
tokens_drafted: int = 0 |
|
|
tokens_accepted: int = 0 |
|
|
tokens_rejected: int = 0 |
|
|
|
|
|
|
|
|
exit_layers: List[int] = field(default_factory=list) |
|
|
|
|
|
def reset(self): |
|
|
"""Reset all metrics.""" |
|
|
self.layer_forward_counts.clear() |
|
|
self.total_layer_forwards = 0 |
|
|
self.draft_time_ms = 0.0 |
|
|
self.verify_time_ms = 0.0 |
|
|
self.total_time_ms = 0.0 |
|
|
self.tokens_drafted = 0 |
|
|
self.tokens_accepted = 0 |
|
|
self.tokens_rejected = 0 |
|
|
self.exit_layers.clear() |
|
|
|
|
|
def record_layer_forward(self, layer_idx: int): |
|
|
"""Record a layer forward pass.""" |
|
|
self.layer_forward_counts[layer_idx] = ( |
|
|
self.layer_forward_counts.get(layer_idx, 0) + 1 |
|
|
) |
|
|
self.total_layer_forwards += 1 |
|
|
|
|
|
def summary(self) -> str: |
|
|
"""Return human-readable summary.""" |
|
|
lines = [ |
|
|
"=" * 50, |
|
|
"BENCHMARK METRICS", |
|
|
"=" * 50, |
|
|
f"Total Layer Forwards: {self.total_layer_forwards}", |
|
|
f"Tokens Drafted: {self.tokens_drafted}", |
|
|
f"Tokens Accepted: {self.tokens_accepted}", |
|
|
f"Tokens Rejected: {self.tokens_rejected}", |
|
|
f"Draft Time: {self.draft_time_ms:.2f} ms", |
|
|
f"Verify Time: {self.verify_time_ms:.2f} ms", |
|
|
f"Total Time: {self.total_time_ms:.2f} ms", |
|
|
"", |
|
|
"Layer Forward Distribution:", |
|
|
] |
|
|
for layer_idx in sorted(self.layer_forward_counts.keys()): |
|
|
count = self.layer_forward_counts[layer_idx] |
|
|
lines.append(f" Layer {layer_idx:2d}: {count} forwards") |
|
|
|
|
|
if self.exit_layers: |
|
|
avg_exit = sum(self.exit_layers) / len(self.exit_layers) |
|
|
lines.append(f"\nAverage Exit Layer: {avg_exit:.1f}") |
|
|
|
|
|
lines.append("=" * 50) |
|
|
return "\n".join(lines) |
|
|
|
|
|
|
|
|
|
|
|
_metrics: Optional[BenchmarkMetrics] = None |
|
|
|
|
|
|
|
|
def get_metrics() -> Optional[BenchmarkMetrics]: |
|
|
"""Get the current metrics instance.""" |
|
|
return _metrics |
|
|
|
|
|
|
|
|
@contextmanager |
|
|
def benchmark_context(): |
|
|
"""Context manager that enables metric collection.""" |
|
|
global _metrics |
|
|
_metrics = BenchmarkMetrics() |
|
|
try: |
|
|
yield _metrics |
|
|
finally: |
|
|
_metrics = None |
|
|
|
|
|
|
|
|
def instrument_layer_forward(layer_idx: int): |
|
|
"""Call this from forward_layer to record layer execution.""" |
|
|
if _metrics is not None: |
|
|
_metrics.record_layer_forward(layer_idx) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Timer: |
|
|
"""Simple timer for benchmarking.""" |
|
|
|
|
|
def __init__(self): |
|
|
self.start_time = None |
|
|
self.elapsed_ms = 0.0 |
|
|
|
|
|
def start(self): |
|
|
torch.cuda.synchronize() if torch.cuda.is_available() else None |
|
|
self.start_time = time.perf_counter() |
|
|
|
|
|
def stop(self) -> float: |
|
|
torch.cuda.synchronize() if torch.cuda.is_available() else None |
|
|
if self.start_time is not None: |
|
|
self.elapsed_ms = (time.perf_counter() - self.start_time) * 1000 |
|
|
return self.elapsed_ms |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@dataclass |
|
|
class BenchmarkConfig: |
|
|
"""Configuration for benchmark runs.""" |
|
|
|
|
|
|
|
|
model_name: str = "Qwen/Qwen3-0.6B" |
|
|
|
|
|
|
|
|
prompt: str = "Explain what machine learning is in simple terms." |
|
|
max_draft_length: int = 5 |
|
|
num_iterations: int = 3 |
|
|
|
|
|
|
|
|
accuracy_level: float = 0.75 |
|
|
|
|
|
|
|
|
seed: int = 42 |
|
|
|
|
|
|
|
|
def run_single_draft_verify_benchmark( |
|
|
decoder, |
|
|
config: BenchmarkConfig, |
|
|
use_cache: bool = False, |
|
|
) -> BenchmarkMetrics: |
|
|
""" |
|
|
Run a single draft + verify cycle and measure metrics. |
|
|
|
|
|
Args: |
|
|
decoder: The DSSDecoder instance |
|
|
config: Benchmark configuration |
|
|
use_cache: Whether to use JaggedKVCache (for comparison) |
|
|
|
|
|
Returns: |
|
|
BenchmarkMetrics with recorded data |
|
|
""" |
|
|
|
|
|
torch.manual_seed(config.seed) |
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(config.seed) |
|
|
|
|
|
with benchmark_context() as metrics: |
|
|
timer = Timer() |
|
|
|
|
|
|
|
|
input_ids = decoder.tokenizer.encode(config.prompt, return_tensors="pt").to( |
|
|
decoder.device |
|
|
) |
|
|
|
|
|
|
|
|
thresholds = {} |
|
|
if decoder.calibration: |
|
|
thresholds = decoder.calibration.get_thresholds_for_level( |
|
|
config.accuracy_level |
|
|
) |
|
|
|
|
|
|
|
|
timer.start() |
|
|
drafted_tokens = [] |
|
|
draft_ids = input_ids.clone() |
|
|
|
|
|
for _ in range(config.max_draft_length): |
|
|
|
|
|
|
|
|
draft_result = decoder._draft_single_token(draft_ids, thresholds) |
|
|
|
|
|
if draft_result is None: |
|
|
break |
|
|
|
|
|
token_id, exit_head, exit_layer, uncertainty = draft_result |
|
|
drafted_tokens.append((token_id, exit_head, exit_layer, uncertainty)) |
|
|
metrics.exit_layers.append(exit_layer) |
|
|
|
|
|
if token_id == decoder.tokenizer.eos_token_id: |
|
|
break |
|
|
|
|
|
draft_ids = torch.cat( |
|
|
[draft_ids, torch.tensor([[token_id]], device=decoder.device)], dim=1 |
|
|
) |
|
|
|
|
|
metrics.draft_time_ms = timer.stop() |
|
|
metrics.tokens_drafted = len(drafted_tokens) |
|
|
|
|
|
|
|
|
timer.start() |
|
|
|
|
|
if drafted_tokens: |
|
|
with torch.no_grad(): |
|
|
outputs = decoder.model(draft_ids, use_cache=False) |
|
|
verify_logits = outputs.logits |
|
|
|
|
|
|
|
|
start_pos = input_ids.shape[1] - 1 |
|
|
accepted = 0 |
|
|
|
|
|
for i, (token_id, exit_head, exit_layer, uncertainty) in enumerate( |
|
|
drafted_tokens |
|
|
): |
|
|
verify_pos = start_pos + i |
|
|
verified_token = torch.argmax(verify_logits[0, verify_pos, :]).item() |
|
|
|
|
|
if token_id == verified_token: |
|
|
accepted += 1 |
|
|
else: |
|
|
break |
|
|
|
|
|
metrics.tokens_accepted = accepted |
|
|
metrics.tokens_rejected = len(drafted_tokens) - accepted |
|
|
|
|
|
metrics.verify_time_ms = timer.stop() |
|
|
metrics.total_time_ms = metrics.draft_time_ms + metrics.verify_time_ms |
|
|
|
|
|
return metrics |
|
|
|
|
|
|
|
|
def run_baseline_benchmark(decoder, config: BenchmarkConfig) -> BenchmarkMetrics: |
|
|
""" |
|
|
Run baseline benchmark (current implementation without cache optimization). |
|
|
""" |
|
|
print(f"\n{'=' * 60}") |
|
|
print("BASELINE BENCHMARK (No Cache)") |
|
|
print(f"{'=' * 60}") |
|
|
print(f"Model: {config.model_name}") |
|
|
print(f"Prompt: '{config.prompt[:50]}...'") |
|
|
print(f"Max Draft Length: {config.max_draft_length}") |
|
|
print(f"Iterations: {config.num_iterations}") |
|
|
|
|
|
all_metrics = [] |
|
|
|
|
|
for i in range(config.num_iterations): |
|
|
print(f"\nIteration {i + 1}/{config.num_iterations}...") |
|
|
metrics = run_single_draft_verify_benchmark(decoder, config, use_cache=False) |
|
|
all_metrics.append(metrics) |
|
|
print(f" Layer Forwards: {metrics.total_layer_forwards}") |
|
|
print(f" Draft Time: {metrics.draft_time_ms:.2f} ms") |
|
|
print(f" Verify Time: {metrics.verify_time_ms:.2f} ms") |
|
|
|
|
|
|
|
|
avg_metrics = BenchmarkMetrics() |
|
|
avg_metrics.total_layer_forwards = sum( |
|
|
m.total_layer_forwards for m in all_metrics |
|
|
) // len(all_metrics) |
|
|
avg_metrics.draft_time_ms = sum(m.draft_time_ms for m in all_metrics) / len( |
|
|
all_metrics |
|
|
) |
|
|
avg_metrics.verify_time_ms = sum(m.verify_time_ms for m in all_metrics) / len( |
|
|
all_metrics |
|
|
) |
|
|
avg_metrics.total_time_ms = sum(m.total_time_ms for m in all_metrics) / len( |
|
|
all_metrics |
|
|
) |
|
|
avg_metrics.tokens_drafted = all_metrics[0].tokens_drafted |
|
|
avg_metrics.tokens_accepted = all_metrics[0].tokens_accepted |
|
|
avg_metrics.tokens_rejected = all_metrics[0].tokens_rejected |
|
|
|
|
|
|
|
|
for m in all_metrics: |
|
|
for layer_idx, count in m.layer_forward_counts.items(): |
|
|
avg_metrics.layer_forward_counts[layer_idx] = ( |
|
|
avg_metrics.layer_forward_counts.get(layer_idx, 0) |
|
|
+ count // len(all_metrics) |
|
|
) |
|
|
|
|
|
print("\n" + avg_metrics.summary()) |
|
|
return avg_metrics |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Run benchmark suite.""" |
|
|
import sys |
|
|
|
|
|
sys.path.insert(0, "/home/fvalade/workspace/DSSD_demo") |
|
|
|
|
|
from src.inference import load_dssd_model |
|
|
|
|
|
config = BenchmarkConfig() |
|
|
|
|
|
print("Loading model...") |
|
|
try: |
|
|
|
|
|
decoder, tokenizer = load_dssd_model( |
|
|
model_name=config.model_name, |
|
|
heads_path="../checkpoints/qwen3-0.6b/aux_heads.pt", |
|
|
config_path="../checkpoints/qwen3-0.6b/config.json", |
|
|
calibration_path="../checkpoints/qwen3-0.6b/calibration.json", |
|
|
device="auto", |
|
|
) |
|
|
print("Model loaded successfully!") |
|
|
except Exception as e: |
|
|
print(f"Error loading model: {e}") |
|
|
print("\nTo run this benchmark, ensure you have:") |
|
|
print(" 1. A trained auxiliary heads checkpoint") |
|
|
print(" 2. The corresponding config.json") |
|
|
print(" 3. (Optional) calibration.json for thresholds") |
|
|
return |
|
|
|
|
|
|
|
|
baseline_metrics = run_baseline_benchmark(decoder, config) |
|
|
|
|
|
|
|
|
print("\n" + "=" * 60) |
|
|
print("BASELINE RESULTS SAVED") |
|
|
print("Run this again after implementing JaggedKVCache to compare.") |
|
|
print("=" * 60) |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|