DevaFlow-space / analysis /kv_cache_benchmark.py
bhsinghgrid's picture
Upgrade UI: model selection + tasks 1-5 + analysis modules
29e5bf8 verified
# """
# analysis/kv_cache_benchmark.py
# ================================
# Task 1: Benchmark KV cache vs standard generate().
#
# Measures:
# - Wall-clock time for generate() vs generate_cached()
# - Encoder time as % of total generation time (before/after)
# - Speedup ratio at src_len = 16, 32, 64 tokens
#
# How it works:
# Standard generate():
# For each of T=128 steps:
# src β†’ encoder β†’ memory β†’ decoder β†’ logits (encoder runs 128 times)
#
# generate_cached():
# src β†’ encoder β†’ memory (once)
# For each of T=128 steps:
# cached_memory β†’ decoder β†’ logits (encoder runs 1 time)
#
# Expected speedup:
# If encoder = 30% of per-step time:
# Saved = 127/128 * 30% β‰ˆ 29.7% of total time
# If encoder = 50% of per-step time:
# Saved β‰ˆ 49.6% of total time
#
# Usage:
# python -m analysis.kv_cache_benchmark
# or:
# from analysis.kv_cache_benchmark import run_benchmark
# results = run_benchmark(model, src_tokenizer, device)
# """
#
# import torch
# import time
# import numpy as np
# from typing import Dict, List
#
#
# def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1):
# """Create a random source tensor of given length."""
# # Random real tokens (ids 5..src_vocab-1), padded to src_len
# ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device)
# return ids
#
#
# def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float:
# """
# Time a zero-argument callable.
# Returns mean wall-clock seconds over n_runs after n_warmup warmup calls.
# """
# # Warmup
# for _ in range(n_warmup):
# fn()
# if torch.cuda.is_available():
# torch.cuda.synchronize()
# elif torch.backends.mps.is_available():
# torch.mps.synchronize()
#
# times = []
# for _ in range(n_runs):
# start = time.perf_counter()
# fn()
# if torch.cuda.is_available():
# torch.cuda.synchronize()
# elif torch.backends.mps.is_available():
# torch.mps.synchronize()
# times.append(time.perf_counter() - start)
#
# return float(np.mean(times))
#
#
# def benchmark_encoder_cost(
# model,
# src: torch.Tensor,
# ) -> Dict[str, float]:
# """
# Measure encoder time as a fraction of one full forward pass.
#
# Returns:
# encoder_s : seconds for one encoder call
# full_step_s : seconds for one full forward_cached decoder step
# encoder_pct : encoder_s / (encoder_s + full_step_s) * 100
# """
# inner = model.model
# if not hasattr(inner, 'encode_source'):
# raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
#
# device = src.device
# B = src.shape[0]
# T = inner.scheduler.num_timesteps
# tgt_len = inner.max_seq_len
# mask_id = inner.mask_token_id
#
# x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device)
# t = torch.zeros(B, dtype=torch.long, device=device)
#
# # Time encoder alone
# encoder_s = _time_fn(lambda: inner.encode_source(src))
#
# # Pre-compute memory for decoder timing
# memory, src_pad_mask = inner.encode_source(src)
#
# # Time one decoder step (cached)
# decoder_s = _time_fn(
# lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t,
# inference_mode=True)
# )
#
# # Time one full step (non-cached = encoder + decoder)
# full_s = _time_fn(
# lambda: inner.forward(src, x0_est, t, inference_mode=True)
# )
#
# encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9)
#
# return {
# "encoder_s": encoder_s,
# "decoder_s": decoder_s,
# "full_step_s": full_s,
# "encoder_pct": encoder_pct,
# }
#
#
# def run_benchmark(
# model,
# src_tokenizer,
# device: torch.device,
# src_lens: List[int] = [16, 32, 64],
# n_runs: int = 5,
# ) -> Dict:
# """
# Full benchmark: compare generate() vs generate_cached() at multiple src lengths.
#
# Args:
# model : SanskritModel (D3PMCrossAttention)
# src_tokenizer : SanskritSourceTokenizer
# device : torch.device
# src_lens : list of source lengths to benchmark
# n_runs : number of timing runs per condition
#
# Returns:
# results dict with timing and speedup for each src_len
# """
# inner = model.model
# if not hasattr(inner, 'generate_cached'):
# raise ValueError("Model does not support KV cache (not D3PMCrossAttention).")
#
# src_vocab = inner.src_embed.token_emb.weight.shape[0]
# results = {}
#
# print("\n" + "=" * 65)
# print(" KV CACHE BENCHMARK")
# print("=" * 65)
# print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} "
# f"{'speedup':>8} {'encoder%':>9}")
# print("-" * 65)
#
# for src_len in src_lens:
# src = _make_src(src_len, src_vocab, device)
#
# # Encoder cost breakdown
# enc_cost = benchmark_encoder_cost(model, src)
#
# # Time standard generate() β€” encoder runs T times
# def run_standard():
# return inner.generate(src, temperature=0.8, top_k=40)
#
# # Time generate_cached() β€” encoder runs once
# def run_cached():
# return inner.generate_cached(src, temperature=0.8, top_k=40)
#
# t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs)
# t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs)
# speedup = t_standard / max(t_cached, 1e-9)
#
# results[src_len] = {
# "standard_s": t_standard,
# "cached_s": t_cached,
# "speedup": speedup,
# "encoder_pct": enc_cost["encoder_pct"],
# }
#
# print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} "
# f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%")
#
# print("=" * 65)
# print(f"\n Encoder cost = % of one full forward pass")
# print(f" Speedup = standard_time / cached_time")
# print(f" Expected: speedup β‰ˆ 1 / (1 - encoder_pct/100 * (T-1)/T)")
#
# return results
#
#
# def print_summary(results: Dict):
# """Print a human-readable summary of benchmark results."""
# print("\n SUMMARY")
# print(" -------")
# for src_len, r in results.items():
# saved_pct = (1.0 - 1.0 / r["speedup"]) * 100
# print(f" src_len={src_len}: {r['speedup']:.2f}x speedup "
# f"({saved_pct:.1f}% time saved, "
# f"encoder was {r['encoder_pct']:.1f}% of total)")
#
#
# if __name__ == "__main__":
# import sys, os
# sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__))))
# from config import CONFIG
# from inference import load_model
# from models.tokenizer import SanskritSourceTokenizer
#
# cfg = CONFIG
# device = torch.device(cfg['training']['device'])
#
# model_name = cfg['model_type']
# has_neg = cfg['data']['include_negative_examples']
# ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt"
#
# if not os.path.exists(ckpt):
# print(f"No checkpoint at {ckpt}. Train first.")
# sys.exit(1)
#
# model, cfg = load_model(ckpt, cfg, device)
# model.eval()
#
# src_tokenizer = SanskritSourceTokenizer(
# vocab_size = cfg['model'].get('src_vocab_size', 500),
# max_len = cfg['model']['max_seq_len'],
# )
#
# results = run_benchmark(model, src_tokenizer, device)
# print_summary(results)
# ============================================================
# FULL TASK 1: KV CACHE + PROJECTION + BENCHMARK + GRAPHS
# ============================================================
import torch
import torch.nn as nn
import torch.nn.functional as F
import time
import numpy as np
import matplotlib.pyplot as plt
# ============================================================
# πŸ”§ MODEL (PATCHED WITH PROJECTION + KV CACHE)
# ============================================================
class D3PMCrossAttention(nn.Module):
def __init__(self, d_model=512, vocab_size=500, max_seq_len=64, T=128):
super().__init__()
self.d_model = d_model
self.max_seq_len = max_seq_len
self.mask_token_id = 0
# Dummy encoder/decoder (replace with yours)
self.encoder = nn.Embedding(vocab_size, d_model)
self.tgt_embed = nn.Embedding(vocab_size, d_model)
self.head = nn.Linear(d_model, vocab_size)
self.time_mlp = nn.Linear(1, d_model)
self.hint_gate = nn.Linear(d_model, d_model)
# Fake scheduler
class Scheduler:
def __init__(self, T):
self.num_timesteps = T
self.scheduler = Scheduler(T)
# πŸ”₯ Projection layer (Task 1 requirement)
self.semantic_proj = nn.Linear(d_model, d_model // 2)
self.semantic_up = nn.Linear(d_model // 2, d_model)
# ========================================================
# βœ… ENCODER WITH PROJECTION
# ========================================================
def encode_source(self, src):
memory = self.encoder(src) # [B, L, d]
# πŸ”₯ Compress β†’ Expand
compressed = self.semantic_proj(memory)
memory = self.semantic_up(compressed)
src_pad_mask = None
return memory, src_pad_mask
# ========================================================
# βœ… STANDARD (NO CACHE)
# ========================================================
def forward(self, src, x, t):
memory, mask = self.encode_source(src)
return self.forward_cached(memory, mask, x, t)
# ========================================================
# βœ… CACHED FORWARD
# ========================================================
def forward_cached(self, memory, src_pad_mask, x, t, hint=None):
x = self.tgt_embed(x)
t_emb = self.time_mlp((t.float()/self.scheduler.num_timesteps).unsqueeze(-1))
x = x + t_emb.unsqueeze(1)
if hint is not None:
x = x + self.hint_gate(x) * self.tgt_embed(hint)
logits = self.head(x)
self._last_hidden = x
return logits, None
# ========================================================
# ❌ OLD GENERATE (SLOW)
# ========================================================
@torch.no_grad()
def generate(self, src):
B = src.shape[0]
device = src.device
T = self.scheduler.num_timesteps
x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, device=device)
logits, _ = self.forward(src, x, t)
probs = F.softmax(logits, dim=-1)
x = torch.argmax(probs, dim=-1)
return x
# ========================================================
# βœ… FAST GENERATE (KV CACHE)
# ========================================================
@torch.no_grad()
def generate_cached(self, src):
B = src.shape[0]
device = src.device
T = self.scheduler.num_timesteps
# πŸ”₯ Encode once
memory, mask = self.encode_source(src)
x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device)
hint = None
for t_val in range(T - 1, -1, -1):
t = torch.full((B,), t_val, device=device)
logits, _ = self.forward_cached(memory, mask, x, t, hint)
probs = F.softmax(logits, dim=-1)
x = torch.argmax(probs, dim=-1)
hint = x
return x
# ============================================================
# πŸ“Š BENCHMARK + MEMORY + GRAPHS
# ============================================================
def benchmark(model, device):
model.to(device)
model.eval()
vocab = 500
src_lens = [16, 32, 64]
standard_times = []
cached_times = []
speedups = []
memory_savings = []
for src_len in src_lens:
print(f"\nπŸ”Ή src_len = {src_len}")
src = torch.randint(5, vocab, (1, src_len)).to(device)
# -------- STANDARD --------
torch.cuda.reset_peak_memory_stats()
start = time.time()
model.generate(src)
torch.cuda.synchronize()
t_std = time.time() - start
mem_std = torch.cuda.max_memory_allocated() / 1024**2
# -------- CACHED --------
torch.cuda.reset_peak_memory_stats()
start = time.time()
model.generate_cached(src)
torch.cuda.synchronize()
t_cache = time.time() - start
mem_cache = torch.cuda.max_memory_allocated() / 1024**2
speedup = t_std / t_cache
mem_red = 100 * (mem_std - mem_cache) / mem_std
print(f"Time: {t_std:.2f}s β†’ {t_cache:.2f}s | {speedup:.2f}x")
print(f"Memory: {mem_std:.0f}MB β†’ {mem_cache:.0f}MB | {mem_red:.1f}%")
standard_times.append(t_std)
cached_times.append(t_cache)
speedups.append(speedup)
memory_savings.append(mem_red)
# ==========================
# πŸ“ˆ PLOT: TIME
# ==========================
plt.figure()
plt.plot(src_lens, standard_times, marker='o', label="Standard")
plt.plot(src_lens, cached_times, marker='o', label="Cached")
plt.xlabel("Source Length")
plt.ylabel("Time (s)")
plt.title("Generation Time")
plt.legend()
plt.grid()
plt.show()
# ==========================
# πŸ“ˆ PLOT: SPEEDUP
# ==========================
plt.figure()
plt.plot(src_lens, speedups, marker='o')
plt.xlabel("Source Length")
plt.ylabel("Speedup (x)")
plt.title("KV Cache Speedup")
plt.grid()
plt.show()
# ==========================
# πŸ“ˆ PLOT: MEMORY
# ==========================
plt.figure()
plt.plot(src_lens, memory_savings, marker='o')
plt.xlabel("Source Length")
plt.ylabel("Memory Reduction (%)")
plt.title("Memory Savings")
plt.grid()
plt.show()
# ============================================================
# πŸš€ RUN
# ============================================================
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = D3PMCrossAttention()
benchmark(model, device)