| """ |
| AttnVQ KV-cache demo (Gradio / Hugging Face Spaces). |
| |
| 1. Memory vs baselines — analytical KV footprint (Laguna-XS.2 geometry). |
| 2. Live generation — Laguna-XS.2 with fp16 or VQQuantizedCache (uint8 indices). |
| |
| Tab 2 needs GPU hardware on the Space and artifacts/codebooks.pt for AttnVQ modes. |
| """ |
|
|
| from __future__ import annotations |
|
|
| import math |
| import os |
| import time |
|
|
| import torch |
|
|
| |
| _USING_REAL = True |
| try: |
| from vqkv.quantizers import ProductVQKV, ScalarKV, KIVIScalarKV |
| from vqkv.compressed_cache import ( |
| VQQuantizedCache, kv_cache_bytes, LagunaGeom, |
| ) |
| except ModuleNotFoundError: |
| try: |
| from quantizers import ProductVQKV, ScalarKV, KIVIScalarKV |
| from compressed_cache import VQQuantizedCache, kv_cache_bytes, LagunaGeom |
| except ModuleNotFoundError: |
| _USING_REAL = False |
| VQQuantizedCache = None |
|
|
| if not _USING_REAL: |
| class LagunaGeom: |
| n_layers = 40 |
| full_layers = 10 |
| sliding_layers = 30 |
| sliding_window = 512 |
| n_kv_heads = 8 |
| head_dim = 128 |
|
|
| def kv_cache_bytes(context_len, bits_per_elt_full, geom=LagunaGeom(), |
| bits_per_elt_sliding=16.0): |
| elts = geom.n_kv_heads * geom.head_dim * 2 |
| full_tok = context_len * geom.full_layers |
| slide_tok = min(context_len, geom.sliding_window) * geom.sliding_layers |
| full_B = full_tok * elts * bits_per_elt_full / 8 |
| slide_B = slide_tok * elts * bits_per_elt_sliding / 8 |
| return {"full_B": full_B, "sliding_B": slide_B, "total_B": full_B + slide_B} |
|
|
| class ProductVQKV: |
| def __init__(self, n_sub, n_codes=256): |
| self.n_sub = n_sub |
| self.n_codes = n_codes |
|
|
| def bits_per_element(self, head_dim): |
| return self.n_sub * math.log2(self.n_codes) / head_dim |
|
|
| class ScalarKV: |
| def __init__(self, nbits=2, group=64): |
| self.nbits = nbits |
| self.group = group |
|
|
| def bits_per_element(self, head_dim): |
| return self.nbits + 16.0 / self.group |
|
|
| class KIVIScalarKV(ScalarKV): |
| pass |
|
|
|
|
| ART = os.environ.get("VQKV_ARTIFACTS", "artifacts") |
| MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2") |
| GEOM = LagunaGeom() |
| HD = GEOM.head_dim |
| OURS = "AttnVQ" |
| LIME = "#C6F24E" |
|
|
| |
| _BASELINES = [ |
| ("fp16", 16.0, "#9aa6b2"), |
| ("scalar int2", ScalarKV(2).bits_per_element(HD), "#ff6b5e"), |
| ("KIVI int2", KIVIScalarKV(2).bits_per_element(HD), "#ffb454"), |
| ("TurboQuant 2b", 2.125, "#5BC8FF"), |
| (f"{OURS} 2-bit", ProductVQKV(32, 256).bits_per_element(HD), LIME), |
| (f"{OURS} 1-bit", ProductVQKV(16, 256).bits_per_element(HD), "#9ed63a"), |
| ] |
|
|
| _CACHE_CFG = { |
| "fp16 (baseline)": None, |
| f"{OURS} 2-bit": "productvq-32x256-2b", |
| f"{OURS} 1-bit": "productvq-16x256-1b", |
| } |
|
|
| _MODEL: dict = {} |
| _CODEBOOKS: dict | None = None |
|
|
|
|
| def _baseline_bytes(context_len: int) -> list[tuple[str, float, float, str]]: |
| """(label, GB, bpe, color) sorted by GB descending.""" |
| rows = [] |
| for label, bpe, color in _BASELINES: |
| gb = kv_cache_bytes(context_len, bpe)["total_B"] / 1e9 |
| rows.append((label, gb, bpe, color)) |
| rows.sort(key=lambda r: r[1], reverse=True) |
| return rows |
|
|
|
|
| def memory_view(context_len): |
| import matplotlib |
| matplotlib.use("Agg") |
| import matplotlib.pyplot as plt |
|
|
| context_len = int(context_len) |
| rows = _baseline_bytes(context_len) |
| fp16_gb = next(g for lbl, g, _, _ in rows if lbl == "fp16") |
| ours_2b = next((g, bpe) for lbl, g, bpe, _ in rows if lbl == f"{OURS} 2-bit") |
|
|
| fig, (axb, axl) = plt.subplots(1, 2, figsize=(11, 4.6)) |
| labels = [r[0] for r in rows] |
| vals = [r[1] for r in rows] |
| colors = [r[3] for r in rows] |
| bars = axb.barh(labels, vals, color=colors) |
| axb.set_xlabel("KV-cache memory (GB)") |
| axb.set_title(f"Laguna-XS.2 @ {context_len:,} tokens") |
| axb.invert_yaxis() |
| for bar, val in zip(bars, vals): |
| axb.text(val + 0.02, bar.get_y() + bar.get_height() / 2, |
| f"{val:.2f} GB", va="center", fontsize=8) |
|
|
| grid = [2048, 4096, 8192, 16384, 32768, 65536, 131072] |
| bpe_2b = ours_2b[1] |
| axl.plot(grid, [kv_cache_bytes(L, 16.0)["total_B"] / 1e9 for L in grid], |
| "o-", color="#9aa6b2", label="fp16") |
| axl.plot(grid, [kv_cache_bytes(L, bpe_2b)["total_B"] / 1e9 for L in grid], |
| "o-", color=LIME, linewidth=2.4, label=f"{OURS} 2-bit") |
| axl.plot(grid, [kv_cache_bytes(L, ScalarKV(2).bits_per_element(HD))["total_B"] / 1e9 |
| for L in grid], |
| "o--", color="#ff6b5e", alpha=0.7, label="scalar int2") |
| axl.axvline(context_len, color="#888", ls=":", lw=1) |
| axl.set_xscale("log", base=2) |
| axl.set_xlabel("context length") |
| axl.set_ylabel("GB") |
| axl.set_title("Footprint vs context length") |
| axl.grid(alpha=0.3) |
| axl.legend(fontsize=8) |
| fig.tight_layout() |
|
|
| ratio = fp16_gb / ours_2b[0] |
| sc2 = next(g for lbl, g, _, _ in rows if lbl == "scalar int2") |
| md = ( |
| f"### {ratio:.1f}× less KV cache than fp16 at {context_len:,} tokens\n" |
| f"**{OURS} 2-bit** holds **{ours_2b[0]:.2f} GB** vs fp16 **{fp16_gb:.2f} GB** " |
| f"and scalar int2 **{sc2:.2f} GB** — fewer bytes at higher fidelity.\n\n" |
| "| method | bits/elt (full-attn layers) | KV cache |\n" |
| "|---|---:|---:|\n" |
| + "\n".join( |
| f"| {lbl} | {bpe:.2f} | {gb:.2f} GB |" |
| for lbl, gb, bpe, _ in sorted(rows, key=lambda r: r[1]) |
| ) |
| ) |
| return fig, md |
|
|
|
|
| def _get_model(): |
| if _MODEL.get("model") is not None: |
| return _MODEL["model"], _MODEL["tok"], _MODEL["device"] |
| if not torch.cuda.is_available(): |
| raise RuntimeError("CUDA is required for live generation.") |
| from transformers import AutoModelForCausalLM, AutoTokenizer |
|
|
| tok = AutoTokenizer.from_pretrained( |
| MODEL_ID, trust_remote_code=True, fix_mistral_regex=True) |
| model = AutoModelForCausalLM.from_pretrained( |
| MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda", |
| trust_remote_code=True, |
| ).eval() |
| device = "cuda:0" |
| _MODEL.update(model=model, tok=tok, device=device) |
| return model, tok, device |
|
|
|
|
| def _get_codebooks(): |
| global _CODEBOOKS |
| if _CODEBOOKS is not None: |
| return _CODEBOOKS |
| if VQQuantizedCache is None: |
| raise RuntimeError("vqkv.compressed_cache not importable.") |
| path = os.path.join(ART, "codebooks.pt") |
| if not os.path.exists(path): |
| raise FileNotFoundError(f"{path} not found — set VQKV_ARTIFACTS=artifacts.") |
| _CODEBOOKS = torch.load(path, map_location="cuda", weights_only=False) |
| for per_layer in _CODEBOOKS["fitted"].values(): |
| for q in per_layer.values(): |
| if hasattr(q, "to"): |
| q.to("cuda") |
| return _CODEBOOKS |
|
|
|
|
| def _measure_dynamic_cache(cache) -> int: |
| total = 0 |
| for buf in (getattr(cache, "key_cache", []), getattr(cache, "value_cache", [])): |
| for t in buf: |
| if isinstance(t, torch.Tensor): |
| total += t.numel() * t.element_size() |
| return total |
|
|
|
|
| def run_generate(prompt, cache_label, max_new_tokens): |
| if not prompt.strip(): |
| prompt = "Write a one-sentence summary of why KV-cache compression matters for long context." |
|
|
| try: |
| model, tok, device = _get_model() |
| except Exception as exc: |
| return "", f"**Model load failed:** {exc}" |
|
|
| cfg_name = _CACHE_CFG[cache_label] |
| inputs = tok(prompt, return_tensors="pt").to(device) |
| n_prompt = inputs["input_ids"].shape[1] |
|
|
| if cfg_name is None: |
| from transformers.cache_utils import DynamicCache |
| cache = DynamicCache() |
| compressed = False |
| else: |
| try: |
| blob = _get_codebooks() |
| cache = VQQuantizedCache(blob["fitted"][cfg_name], blob["meta"]["full_layers"]) |
| compressed = True |
| except Exception as exc: |
| return "", f"**Cache setup failed:** {exc}" |
|
|
| t0 = time.perf_counter() |
| try: |
| with torch.inference_mode(): |
| out = model.generate( |
| **inputs, |
| max_new_tokens=int(max_new_tokens), |
| past_key_values=cache, |
| use_cache=True, |
| do_sample=False, |
| ) |
| except Exception as exc: |
| return "", f"**Generation failed:** {exc}" |
| elapsed = time.perf_counter() - t0 |
|
|
| n_new = out.shape[1] - n_prompt |
| text = tok.decode(out[0, n_prompt:], skip_special_tokens=True) |
| seq_len = out.shape[1] |
|
|
| if compressed: |
| fp = cache.memory_footprint() |
| kv_b = fp["total_B"] |
| cache_detail = ( |
| f"uint8 indices {fp['compressed_indices_B'] / 1e6:.2f} MB · " |
| f"codebooks {fp['codebooks_B'] / 1e6:.2f} MB · " |
| f"sliding-window layers {fp['native_layers_B'] / 1e6:.2f} MB" |
| ) |
| else: |
| kv_b = _measure_dynamic_cache(cache) |
| cache_detail = "fp16 DynamicCache (measured on device)" |
|
|
| fp16_b = kv_cache_bytes(seq_len, 16.0)["total_B"] |
| ratio = fp16_b / kv_b if kv_b > 0 else 1.0 |
| tok_s = n_new / elapsed if elapsed > 0 else 0.0 |
|
|
| md = ( |
| f"### {cache_label} · Laguna-XS.2\n\n" |
| f"| | |\n|---|---|\n" |
| f"| Prompt tokens | {n_prompt:,} |\n" |
| f"| Generated tokens | {n_new:,} |\n" |
| f"| Total sequence | {seq_len:,} |\n" |
| f"| **KV cache (live)** | **{kv_b / 1e9:.3f} GB** |\n" |
| f"| fp16 formula @ same length | {fp16_b / 1e9:.3f} GB |\n" |
| f"| **vs fp16** | **{ratio:.1f}×** smaller |\n" |
| f"| Throughput | {tok_s:.1f} tok/s |\n\n" |
| f"_{cache_detail}_\n\n" |
| ) |
| if compressed: |
| md += ( |
| f"**{OURS}** stores only codebook **indices** on the 10 full-attention layers " |
| f"(real memory savings). Dequantization is transient per layer — wall-clock speedup " |
| f"needs a fused kernel." |
| ) |
| else: |
| md += "fp16 baseline uses native bf16 KV tensors in `DynamicCache`." |
| return text, md |
|
|
|
|
| def build_demo(): |
| import gradio as gr |
|
|
| cuda = torch.cuda.is_available() |
| cb_ok = os.path.exists(os.path.join(ART, "codebooks.pt")) |
| real = "real `vqkv` package" if _USING_REAL else "fallback geometry only" |
| subtitle = ( |
| f"*Components: {real}. GPU: {'yes' if cuda else '**no — tab 2 disabled**'}. " |
| f"Codebooks: {'found' if cb_ok else 'missing — AttnVQ modes need artifacts/codebooks.pt'}.*" |
| ) |
|
|
| with gr.Blocks(title=f"{OURS} KV-Cache") as demo: |
| gr.Markdown( |
| f"# {OURS} — attention-aware KV-cache quantization · Laguna-XS.2\n{subtitle}" |
| ) |
|
|
| with gr.Tab("Memory vs baselines"): |
| ctx = gr.Slider(2048, 131072, value=131072, step=2048, label="Context length (tokens)") |
| mplot, mmd = gr.Plot(), gr.Markdown() |
| ctx.change(memory_view, ctx, [mplot, mmd]) |
| demo.load(memory_view, ctx, [mplot, mmd]) |
|
|
| with gr.Tab("Live generation"): |
| gr.Markdown( |
| "Run **Laguna-XS.2** with a real KV cache. " |
| f"**{OURS}** modes use `VQQuantizedCache` (uint8 indices); " |
| "fp16 uses native `DynamicCache`. Memory is read from the live cache after generation." |
| ) |
| prompt = gr.Textbox( |
| label="Prompt", |
| lines=4, |
| value="Explain in two sentences how product vector quantization compresses a KV cache.", |
| ) |
| with gr.Row(): |
| cache_mode = gr.Dropdown( |
| list(_CACHE_CFG), |
| value=f"{OURS} 2-bit", |
| label="KV cache", |
| interactive=cuda, |
| ) |
| max_tok = gr.Slider(16, 256, value=64, step=16, label="Max new tokens", |
| interactive=cuda) |
| btn = gr.Button("Generate", variant="primary", interactive=cuda) |
| out_text = gr.Textbox(label="Completion", lines=6) |
| out_md = gr.Markdown() |
| btn.click(run_generate, [prompt, cache_mode, max_tok], [out_text, out_md]) |
|
|
| return demo |
|
|
|
|
| demo = build_demo() |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|