attnvq / app.py
adirik's picture
add demo and slides
20d8ec4
Raw
History Blame Contribute Delete
12.4 kB
"""
AttnVQ KV-cache demo (Gradio / Hugging Face Spaces).
1. Memory vs baselines — analytical KV footprint (Laguna-XS.2 geometry).
2. Live generation — Laguna-XS.2 with fp16 or VQQuantizedCache (uint8 indices).
Tab 2 needs GPU hardware on the Space and artifacts/codebooks.pt for AttnVQ modes.
"""
from __future__ import annotations
import math
import os
import time
import torch
# ---------------------------------------------------------------------------
_USING_REAL = True
try:
from vqkv.quantizers import ProductVQKV, ScalarKV, KIVIScalarKV # type: ignore
from vqkv.compressed_cache import ( # type: ignore
VQQuantizedCache, kv_cache_bytes, LagunaGeom,
)
except ModuleNotFoundError:
try:
from quantizers import ProductVQKV, ScalarKV, KIVIScalarKV # type: ignore
from compressed_cache import VQQuantizedCache, kv_cache_bytes, LagunaGeom # type: ignore
except ModuleNotFoundError:
_USING_REAL = False
VQQuantizedCache = None # type: ignore
if not _USING_REAL:
class LagunaGeom:
n_layers = 40
full_layers = 10
sliding_layers = 30
sliding_window = 512
n_kv_heads = 8
head_dim = 128
def kv_cache_bytes(context_len, bits_per_elt_full, geom=LagunaGeom(),
bits_per_elt_sliding=16.0):
elts = geom.n_kv_heads * geom.head_dim * 2
full_tok = context_len * geom.full_layers
slide_tok = min(context_len, geom.sliding_window) * geom.sliding_layers
full_B = full_tok * elts * bits_per_elt_full / 8
slide_B = slide_tok * elts * bits_per_elt_sliding / 8
return {"full_B": full_B, "sliding_B": slide_B, "total_B": full_B + slide_B}
class ProductVQKV:
def __init__(self, n_sub, n_codes=256):
self.n_sub = n_sub
self.n_codes = n_codes
def bits_per_element(self, head_dim):
return self.n_sub * math.log2(self.n_codes) / head_dim
class ScalarKV:
def __init__(self, nbits=2, group=64):
self.nbits = nbits
self.group = group
def bits_per_element(self, head_dim):
return self.nbits + 16.0 / self.group
class KIVIScalarKV(ScalarKV):
pass
ART = os.environ.get("VQKV_ARTIFACTS", "artifacts")
MODEL_ID = os.environ.get("LAGUNA_ID", "poolside/Laguna-XS.2")
GEOM = LagunaGeom()
HD = GEOM.head_dim
OURS = "AttnVQ"
LIME = "#C6F24E"
# Analytical bits/element for memory chart (TurboQuant from measured cheap-metrics run).
_BASELINES = [
("fp16", 16.0, "#9aa6b2"),
("scalar int2", ScalarKV(2).bits_per_element(HD), "#ff6b5e"),
("KIVI int2", KIVIScalarKV(2).bits_per_element(HD), "#ffb454"),
("TurboQuant 2b", 2.125, "#5BC8FF"),
(f"{OURS} 2-bit", ProductVQKV(32, 256).bits_per_element(HD), LIME),
(f"{OURS} 1-bit", ProductVQKV(16, 256).bits_per_element(HD), "#9ed63a"),
]
_CACHE_CFG = {
"fp16 (baseline)": None,
f"{OURS} 2-bit": "productvq-32x256-2b",
f"{OURS} 1-bit": "productvq-16x256-1b",
}
_MODEL: dict = {}
_CODEBOOKS: dict | None = None
def _baseline_bytes(context_len: int) -> list[tuple[str, float, float, str]]:
"""(label, GB, bpe, color) sorted by GB descending."""
rows = []
for label, bpe, color in _BASELINES:
gb = kv_cache_bytes(context_len, bpe)["total_B"] / 1e9
rows.append((label, gb, bpe, color))
rows.sort(key=lambda r: r[1], reverse=True)
return rows
def memory_view(context_len):
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
context_len = int(context_len)
rows = _baseline_bytes(context_len)
fp16_gb = next(g for lbl, g, _, _ in rows if lbl == "fp16")
ours_2b = next((g, bpe) for lbl, g, bpe, _ in rows if lbl == f"{OURS} 2-bit")
fig, (axb, axl) = plt.subplots(1, 2, figsize=(11, 4.6))
labels = [r[0] for r in rows]
vals = [r[1] for r in rows]
colors = [r[3] for r in rows]
bars = axb.barh(labels, vals, color=colors)
axb.set_xlabel("KV-cache memory (GB)")
axb.set_title(f"Laguna-XS.2 @ {context_len:,} tokens")
axb.invert_yaxis()
for bar, val in zip(bars, vals):
axb.text(val + 0.02, bar.get_y() + bar.get_height() / 2,
f"{val:.2f} GB", va="center", fontsize=8)
grid = [2048, 4096, 8192, 16384, 32768, 65536, 131072]
bpe_2b = ours_2b[1]
axl.plot(grid, [kv_cache_bytes(L, 16.0)["total_B"] / 1e9 for L in grid],
"o-", color="#9aa6b2", label="fp16")
axl.plot(grid, [kv_cache_bytes(L, bpe_2b)["total_B"] / 1e9 for L in grid],
"o-", color=LIME, linewidth=2.4, label=f"{OURS} 2-bit")
axl.plot(grid, [kv_cache_bytes(L, ScalarKV(2).bits_per_element(HD))["total_B"] / 1e9
for L in grid],
"o--", color="#ff6b5e", alpha=0.7, label="scalar int2")
axl.axvline(context_len, color="#888", ls=":", lw=1)
axl.set_xscale("log", base=2)
axl.set_xlabel("context length")
axl.set_ylabel("GB")
axl.set_title("Footprint vs context length")
axl.grid(alpha=0.3)
axl.legend(fontsize=8)
fig.tight_layout()
ratio = fp16_gb / ours_2b[0]
sc2 = next(g for lbl, g, _, _ in rows if lbl == "scalar int2")
md = (
f"### {ratio:.1f}× less KV cache than fp16 at {context_len:,} tokens\n"
f"**{OURS} 2-bit** holds **{ours_2b[0]:.2f} GB** vs fp16 **{fp16_gb:.2f} GB** "
f"and scalar int2 **{sc2:.2f} GB** — fewer bytes at higher fidelity.\n\n"
"| method | bits/elt (full-attn layers) | KV cache |\n"
"|---|---:|---:|\n"
+ "\n".join(
f"| {lbl} | {bpe:.2f} | {gb:.2f} GB |"
for lbl, gb, bpe, _ in sorted(rows, key=lambda r: r[1])
)
)
return fig, md
def _get_model():
if _MODEL.get("model") is not None:
return _MODEL["model"], _MODEL["tok"], _MODEL["device"]
if not torch.cuda.is_available():
raise RuntimeError("CUDA is required for live generation.")
from transformers import AutoModelForCausalLM, AutoTokenizer
tok = AutoTokenizer.from_pretrained(
MODEL_ID, trust_remote_code=True, fix_mistral_regex=True)
model = AutoModelForCausalLM.from_pretrained(
MODEL_ID, torch_dtype=torch.bfloat16, device_map="cuda",
trust_remote_code=True,
).eval()
device = "cuda:0"
_MODEL.update(model=model, tok=tok, device=device)
return model, tok, device
def _get_codebooks():
global _CODEBOOKS
if _CODEBOOKS is not None:
return _CODEBOOKS
if VQQuantizedCache is None:
raise RuntimeError("vqkv.compressed_cache not importable.")
path = os.path.join(ART, "codebooks.pt")
if not os.path.exists(path):
raise FileNotFoundError(f"{path} not found — set VQKV_ARTIFACTS=artifacts.")
_CODEBOOKS = torch.load(path, map_location="cuda", weights_only=False)
for per_layer in _CODEBOOKS["fitted"].values():
for q in per_layer.values():
if hasattr(q, "to"):
q.to("cuda")
return _CODEBOOKS
def _measure_dynamic_cache(cache) -> int:
total = 0
for buf in (getattr(cache, "key_cache", []), getattr(cache, "value_cache", [])):
for t in buf:
if isinstance(t, torch.Tensor):
total += t.numel() * t.element_size()
return total
def run_generate(prompt, cache_label, max_new_tokens):
if not prompt.strip():
prompt = "Write a one-sentence summary of why KV-cache compression matters for long context."
try:
model, tok, device = _get_model()
except Exception as exc:
return "", f"**Model load failed:** {exc}"
cfg_name = _CACHE_CFG[cache_label]
inputs = tok(prompt, return_tensors="pt").to(device)
n_prompt = inputs["input_ids"].shape[1]
if cfg_name is None:
from transformers.cache_utils import DynamicCache
cache = DynamicCache()
compressed = False
else:
try:
blob = _get_codebooks()
cache = VQQuantizedCache(blob["fitted"][cfg_name], blob["meta"]["full_layers"])
compressed = True
except Exception as exc:
return "", f"**Cache setup failed:** {exc}"
t0 = time.perf_counter()
try:
with torch.inference_mode():
out = model.generate(
**inputs,
max_new_tokens=int(max_new_tokens),
past_key_values=cache,
use_cache=True,
do_sample=False,
)
except Exception as exc:
return "", f"**Generation failed:** {exc}"
elapsed = time.perf_counter() - t0
n_new = out.shape[1] - n_prompt
text = tok.decode(out[0, n_prompt:], skip_special_tokens=True)
seq_len = out.shape[1]
if compressed:
fp = cache.memory_footprint()
kv_b = fp["total_B"]
cache_detail = (
f"uint8 indices {fp['compressed_indices_B'] / 1e6:.2f} MB · "
f"codebooks {fp['codebooks_B'] / 1e6:.2f} MB · "
f"sliding-window layers {fp['native_layers_B'] / 1e6:.2f} MB"
)
else:
kv_b = _measure_dynamic_cache(cache)
cache_detail = "fp16 DynamicCache (measured on device)"
fp16_b = kv_cache_bytes(seq_len, 16.0)["total_B"]
ratio = fp16_b / kv_b if kv_b > 0 else 1.0
tok_s = n_new / elapsed if elapsed > 0 else 0.0
md = (
f"### {cache_label} · Laguna-XS.2\n\n"
f"| | |\n|---|---|\n"
f"| Prompt tokens | {n_prompt:,} |\n"
f"| Generated tokens | {n_new:,} |\n"
f"| Total sequence | {seq_len:,} |\n"
f"| **KV cache (live)** | **{kv_b / 1e9:.3f} GB** |\n"
f"| fp16 formula @ same length | {fp16_b / 1e9:.3f} GB |\n"
f"| **vs fp16** | **{ratio:.1f}×** smaller |\n"
f"| Throughput | {tok_s:.1f} tok/s |\n\n"
f"_{cache_detail}_\n\n"
)
if compressed:
md += (
f"**{OURS}** stores only codebook **indices** on the 10 full-attention layers "
f"(real memory savings). Dequantization is transient per layer — wall-clock speedup "
f"needs a fused kernel."
)
else:
md += "fp16 baseline uses native bf16 KV tensors in `DynamicCache`."
return text, md
def build_demo():
import gradio as gr
cuda = torch.cuda.is_available()
cb_ok = os.path.exists(os.path.join(ART, "codebooks.pt"))
real = "real `vqkv` package" if _USING_REAL else "fallback geometry only"
subtitle = (
f"*Components: {real}. GPU: {'yes' if cuda else '**no — tab 2 disabled**'}. "
f"Codebooks: {'found' if cb_ok else 'missing — AttnVQ modes need artifacts/codebooks.pt'}.*"
)
with gr.Blocks(title=f"{OURS} KV-Cache") as demo:
gr.Markdown(
f"# {OURS} — attention-aware KV-cache quantization · Laguna-XS.2\n{subtitle}"
)
with gr.Tab("Memory vs baselines"):
ctx = gr.Slider(2048, 131072, value=131072, step=2048, label="Context length (tokens)")
mplot, mmd = gr.Plot(), gr.Markdown()
ctx.change(memory_view, ctx, [mplot, mmd])
demo.load(memory_view, ctx, [mplot, mmd])
with gr.Tab("Live generation"):
gr.Markdown(
"Run **Laguna-XS.2** with a real KV cache. "
f"**{OURS}** modes use `VQQuantizedCache` (uint8 indices); "
"fp16 uses native `DynamicCache`. Memory is read from the live cache after generation."
)
prompt = gr.Textbox(
label="Prompt",
lines=4,
value="Explain in two sentences how product vector quantization compresses a KV cache.",
)
with gr.Row():
cache_mode = gr.Dropdown(
list(_CACHE_CFG),
value=f"{OURS} 2-bit",
label="KV cache",
interactive=cuda,
)
max_tok = gr.Slider(16, 256, value=64, step=16, label="Max new tokens",
interactive=cuda)
btn = gr.Button("Generate", variant="primary", interactive=cuda)
out_text = gr.Textbox(label="Completion", lines=6)
out_md = gr.Markdown()
btn.click(run_generate, [prompt, cache_mode, max_tok], [out_text, out_md])
return demo
demo = build_demo()
if __name__ == "__main__":
demo.launch()