kv-cache-compression / visualize_long_context.py
harshithsaiv's picture
feat: complete 4-method benchmark with honest memory reporting
0774ec2
"""
Long context visualization β€” 4 methods comparison.
"""
import json
import matplotlib.pyplot as plt
import os
def load_long(model_name):
path = os.path.expanduser(
f"~/kv-hack/results/{model_name}/long_context_results.json"
)
with open(path) as f:
return json.load(f)
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
mistral = load_long("mistral-7b")
llama = load_long("llama-3-8b")
C_FP16 = "#ef4444"
C_UNIFORM = "#f97316"
C_NAIVE = "#a855f7"
C_MISTRAL = "#22c55e"
C_LLAMA = "#3b82f6"
# ── GRAPH 1: Both Models 4 Methods ───────────────────
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
for ax, data, triton_color, title in [
(axes[0], mistral, C_MISTRAL, "Mistral-7B"),
(axes[1], llama, C_LLAMA, "Llama-3-8B"),
]:
valid = [r for r in data["results"] if "triton_mb" in r]
ctx = [r["context_len"] for r in valid]
fp16 = [r["fp16_mb"] for r in valid]
uni8 = [r["uniform8_mb"] for r in valid]
naive = [r["naive_real_gpu_mb"] for r in valid]
triton= [r["triton_mb"] for r in valid]
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=3, markersize=9, label="Naive Per-Head (uint8)")
ax.plot(ctx, triton, '^-', color=triton_color, linewidth=3, markersize=9, label="Triton True 4-bit (Ours)")
ax.fill_between(ctx, fp16, triton, alpha=0.07, color=triton_color)
# annotate last point
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
xy=(ctx[-1], fp16[-1]),
xytext=(-50, 10), textcoords='offset points',
color=C_FP16, fontweight='bold', fontsize=9)
ax.annotate(f"{uni8[-1]/1024:.1f} GB",
xy=(ctx[-1], uni8[-1]),
xytext=(-50, 10), textcoords='offset points',
color=C_UNIFORM, fontweight='bold', fontsize=9)
ax.annotate(f"{naive[-1]/1024:.1f} GB",
xy=(ctx[-1], naive[-1]),
xytext=(-50, -18), textcoords='offset points',
color=C_NAIVE, fontweight='bold', fontsize=9)
ax.annotate(f"{triton[-1]/1024:.1f} GB\n({valid[-1]['triton_compression']}x)",
xy=(ctx[-1], triton[-1]),
xytext=(-80, -35), textcoords='offset points',
color=triton_color, fontweight='bold', fontsize=9)
# OOM marker for llama
if title == "Llama-3-8B":
ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
ax.text(ctx[-1]*0.88, max(fp16)*0.88,
"FP16\nOOM β†’", color=C_FP16,
fontweight='bold', fontsize=10, ha='right')
ax.set_xlabel("Context Length (tokens)", fontsize=12)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
ax.set_title(f"{title}\nKV Cache Memory vs Context Length (4 Methods)",
fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_xticks(ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
plt.suptitle("Per-Head Mixed-Precision KV Cache β€” Long Context Benchmark",
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_4methods.png"),
dpi=150, bbox_inches='tight')
print("βœ… Saved figures/long_context_4methods.png")
# ── GRAPH 2: The savings story at 32K ─────────────────
fig, ax = plt.subplots(figsize=(10, 6))
# use mistral 32K numbers
r32 = next(r for r in mistral["results"] if r["context_len"] == 32768)
methods = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8)", "Triton True\n4-bit (Ours)"]
values = [r32["fp16_mb"], r32["uniform8_mb"], r32["naive_real_gpu_mb"], r32["triton_mb"]]
colors = [C_FP16, C_UNIFORM, C_NAIVE, C_MISTRAL]
bars = ax.bar(methods, values, color=colors, width=0.5,
edgecolor='white', linewidth=2)
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
f"{val/1024:.1f} GB", ha='center',
fontweight='bold', fontsize=12)
# savings arrows
ax.annotate('', xy=(3, r32["triton_mb"]),
xytext=(0, r32["fp16_mb"]),
arrowprops=dict(arrowstyle='<->', color='gray', lw=2))
ax.text(1.5, (r32["fp16_mb"] + r32["triton_mb"])/2,
f"Save {(r32['fp16_mb']-r32['triton_mb'])/1024:.1f} GB\n({r32['triton_compression']}x)",
ha='center', color='gray', fontweight='bold', fontsize=11)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
ax.set_title("KV Cache Memory at 32K Context β€” Mistral-7B\nTriton saves 2.4GB vs FP16 baseline",
fontsize=14, fontweight='bold')
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_32k_4methods.png"),
dpi=150, bbox_inches='tight')
print("βœ… Saved figures/memory_32k_4methods.png")
# ── GRAPH 3: Prefill Latency Both Models ──────────────
fig, ax = plt.subplots(figsize=(10, 5))
m_valid = [r for r in mistral["results"] if "prefill_ms" in r]
l_valid = [r for r in llama["results"] if "prefill_ms" in r]
m_ctx = [r["context_len"] for r in m_valid]
l_ctx = [r["context_len"] for r in l_valid]
m_prefill = [r["prefill_ms"] for r in m_valid]
l_prefill = [r["prefill_ms"] for r in l_valid]
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
markersize=8, label="Mistral-7B")
ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5,
markersize=8, label="Llama-3-8B")
for x, y in zip(m_ctx, m_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, 10), textcoords='offset points',
ha='center', fontsize=8, color=C_MISTRAL)
for x, y in zip(l_ctx, l_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, -18), textcoords='offset points',
ha='center', fontsize=8, color=C_LLAMA)
ax.set_xlabel("Context Length (tokens)", fontsize=13)
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
ax.set_title("Prefill Latency vs Context Length β€” Both Models",
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(m_ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in m_ctx])
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
dpi=150, bbox_inches='tight')
print("βœ… Saved figures/prefill_latency_both.png")
plt.close('all')
print("\nπŸŽ‰ All long context graphs saved!")