""" AttnVQ KV-cache demo (Gradio / Hugging Face Spaces). 1. Memory vs baselines — analytical KV footprint (Laguna-XS.2 geometry). 2. Live generation — Laguna-XS.2 with fp16 or VQQuantizedCache (uint8 indices). Tab 2 needs GPU hardware on the Space and artifacts/codebooks.pt for AttnVQ modes. """ from __future__ import annotations import math import os import time import torch # --------------------------------------------------------------------------- _USING_REAL = True try: from vqkv.quantizers import ProductVQKV, ScalarKV, KIVIScalarKV # type: ignore from vqkv.compressed_cache import ( # type: ignore VQQuantizedCache, kv_cache_bytes, LagunaGeom, ) except ModuleNotFoundError: try: from quantizers import ProductVQKV, ScalarKV, KIVIScalarKV # type: ignore from compressed_cache import VQQuantizedCache, kv_cache_bytes, LagunaGeom # type: ignore except ModuleNotFoundError: _USING_REAL = False VQQuantizedCache = None # type: ignore if not _USING_REAL: class LagunaGeom: n_layers = 40 full_layers = 10 sliding_layers = 30 sliding_window = 512 n_kv_heads = 8 head_dim = 128 def kv_cache_bytes(context_len, bits_per_elt_full, geom=LagunaGeom(), bits_per_elt_sliding=16.0): elts = geom.n_kv_heads * geom.head_dim * 2 full_tok = context_len * geom.full_layers slide_tok = min(context_len, geom.sliding_window) * geom.sliding_layers full_B = full_tok * elts * bits_per_elt_full / 8 slide_B = slide_tok * elts * bits_per_elt_sliding / 8 return {"full_B": full_B, "sliding_B": slide_B, "total_B": full_B + slide_B} class ProductVQKV: def __init__(self, n_sub, n_codes=256): self.n_sub = n_sub self.n_codes = n_codes def bits_per_element(self, head_dim): return self.n_sub * math.log2(self.n_codes) / head_dim class ScalarKV: def __init__(self, nbits=2, group=64): self.nbits = nbits self.group = group def bits_per_element(self, head_dim): return self.nbits + 16.0 / self.group class KIVIScalarKV(ScalarKV): pass ART = os.environ.get("VQKV_ARTIFACTS", "artifacts") MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2") GEOM = LagunaGeom() HD = GEOM.head_dim OURS = "AttnVQ" LIME = "#C6F24E" # Analytical bits/element for memory chart (TurboQuant from measured cheap-metrics run). _BASELINES = [ ("fp16", 16.0, "#9aa6b2"), ("scalar int2", ScalarKV(2).bits_per_element(HD), "#ff6b5e"), ("KIVI int2", KIVIScalarKV(2).bits_per_element(HD), "#ffb454"), ("TurboQuant 2b", 2.125, "#5BC8FF"), (f"{OURS} 2-bit", ProductVQKV(32, 256).bits_per_element(HD), LIME), (f"{OURS} 1-bit", ProductVQKV(16, 256).bits_per_element(HD), "#9ed63a"), ] _CACHE_CFG = { "fp16 (baseline)": None, f"{OURS} 2-bit": "productvq-32x256-2b", f"{OURS} 1-bit": "productvq-16x256-1b", } _MODEL: dict = {} _CODEBOOKS: dict | None = None def _baseline_bytes(context_len: int) -> list[tuple[str, float, float, str]]: """(label, GB, bpe, color) sorted by GB descending.""" rows = [] for label, bpe, color in _BASELINES: gb = kv_cache_bytes(context_len, bpe)["total_B"] / 1e9 rows.append((label, gb, bpe, color)) rows.sort(key=lambda r: r[1], reverse=True) return rows def memory_view(context_len): import matplotlib matplotlib.use("Agg") import matplotlib.pyplot as plt context_len = int(context_len) rows = _baseline_bytes(context_len) fp16_gb = next(g for lbl, g, _, _ in rows if lbl == "fp16") ours_2b = next((g, bpe) for lbl, g, bpe, _ in rows if lbl == f"{OURS} 2-bit") fig, (axb, axl) = plt.subplots(1, 2, figsize=(11, 4.6)) labels = [r[0] for r in rows] vals = [r[1] for r in rows] colors = [r[3] for r in rows] bars = axb.barh(labels, vals, color=colors) axb.set_xlabel("KV-cache memory (GB)") axb.set_title(f"Laguna-XS.2 @ {context_len:,} tokens") axb.invert_yaxis() for bar, val in zip(bars, vals): axb.text(val + 0.02, bar.get_y() + bar.get_height() / 2, f"{val:.2f} GB", va="center", fontsize=8) grid = [2048, 4096, 8192, 16384, 32768, 65536, 131072] bpe_2b = ours_2b[1] axl.plot(grid, [kv_cache_bytes(L, 16.0)["total_B"] / 1e9 for L in grid], "o-", color="#9aa6b2", label="fp16") axl.plot(grid, [kv_cache_bytes(L, bpe_2b)["total_B"] / 1e9 for L in grid], "o-", color=LIME, linewidth=2.4, label=f"{OURS} 2-bit") axl.plot(grid, [kv_cache_bytes(L, ScalarKV(2).bits_per_element(HD))["total_B"] / 1e9 for L in grid], "o--", color="#ff6b5e", alpha=0.7, label="scalar int2") axl.axvline(context_len, color="#888", ls=":", lw=1) axl.set_xscale("log", base=2) axl.set_xlabel("context length") axl.set_ylabel("GB") axl.set_title("Footprint vs context length") axl.grid(alpha=0.3) axl.legend(fontsize=8) fig.tight_layout() ratio = fp16_gb / ours_2b[0] sc2 = next(g for lbl, g, _, _ in rows if lbl == "scalar int2") md = ( f"### {ratio:.1f}× less KV cache than fp16 at {context_len:,} tokens\n" f"**{OURS} 2-bit** holds **{ours_2b[0]:.2f} GB** vs fp16 **{fp16_gb:.2f} GB** " f"and scalar int2 **{sc2:.2f} GB** — fewer bytes at higher fidelity.\n\n" "| method | bits/elt (full-attn layers) | KV cache |\n" "|---|---:|---:|\n" + "\n".join( f"| {lbl} | {bpe:.2f} | {gb:.2f} GB |" for lbl, gb, bpe, _ in sorted(rows, key=lambda r: r[1]) ) ) return fig, md def _get_model(): if _MODEL.get("model") is not None: return _MODEL["model"], _MODEL["tok"], _MODEL["device"] if not torch.cuda.is_available(): raise RuntimeError("CUDA is required for live generation.") from transformers import AutoModelForCausalLM, AutoTokenizer tok = AutoTokenizer.from_pretrained( MODEL_ID, trust_remote_code=True, fix_mistral_regex=True) model = AutoModelForCausalLM.from_pretrained( MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda", trust_remote_code=True, ).eval() device = "cuda:0" _MODEL.update(model=model, tok=tok, device=device) return model, tok, device def _get_codebooks(): global _CODEBOOKS if _CODEBOOKS is not None: return _CODEBOOKS if VQQuantizedCache is None: raise RuntimeError("vqkv.compressed_cache not importable.") path = os.path.join(ART, "codebooks.pt") if not os.path.exists(path): raise FileNotFoundError(f"{path} not found — set VQKV_ARTIFACTS=artifacts.") _CODEBOOKS = torch.load(path, map_location="cuda", weights_only=False) for per_layer in _CODEBOOKS["fitted"].values(): for q in per_layer.values(): if hasattr(q, "to"): q.to("cuda") return _CODEBOOKS def _measure_dynamic_cache(cache) -> int: total = 0 for buf in (getattr(cache, "key_cache", []), getattr(cache, "value_cache", [])): for t in buf: if isinstance(t, torch.Tensor): total += t.numel() * t.element_size() return total def run_generate(prompt, cache_label, max_new_tokens): if not prompt.strip(): prompt = "Write a one-sentence summary of why KV-cache compression matters for long context." try: model, tok, device = _get_model() except Exception as exc: return "", f"**Model load failed:** {exc}" cfg_name = _CACHE_CFG[cache_label] inputs = tok(prompt, return_tensors="pt").to(device) n_prompt = inputs["input_ids"].shape[1] if cfg_name is None: from transformers.cache_utils import DynamicCache cache = DynamicCache() compressed = False else: try: blob = _get_codebooks() cache = VQQuantizedCache(blob["fitted"][cfg_name], blob["meta"]["full_layers"]) compressed = True except Exception as exc: return "", f"**Cache setup failed:** {exc}" t0 = time.perf_counter() try: with torch.inference_mode(): out = model.generate( **inputs, max_new_tokens=int(max_new_tokens), past_key_values=cache, use_cache=True, do_sample=False, ) except Exception as exc: return "", f"**Generation failed:** {exc}" elapsed = time.perf_counter() - t0 n_new = out.shape[1] - n_prompt text = tok.decode(out[0, n_prompt:], skip_special_tokens=True) seq_len = out.shape[1] if compressed: fp = cache.memory_footprint() kv_b = fp["total_B"] cache_detail = ( f"uint8 indices {fp['compressed_indices_B'] / 1e6:.2f} MB · " f"codebooks {fp['codebooks_B'] / 1e6:.2f} MB · " f"sliding-window layers {fp['native_layers_B'] / 1e6:.2f} MB" ) else: kv_b = _measure_dynamic_cache(cache) cache_detail = "fp16 DynamicCache (measured on device)" fp16_b = kv_cache_bytes(seq_len, 16.0)["total_B"] ratio = fp16_b / kv_b if kv_b > 0 else 1.0 tok_s = n_new / elapsed if elapsed > 0 else 0.0 md = ( f"### {cache_label} · Laguna-XS.2\n\n" f"| | |\n|---|---|\n" f"| Prompt tokens | {n_prompt:,} |\n" f"| Generated tokens | {n_new:,} |\n" f"| Total sequence | {seq_len:,} |\n" f"| **KV cache (live)** | **{kv_b / 1e9:.3f} GB** |\n" f"| fp16 formula @ same length | {fp16_b / 1e9:.3f} GB |\n" f"| **vs fp16** | **{ratio:.1f}×** smaller |\n" f"| Throughput | {tok_s:.1f} tok/s |\n\n" f"_{cache_detail}_\n\n" ) if compressed: md += ( f"**{OURS}** stores only codebook **indices** on the 10 full-attention layers " f"(real memory savings). Dequantization is transient per layer — wall-clock speedup " f"needs a fused kernel." ) else: md += "fp16 baseline uses native bf16 KV tensors in `DynamicCache`." return text, md def build_demo(): import gradio as gr cuda = torch.cuda.is_available() cb_ok = os.path.exists(os.path.join(ART, "codebooks.pt")) real = "real `vqkv` package" if _USING_REAL else "fallback geometry only" subtitle = ( f"*Components: {real}. GPU: {'yes' if cuda else '**no — tab 2 disabled**'}. " f"Codebooks: {'found' if cb_ok else 'missing — AttnVQ modes need artifacts/codebooks.pt'}.*" ) with gr.Blocks(title=f"{OURS} KV-Cache") as demo: gr.Markdown( f"# {OURS} — attention-aware KV-cache quantization · Laguna-XS.2\n{subtitle}" ) with gr.Tab("Memory vs baselines"): ctx = gr.Slider(2048, 131072, value=131072, step=2048, label="Context length (tokens)") mplot, mmd = gr.Plot(), gr.Markdown() ctx.change(memory_view, ctx, [mplot, mmd]) demo.load(memory_view, ctx, [mplot, mmd]) with gr.Tab("Live generation"): gr.Markdown( "Run **Laguna-XS.2** with a real KV cache. " f"**{OURS}** modes use `VQQuantizedCache` (uint8 indices); " "fp16 uses native `DynamicCache`. Memory is read from the live cache after generation." ) prompt = gr.Textbox( label="Prompt", lines=4, value="Explain in two sentences how product vector quantization compresses a KV cache.", ) with gr.Row(): cache_mode = gr.Dropdown( list(_CACHE_CFG), value=f"{OURS} 2-bit", label="KV cache", interactive=cuda, ) max_tok = gr.Slider(16, 256, value=64, step=16, label="Max new tokens", interactive=cuda) btn = gr.Button("Generate", variant="primary", interactive=cuda) out_text = gr.Textbox(label="Completion", lines=6) out_md = gr.Markdown() btn.click(run_generate, [prompt, cache_mode, max_tok], [out_text, out_md]) return demo demo = build_demo() if __name__ == "__main__": demo.launch()