File size: 8,374 Bytes
33efa44
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
#!/usr/bin/env python3
"""
Benchmark comparison: Standard generation vs Cache-optimized generation.

This script measures and compares:
- Layer forward counts
- Wall clock time
- Tokens per second

Usage:
    python tests/run_benchmark.py --model Qwen/Qwen3-0.6B --heads-path /path/to/heads.pt
"""

import argparse
import time
import sys
import os

# Add project to path
sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__))))

import torch


def make_dummy_decoder():
    """Create a minimal decoder for benchmarking without GPU."""
    from src.jagged_cache import JaggedKVCache

    print("\n" + "=" * 60)
    print("BENCHMARK: JaggedKVCache Operations (No GPU Required)")
    print("=" * 60)

    # Test cache performance
    num_layers = 28
    batch_size = 1
    num_heads = 8
    head_dim = 128
    seq_len = 100

    cache = JaggedKVCache(
        num_layers=num_layers,
        batch_size=batch_size,
        num_kv_heads=num_heads,
        head_dim=head_dim,
        device="cpu",
        dtype=torch.float32,
    )

    # Simulate prefill
    print(f"\nSimulating prefill ({seq_len} tokens, {num_layers} layers)...")
    start = time.perf_counter()
    for pos in range(seq_len):
        for layer_idx in range(num_layers):
            k = torch.randn(batch_size, num_heads, 1, head_dim)
            v = torch.randn(batch_size, num_heads, 1, head_dim)
            cache.update(layer_idx, k, v, torch.tensor([pos]))
    prefill_time = (time.perf_counter() - start) * 1000
    print(f"  Prefill time: {prefill_time:.2f} ms")

    # Simulate draft phase (early exit at different layers)
    print("\nSimulating draft phase (5 tokens, variable exit layers)...")
    exit_layers = [4, 8, 6, 12, 10]  # Simulate different exit layers
    draft_cache = cache.clone()

    start = time.perf_counter()
    for i, exit_layer in enumerate(exit_layers):
        pos = seq_len + i
        for layer_idx in range(exit_layer + 1):
            k = torch.randn(batch_size, num_heads, 1, head_dim)
            v = torch.randn(batch_size, num_heads, 1, head_dim)
            draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
    draft_time = (time.perf_counter() - start) * 1000
    print(f"  Draft time: {draft_time:.2f} ms")

    # Print cache state
    print("\nCache state after drafting:")
    for layer_idx in [0, 4, 8, 12, 16, 20, 24, 27]:
        filled = len(draft_cache.filled_positions[layer_idx])
        print(f"  Layer {layer_idx:2d}: {filled} positions filled")

    # Simulate verification (fill all layers for all positions)
    print("\nSimulating verification (lazy fill + full model)...")
    start = time.perf_counter()
    for pos in range(seq_len, seq_len + 5):
        # Find missing layers
        missing = draft_cache.get_missing_layers(pos, num_layers - 1)
        for layer_idx in missing:
            k = torch.randn(batch_size, num_heads, 1, head_dim)
            v = torch.randn(batch_size, num_heads, 1, head_dim)
            draft_cache.update(layer_idx, k, v, torch.tensor([pos]))
    verify_time = (time.perf_counter() - start) * 1000
    print(f"  Verify time: {verify_time:.2f} ms")

    # Calculate and explain savings
    print("\n" + "=" * 60)
    print("ANALYSIS: Layer Operations")
    print("=" * 60)

    # Prefill ops (same for all approaches - one-time cost)
    prefill_ops = seq_len * num_layers
    print(f"\nPREFILL (one-time): {prefill_ops} layer ops")

    # Draft phase with early exit
    draft_ops = sum(exit_layer + 1 for exit_layer in exit_layers)
    draft_ops_full = 5 * num_layers  # Without early exit
    print(f"\nDRAFT PHASE (5 tokens):")
    print(f"  With early exit: {draft_ops} ops (avg {draft_ops / 5:.1f} layers/token)")
    print(f"  Without early exit: {draft_ops_full} ops ({num_layers} layers/token)")
    print(
        f"  Draft savings: {draft_ops_full - draft_ops} ops ({100 * (1 - draft_ops / draft_ops_full):.0f}% reduction)"
    )

    # The KEY benefit: with cache, each draft token is O(1 token * exit_layer)
    # Without cache, it would be O(seq_len * exit_layer) per token
    print(f"\nCACHE BENEFIT:")
    print(f"  Without cache, each draft would recompute {seq_len}-token context")
    print(f"  With cache, each draft processes only 1 new token")
    per_token_savings = seq_len - 1  # Positions we don't recompute
    total_context_savings = per_token_savings * draft_ops
    print(f"  Context reuse savings: ~{total_context_savings} avoided operations")

    # Verify phase
    verify_ops = 5 * num_layers
    print(f"\nVERIFY PHASE: {verify_ops} ops (fills all layers for drafted tokens)")

    print(f"\nTotal time: {prefill_time + draft_time + verify_time:.2f} ms")

    return True


def run_full_benchmark(model_name, heads_path, config_path, calibration_path=None):
    """Run full benchmark with actual model."""
    from src.inference import load_dssd_model

    print("\n" + "=" * 60)
    print(f"BENCHMARK: Full Model Comparison")
    print(f"Model: {model_name}")
    print("=" * 60)

    try:
        decoder, tokenizer = load_dssd_model(
            model_name=model_name,
            heads_path=heads_path,
            config_path=config_path,
            calibration_path=calibration_path,
            device="auto",
        )
    except Exception as e:
        print(f"Error loading model: {e}")
        return False

    prompt = "Explain what machine learning is in three sentences."
    max_tokens = 50

    # Warmup
    print("\nWarming up...")
    _ = decoder.generate(
        prompt, max_tokens=10, use_early_exit=False, use_chat_template=True
    )

    # Benchmark standard generation
    print("\nRunning standard generation (no cache)...")
    start = time.perf_counter()
    result_standard = decoder.generate(
        prompt,
        max_tokens=max_tokens,
        use_early_exit=True,
        accuracy_level=0.75,
        use_chat_template=True,
    )
    time_standard = time.perf_counter() - start

    # Benchmark cache-optimized generation (fast version)
    print("Running cache-optimized generation (fast)...")
    start = time.perf_counter()
    result_cached = decoder.generate_fast(
        prompt,
        max_tokens=max_tokens,
        accuracy_level=0.75,
        use_chat_template=True,
    )
    time_cached = time.perf_counter() - start

    # Print results
    print("\n" + "=" * 60)
    print("RESULTS")
    print("=" * 60)

    print("\nStandard Generation:")
    print(f"  Tokens generated: {len(result_standard.tokens)}")
    print(f"  Time: {time_standard:.2f}s")
    print(f"  Tokens/sec: {len(result_standard.tokens) / time_standard:.2f}")
    print(f"  Avg exit layer: {result_standard.avg_exit_layer:.1f}")

    print("\nCache-Optimized Generation:")
    print(f"  Tokens generated: {len(result_cached.tokens)}")
    print(f"  Time: {time_cached:.2f}s")
    print(f"  Tokens/sec: {len(result_cached.tokens) / time_cached:.2f}")
    print(f"  Avg exit layer: {result_cached.avg_exit_layer:.1f}")
    if "total_drafted" in result_cached.exit_distribution:
        print(f"  Drafted: {result_cached.exit_distribution['total_drafted']}")
        print(f"  Accepted: {result_cached.exit_distribution['total_accepted']}")
        print(
            f"  Acceptance rate: {result_cached.exit_distribution['acceptance_rate']:.1%}"
        )

    print("\nSpeedup:")
    speedup = time_standard / time_cached if time_cached > 0 else 0
    print(f"  {speedup:.2f}x faster with cache")

    return True


def main():
    parser = argparse.ArgumentParser(description="Benchmark DSSD generation")
    parser.add_argument("--model", default="Qwen/Qwen3-0.6B", help="Model name")
    parser.add_argument("--heads-path", help="Path to aux heads checkpoint")
    parser.add_argument("--config-path", help="Path to model config")
    parser.add_argument("--calibration-path", help="Path to calibration file")
    parser.add_argument(
        "--cpu-only", action="store_true", help="Run CPU-only cache benchmark"
    )
    args = parser.parse_args()

    if args.cpu_only or not args.heads_path:
        # Run CPU-only cache operations benchmark
        make_dummy_decoder()
    else:
        # Run full benchmark with model
        run_full_benchmark(
            args.model,
            args.heads_path,
            args.config_path,
            args.calibration_path,
        )


if __name__ == "__main__":
    main()