File size: 7,346 Bytes
f8437ec | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 | """
analysis/kv_cache_benchmark.py
================================
Task 1: Benchmark KV cache vs standard generate().
Measures:
- Wall-clock time for generate() vs generate_cached()
- Encoder time as % of total generation time (before/after)
- Speedup ratio at src_len = 16, 32, 64 tokens
How it works:
Standard generate():
For each of T=128 steps:
src β encoder β memory β decoder β logits (encoder runs 128 times)
generate_cached():
src β encoder β memory (once)
For each of T=128 steps:
cached_memory β decoder β logits (encoder runs 1 time)
Expected speedup:
If encoder = 30% of per-step time:
Saved = 127/128 * 30% β 29.7% of total time
If encoder = 50% of per-step time:
Saved β 49.6% of total time
Usage:
python -m analysis.kv_cache_benchmark
or:
from analysis.kv_cache_benchmark import run_benchmark
results = run_benchmark(model, src_tokenizer, device)
"""
import torch
import time
import numpy as np
from typing import Dict, List
def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
"""Create a random source tensor of given length."""
# Random real tokens (ids 5..src_vocab-1), padded to src_len
ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
return ids
def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
"""
Time a zero-argument callable.
Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
"""
# Warmup
for _ in range(n_warmup):
fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
torch.mps.synchronize()
times = []
for _ in range(n_runs):
start = time.perf_counter()
fn()
if torch.cuda.is_available():
torch.cuda.synchronize()
elif torch.backends.mps.is_available():
torch.mps.synchronize()
times.append(time.perf_counter() - start)
return float(np.mean(times))
def benchmark_encoder_cost(
model,
src: torch.Tensor,
) -> Dict[str, float]:
"""
Measure encoder time as a fraction of one full forward pass.
Returns:
encoder_s : seconds for one encoder call
full_step_s : seconds for one full forward_cached decoder step
encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
"""
inner = model.model
if not hasattr(inner, 'encode_source'):
raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
device = src.device
B = src.shape[0]
T = inner.scheduler.num_timesteps
tgt_len = inner.max_seq_len
mask_id = inner.mask_token_id
x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
t = torch.zeros(B, dtype=torch.long, device=device)
# Time encoder alone
encoder_s = _time_fn(lambda: inner.encode_source(src))
# Pre-compute memory for decoder timing
memory, src_pad_mask = inner.encode_source(src)
# Time one decoder step (cached)
decoder_s = _time_fn(
lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
inference_mode=True)
)
# Time one full step (non-cached = encoder + decoder)
full_s = _time_fn(
lambda: inner.forward(src, x0_est, t, inference_mode=True)
)
encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
return {
"encoder_s": encoder_s,
"decoder_s": decoder_s,
"full_step_s": full_s,
"encoder_pct": encoder_pct,
}
def run_benchmark(
model,
src_tokenizer,
device: torch.device,
src_lens: List[int] = [16, 32, 64],
n_runs: int = 5,
) -> Dict:
"""
Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
Args:
model : SanskritModel (D3PMCrossAttention)
src_tokenizer : SanskritSourceTokenizer
device : torch.device
src_lens : list of source lengths to benchmark
n_runs : number of timing runs per condition
Returns:
results dict with timing and speedup for each src_len
"""
inner = model.model
if not hasattr(inner, 'generate_cached'):
raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
src_vocab = inner.src_embed.token_emb.weight.shape[0]
results = {}
print("\n" + "=" * 65)
print(" KV CACHE BENCHMARK")
print("=" * 65)
print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
f"{'speedup':>8} {'encoder%':>9}")
print("-" * 65)
for src_len in src_lens:
src = _make_src(src_len, src_vocab, device)
# Encoder cost breakdown
enc_cost = benchmark_encoder_cost(model, src)
# Time standard generate() β encoder runs T times
def run_standard():
return inner.generate(src, temperature=0.8, top_k=40)
# Time generate_cached() β encoder runs once
def run_cached():
return inner.generate_cached(src, temperature=0.8, top_k=40)
t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
speedup = t_standard / max(t_cached, 1e-9)
results[src_len] = {
"standard_s": t_standard,
"cached_s": t_cached,
"speedup": speedup,
"encoder_pct": enc_cost["encoder_pct"],
}
print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
print("=" * 65)
print(f"\n Encoder cost = % of one full forward pass")
print(f" Speedup = standard_time / cached_time")
print(f" Expected: speedup β 1 / (1 - encoder_pct/100 * (T-1)/T)")
return results
def print_summary(results: Dict):
"""Print a human-readable summary of benchmark results."""
print("\n SUMMARY")
print(" -------")
for src_len, r in results.items():
saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
f"({saved_pct:.1f}% time saved, "
f"encoder was {r['encoder_pct']:.1f}% of total)")
if __name__ == "__main__":
import sys, os
sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
from config import CONFIG
from inference import load_model
from models.tokenizer import SanskritSourceTokenizer
cfg = CONFIG
device = torch.device(cfg['training']['device'])
model_name = cfg['model_type']
has_neg = cfg['data']['include_negative_examples']
ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
if not os.path.exists(ckpt):
print(f"No checkpoint at {ckpt}. Train first.")
sys.exit(1)
model, cfg = load_model(ckpt, cfg, device)
model.eval()
src_tokenizer = SanskritSourceTokenizer(
vocab_size = cfg['model'].get('src_vocab_size', 500),
max_len = cfg['model']['max_seq_len'],
)
results = run_benchmark(model, src_tokenizer, device)
print_summary(results)
|