File size: 6,830 Bytes
eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e 0774ec2 eec6c0e | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 | """
Long context visualization β 4 methods comparison.
"""
import json
import matplotlib.pyplot as plt
import os
def load_long(model_name):
path = os.path.expanduser(
f"~/kv-hack/results/{model_name}/long_context_results.json"
)
with open(path) as f:
return json.load(f)
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
mistral = load_long("mistral-7b")
llama = load_long("llama-3-8b")
C_FP16 = "#ef4444"
C_UNIFORM = "#f97316"
C_NAIVE = "#a855f7"
C_MISTRAL = "#22c55e"
C_LLAMA = "#3b82f6"
# ββ GRAPH 1: Both Models 4 Methods βββββββββββββββββββ
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
for ax, data, triton_color, title in [
(axes[0], mistral, C_MISTRAL, "Mistral-7B"),
(axes[1], llama, C_LLAMA, "Llama-3-8B"),
]:
valid = [r for r in data["results"] if "triton_mb" in r]
ctx = [r["context_len"] for r in valid]
fp16 = [r["fp16_mb"] for r in valid]
uni8 = [r["uniform8_mb"] for r in valid]
naive = [r["naive_real_gpu_mb"] for r in valid]
triton= [r["triton_mb"] for r in valid]
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=3, markersize=9, label="Naive Per-Head (uint8)")
ax.plot(ctx, triton, '^-', color=triton_color, linewidth=3, markersize=9, label="Triton True 4-bit (Ours)")
ax.fill_between(ctx, fp16, triton, alpha=0.07, color=triton_color)
# annotate last point
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
xy=(ctx[-1], fp16[-1]),
xytext=(-50, 10), textcoords='offset points',
color=C_FP16, fontweight='bold', fontsize=9)
ax.annotate(f"{uni8[-1]/1024:.1f} GB",
xy=(ctx[-1], uni8[-1]),
xytext=(-50, 10), textcoords='offset points',
color=C_UNIFORM, fontweight='bold', fontsize=9)
ax.annotate(f"{naive[-1]/1024:.1f} GB",
xy=(ctx[-1], naive[-1]),
xytext=(-50, -18), textcoords='offset points',
color=C_NAIVE, fontweight='bold', fontsize=9)
ax.annotate(f"{triton[-1]/1024:.1f} GB\n({valid[-1]['triton_compression']}x)",
xy=(ctx[-1], triton[-1]),
xytext=(-80, -35), textcoords='offset points',
color=triton_color, fontweight='bold', fontsize=9)
# OOM marker for llama
if title == "Llama-3-8B":
ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
ax.text(ctx[-1]*0.88, max(fp16)*0.88,
"FP16\nOOM β", color=C_FP16,
fontweight='bold', fontsize=10, ha='right')
ax.set_xlabel("Context Length (tokens)", fontsize=12)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
ax.set_title(f"{title}\nKV Cache Memory vs Context Length (4 Methods)",
fontsize=13, fontweight='bold')
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_xticks(ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark",
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_4methods.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/long_context_4methods.png")
# ββ GRAPH 2: The savings story at 32K βββββββββββββββββ
fig, ax = plt.subplots(figsize=(10, 6))
# use mistral 32K numbers
r32 = next(r for r in mistral["results"] if r["context_len"] == 32768)
methods = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8)", "Triton True\n4-bit (Ours)"]
values = [r32["fp16_mb"], r32["uniform8_mb"], r32["naive_real_gpu_mb"], r32["triton_mb"]]
colors = [C_FP16, C_UNIFORM, C_NAIVE, C_MISTRAL]
bars = ax.bar(methods, values, color=colors, width=0.5,
edgecolor='white', linewidth=2)
for bar, val in zip(bars, values):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
f"{val/1024:.1f} GB", ha='center',
fontweight='bold', fontsize=12)
# savings arrows
ax.annotate('', xy=(3, r32["triton_mb"]),
xytext=(0, r32["fp16_mb"]),
arrowprops=dict(arrowstyle='<->', color='gray', lw=2))
ax.text(1.5, (r32["fp16_mb"] + r32["triton_mb"])/2,
f"Save {(r32['fp16_mb']-r32['triton_mb'])/1024:.1f} GB\n({r32['triton_compression']}x)",
ha='center', color='gray', fontweight='bold', fontsize=11)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
ax.set_title("KV Cache Memory at 32K Context β Mistral-7B\nTriton saves 2.4GB vs FP16 baseline",
fontsize=14, fontweight='bold')
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_32k_4methods.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/memory_32k_4methods.png")
# ββ GRAPH 3: Prefill Latency Both Models ββββββββββββββ
fig, ax = plt.subplots(figsize=(10, 5))
m_valid = [r for r in mistral["results"] if "prefill_ms" in r]
l_valid = [r for r in llama["results"] if "prefill_ms" in r]
m_ctx = [r["context_len"] for r in m_valid]
l_ctx = [r["context_len"] for r in l_valid]
m_prefill = [r["prefill_ms"] for r in m_valid]
l_prefill = [r["prefill_ms"] for r in l_valid]
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
markersize=8, label="Mistral-7B")
ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5,
markersize=8, label="Llama-3-8B")
for x, y in zip(m_ctx, m_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, 10), textcoords='offset points',
ha='center', fontsize=8, color=C_MISTRAL)
for x, y in zip(l_ctx, l_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, -18), textcoords='offset points',
ha='center', fontsize=8, color=C_LLAMA)
ax.set_xlabel("Context Length (tokens)", fontsize=13)
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
ax.set_title("Prefill Latency vs Context Length β Both Models",
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(m_ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in m_ctx])
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/prefill_latency_both.png")
plt.close('all')
print("\nπ All long context graphs saved!") |