Spaces:
Sleeping
Sleeping
| # """ | |
| # analysis/kv_cache_benchmark.py | |
| # ================================ | |
| # Task 1: Benchmark KV cache vs standard generate(). | |
| # | |
| # Measures: | |
| # - Wall-clock time for generate() vs generate_cached() | |
| # - Encoder time as % of total generation time (before/after) | |
| # - Speedup ratio at src_len = 16, 32, 64 tokens | |
| # | |
| # How it works: | |
| # Standard generate(): | |
| # For each of T=128 steps: | |
| # src β encoder β memory β decoder β logits (encoder runs 128 times) | |
| # | |
| # generate_cached(): | |
| # src β encoder β memory (once) | |
| # For each of T=128 steps: | |
| # cached_memory β decoder β logits (encoder runs 1 time) | |
| # | |
| # Expected speedup: | |
| # If encoder = 30% of per-step time: | |
| # Saved = 127/128 * 30% β 29.7% of total time | |
| # If encoder = 50% of per-step time: | |
| # Saved β 49.6% of total time | |
| # | |
| # Usage: | |
| # python -m analysis.kv_cache_benchmark | |
| # or: | |
| # from analysis.kv_cache_benchmark import run_benchmark | |
| # results = run_benchmark(model, src_tokenizer, device) | |
| # """ | |
| # | |
| # import torch | |
| # import time | |
| # import numpy as np | |
| # from typing import Dict, List | |
| # | |
| # | |
| # def _make_src(src_len: int, src_vocab: int, device: torch.device, batch_size: int = 1): | |
| # """Create a random source tensor of given length.""" | |
| # # Random real tokens (ids 5..src_vocab-1), padded to src_len | |
| # ids = torch.randint(5, src_vocab, (batch_size, src_len), device=device) | |
| # return ids | |
| # | |
| # | |
| # def _time_fn(fn, n_warmup: int = 2, n_runs: int = 5) -> float: | |
| # """ | |
| # Time a zero-argument callable. | |
| # Returns mean wall-clock seconds over n_runs after n_warmup warmup calls. | |
| # """ | |
| # # Warmup | |
| # for _ in range(n_warmup): | |
| # fn() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.synchronize() | |
| # elif torch.backends.mps.is_available(): | |
| # torch.mps.synchronize() | |
| # | |
| # times = [] | |
| # for _ in range(n_runs): | |
| # start = time.perf_counter() | |
| # fn() | |
| # if torch.cuda.is_available(): | |
| # torch.cuda.synchronize() | |
| # elif torch.backends.mps.is_available(): | |
| # torch.mps.synchronize() | |
| # times.append(time.perf_counter() - start) | |
| # | |
| # return float(np.mean(times)) | |
| # | |
| # | |
| # def benchmark_encoder_cost( | |
| # model, | |
| # src: torch.Tensor, | |
| # ) -> Dict[str, float]: | |
| # """ | |
| # Measure encoder time as a fraction of one full forward pass. | |
| # | |
| # Returns: | |
| # encoder_s : seconds for one encoder call | |
| # full_step_s : seconds for one full forward_cached decoder step | |
| # encoder_pct : encoder_s / (encoder_s + full_step_s) * 100 | |
| # """ | |
| # inner = model.model | |
| # if not hasattr(inner, 'encode_source'): | |
| # raise ValueError("Model does not support KV cache (not D3PMCrossAttention).") | |
| # | |
| # device = src.device | |
| # B = src.shape[0] | |
| # T = inner.scheduler.num_timesteps | |
| # tgt_len = inner.max_seq_len | |
| # mask_id = inner.mask_token_id | |
| # | |
| # x0_est = torch.full((B, tgt_len), mask_id, dtype=torch.long, device=device) | |
| # t = torch.zeros(B, dtype=torch.long, device=device) | |
| # | |
| # # Time encoder alone | |
| # encoder_s = _time_fn(lambda: inner.encode_source(src)) | |
| # | |
| # # Pre-compute memory for decoder timing | |
| # memory, src_pad_mask = inner.encode_source(src) | |
| # | |
| # # Time one decoder step (cached) | |
| # decoder_s = _time_fn( | |
| # lambda: inner.forward_cached(memory, src_pad_mask, x0_est, t, | |
| # inference_mode=True) | |
| # ) | |
| # | |
| # # Time one full step (non-cached = encoder + decoder) | |
| # full_s = _time_fn( | |
| # lambda: inner.forward(src, x0_est, t, inference_mode=True) | |
| # ) | |
| # | |
| # encoder_pct = 100.0 * encoder_s / max(full_s, 1e-9) | |
| # | |
| # return { | |
| # "encoder_s": encoder_s, | |
| # "decoder_s": decoder_s, | |
| # "full_step_s": full_s, | |
| # "encoder_pct": encoder_pct, | |
| # } | |
| # | |
| # | |
| # def run_benchmark( | |
| # model, | |
| # src_tokenizer, | |
| # device: torch.device, | |
| # src_lens: List[int] = [16, 32, 64], | |
| # n_runs: int = 5, | |
| # ) -> Dict: | |
| # """ | |
| # Full benchmark: compare generate() vs generate_cached() at multiple src lengths. | |
| # | |
| # Args: | |
| # model : SanskritModel (D3PMCrossAttention) | |
| # src_tokenizer : SanskritSourceTokenizer | |
| # device : torch.device | |
| # src_lens : list of source lengths to benchmark | |
| # n_runs : number of timing runs per condition | |
| # | |
| # Returns: | |
| # results dict with timing and speedup for each src_len | |
| # """ | |
| # inner = model.model | |
| # if not hasattr(inner, 'generate_cached'): | |
| # raise ValueError("Model does not support KV cache (not D3PMCrossAttention).") | |
| # | |
| # src_vocab = inner.src_embed.token_emb.weight.shape[0] | |
| # results = {} | |
| # | |
| # print("\n" + "=" * 65) | |
| # print(" KV CACHE BENCHMARK") | |
| # print("=" * 65) | |
| # print(f" {'src_len':>8} {'standard(s)':>12} {'cached(s)':>10} " | |
| # f"{'speedup':>8} {'encoder%':>9}") | |
| # print("-" * 65) | |
| # | |
| # for src_len in src_lens: | |
| # src = _make_src(src_len, src_vocab, device) | |
| # | |
| # # Encoder cost breakdown | |
| # enc_cost = benchmark_encoder_cost(model, src) | |
| # | |
| # # Time standard generate() β encoder runs T times | |
| # def run_standard(): | |
| # return inner.generate(src, temperature=0.8, top_k=40) | |
| # | |
| # # Time generate_cached() β encoder runs once | |
| # def run_cached(): | |
| # return inner.generate_cached(src, temperature=0.8, top_k=40) | |
| # | |
| # t_standard = _time_fn(run_standard, n_warmup=1, n_runs=n_runs) | |
| # t_cached = _time_fn(run_cached, n_warmup=1, n_runs=n_runs) | |
| # speedup = t_standard / max(t_cached, 1e-9) | |
| # | |
| # results[src_len] = { | |
| # "standard_s": t_standard, | |
| # "cached_s": t_cached, | |
| # "speedup": speedup, | |
| # "encoder_pct": enc_cost["encoder_pct"], | |
| # } | |
| # | |
| # print(f" {src_len:>8} {t_standard:>12.3f} {t_cached:>10.3f} " | |
| # f"{speedup:>7.2f}x {enc_cost['encoder_pct']:>8.1f}%") | |
| # | |
| # print("=" * 65) | |
| # print(f"\n Encoder cost = % of one full forward pass") | |
| # print(f" Speedup = standard_time / cached_time") | |
| # print(f" Expected: speedup β 1 / (1 - encoder_pct/100 * (T-1)/T)") | |
| # | |
| # return results | |
| # | |
| # | |
| # def print_summary(results: Dict): | |
| # """Print a human-readable summary of benchmark results.""" | |
| # print("\n SUMMARY") | |
| # print(" -------") | |
| # for src_len, r in results.items(): | |
| # saved_pct = (1.0 - 1.0 / r["speedup"]) * 100 | |
| # print(f" src_len={src_len}: {r['speedup']:.2f}x speedup " | |
| # f"({saved_pct:.1f}% time saved, " | |
| # f"encoder was {r['encoder_pct']:.1f}% of total)") | |
| # | |
| # | |
| # if __name__ == "__main__": | |
| # import sys, os | |
| # sys.path.append(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) | |
| # from config import CONFIG | |
| # from inference import load_model | |
| # from models.tokenizer import SanskritSourceTokenizer | |
| # | |
| # cfg = CONFIG | |
| # device = torch.device(cfg['training']['device']) | |
| # | |
| # model_name = cfg['model_type'] | |
| # has_neg = cfg['data']['include_negative_examples'] | |
| # ckpt = f"results7/{model_name}_neg_{has_neg}/best_model.pt" | |
| # | |
| # if not os.path.exists(ckpt): | |
| # print(f"No checkpoint at {ckpt}. Train first.") | |
| # sys.exit(1) | |
| # | |
| # model, cfg = load_model(ckpt, cfg, device) | |
| # model.eval() | |
| # | |
| # src_tokenizer = SanskritSourceTokenizer( | |
| # vocab_size = cfg['model'].get('src_vocab_size', 500), | |
| # max_len = cfg['model']['max_seq_len'], | |
| # ) | |
| # | |
| # results = run_benchmark(model, src_tokenizer, device) | |
| # print_summary(results) | |
| # ============================================================ | |
| # FULL TASK 1: KV CACHE + PROJECTION + BENCHMARK + GRAPHS | |
| # ============================================================ | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import time | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| # ============================================================ | |
| # π§ MODEL (PATCHED WITH PROJECTION + KV CACHE) | |
| # ============================================================ | |
| class D3PMCrossAttention(nn.Module): | |
| def __init__(self, d_model=512, vocab_size=500, max_seq_len=64, T=128): | |
| super().__init__() | |
| self.d_model = d_model | |
| self.max_seq_len = max_seq_len | |
| self.mask_token_id = 0 | |
| # Dummy encoder/decoder (replace with yours) | |
| self.encoder = nn.Embedding(vocab_size, d_model) | |
| self.tgt_embed = nn.Embedding(vocab_size, d_model) | |
| self.head = nn.Linear(d_model, vocab_size) | |
| self.time_mlp = nn.Linear(1, d_model) | |
| self.hint_gate = nn.Linear(d_model, d_model) | |
| # Fake scheduler | |
| class Scheduler: | |
| def __init__(self, T): | |
| self.num_timesteps = T | |
| self.scheduler = Scheduler(T) | |
| # π₯ Projection layer (Task 1 requirement) | |
| self.semantic_proj = nn.Linear(d_model, d_model // 2) | |
| self.semantic_up = nn.Linear(d_model // 2, d_model) | |
| # ======================================================== | |
| # β ENCODER WITH PROJECTION | |
| # ======================================================== | |
| def encode_source(self, src): | |
| memory = self.encoder(src) # [B, L, d] | |
| # π₯ Compress β Expand | |
| compressed = self.semantic_proj(memory) | |
| memory = self.semantic_up(compressed) | |
| src_pad_mask = None | |
| return memory, src_pad_mask | |
| # ======================================================== | |
| # β STANDARD (NO CACHE) | |
| # ======================================================== | |
| def forward(self, src, x, t): | |
| memory, mask = self.encode_source(src) | |
| return self.forward_cached(memory, mask, x, t) | |
| # ======================================================== | |
| # β CACHED FORWARD | |
| # ======================================================== | |
| def forward_cached(self, memory, src_pad_mask, x, t, hint=None): | |
| x = self.tgt_embed(x) | |
| t_emb = self.time_mlp((t.float()/self.scheduler.num_timesteps).unsqueeze(-1)) | |
| x = x + t_emb.unsqueeze(1) | |
| if hint is not None: | |
| x = x + self.hint_gate(x) * self.tgt_embed(hint) | |
| logits = self.head(x) | |
| self._last_hidden = x | |
| return logits, None | |
| # ======================================================== | |
| # β OLD GENERATE (SLOW) | |
| # ======================================================== | |
| def generate(self, src): | |
| B = src.shape[0] | |
| device = src.device | |
| T = self.scheduler.num_timesteps | |
| x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device) | |
| for t_val in range(T - 1, -1, -1): | |
| t = torch.full((B,), t_val, device=device) | |
| logits, _ = self.forward(src, x, t) | |
| probs = F.softmax(logits, dim=-1) | |
| x = torch.argmax(probs, dim=-1) | |
| return x | |
| # ======================================================== | |
| # β FAST GENERATE (KV CACHE) | |
| # ======================================================== | |
| def generate_cached(self, src): | |
| B = src.shape[0] | |
| device = src.device | |
| T = self.scheduler.num_timesteps | |
| # π₯ Encode once | |
| memory, mask = self.encode_source(src) | |
| x = torch.zeros((B, self.max_seq_len), dtype=torch.long, device=device) | |
| hint = None | |
| for t_val in range(T - 1, -1, -1): | |
| t = torch.full((B,), t_val, device=device) | |
| logits, _ = self.forward_cached(memory, mask, x, t, hint) | |
| probs = F.softmax(logits, dim=-1) | |
| x = torch.argmax(probs, dim=-1) | |
| hint = x | |
| return x | |
| # ============================================================ | |
| # π BENCHMARK + MEMORY + GRAPHS | |
| # ============================================================ | |
| def benchmark(model, device): | |
| model.to(device) | |
| model.eval() | |
| vocab = 500 | |
| src_lens = [16, 32, 64] | |
| standard_times = [] | |
| cached_times = [] | |
| speedups = [] | |
| memory_savings = [] | |
| for src_len in src_lens: | |
| print(f"\nπΉ src_len = {src_len}") | |
| src = torch.randint(5, vocab, (1, src_len)).to(device) | |
| # -------- STANDARD -------- | |
| torch.cuda.reset_peak_memory_stats() | |
| start = time.time() | |
| model.generate(src) | |
| torch.cuda.synchronize() | |
| t_std = time.time() - start | |
| mem_std = torch.cuda.max_memory_allocated() / 1024**2 | |
| # -------- CACHED -------- | |
| torch.cuda.reset_peak_memory_stats() | |
| start = time.time() | |
| model.generate_cached(src) | |
| torch.cuda.synchronize() | |
| t_cache = time.time() - start | |
| mem_cache = torch.cuda.max_memory_allocated() / 1024**2 | |
| speedup = t_std / t_cache | |
| mem_red = 100 * (mem_std - mem_cache) / mem_std | |
| print(f"Time: {t_std:.2f}s β {t_cache:.2f}s | {speedup:.2f}x") | |
| print(f"Memory: {mem_std:.0f}MB β {mem_cache:.0f}MB | {mem_red:.1f}%") | |
| standard_times.append(t_std) | |
| cached_times.append(t_cache) | |
| speedups.append(speedup) | |
| memory_savings.append(mem_red) | |
| # ========================== | |
| # π PLOT: TIME | |
| # ========================== | |
| plt.figure() | |
| plt.plot(src_lens, standard_times, marker='o', label="Standard") | |
| plt.plot(src_lens, cached_times, marker='o', label="Cached") | |
| plt.xlabel("Source Length") | |
| plt.ylabel("Time (s)") | |
| plt.title("Generation Time") | |
| plt.legend() | |
| plt.grid() | |
| plt.show() | |
| # ========================== | |
| # π PLOT: SPEEDUP | |
| # ========================== | |
| plt.figure() | |
| plt.plot(src_lens, speedups, marker='o') | |
| plt.xlabel("Source Length") | |
| plt.ylabel("Speedup (x)") | |
| plt.title("KV Cache Speedup") | |
| plt.grid() | |
| plt.show() | |
| # ========================== | |
| # π PLOT: MEMORY | |
| # ========================== | |
| plt.figure() | |
| plt.plot(src_lens, memory_savings, marker='o') | |
| plt.xlabel("Source Length") | |
| plt.ylabel("Memory Reduction (%)") | |
| plt.title("Memory Savings") | |
| plt.grid() | |
| plt.show() | |
| # ============================================================ | |
| # π RUN | |
| # ============================================================ | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = D3PMCrossAttention() | |
| benchmark(model, device) |