File size: 7,375 Bytes
9190eff | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 | """
Long context visualization β both models.
"""
import json
import matplotlib.pyplot as plt
import matplotlib.ticker as ticker
import os
def load_long(model_name):
path = os.path.expanduser(
f"~/kv-hack/results/{model_name}/long_context_results.json"
)
with open(path) as f:
return json.load(f)
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
mistral = load_long("mistral-7b")
llama = load_long("llama-3-8b")
C_FP16 = "#ef4444"
C_UNIFORM = "#f97316"
C_MISTRAL = "#22c55e"
C_LLAMA = "#3b82f6"
# ββ GRAPH 1: Both Models Side by Side βββββββββββββββββ
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
for ax, data, color, title, oom_ctx in [
(axes[0], mistral, C_MISTRAL, "Mistral-7B", None),
(axes[1], llama, C_LLAMA, "Llama-3-8B", 32768),
]:
valid = [r for r in data["results"] if "mixed_precision_mb" in r]
ctx = [r["context_len"] for r in valid]
fp16 = [r["fp16_mb"] for r in valid]
uni8 = [r["uniform8_mb"] for r in valid]
ours = [r["mixed_precision_mb"] for r in valid]
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
ax.plot(ctx, ours, '^-', color=color, linewidth=3, markersize=9, label="Per-Head Mixed (Ours)")
ax.fill_between(ctx, fp16, ours, alpha=0.08, color=color)
# OOM marker
if oom_ctx:
ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
ax.text(ctx[-1]*0.92, max(fp16)*0.85,
"FP16\nOOM β", color=C_FP16,
fontweight='bold', fontsize=10, ha='right')
# show where ours would be at 32K
ours_32k = ours[-1] * 2
ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB β
",
xy=(ctx[-1], ours[-1]),
xytext=(ctx[-2], ours[-1]+200),
color=color, fontweight='bold', fontsize=9,
arrowprops=dict(arrowstyle='->', color=color))
# annotate last valid point
ax.annotate(f"{fp16[-1]/1024:.1f} GB",
xy=(ctx[-1], fp16[-1]),
xytext=(-40, 10), textcoords='offset points',
color=C_FP16, fontweight='bold', fontsize=9)
ax.annotate(f"{ours[-1]/1024:.1f} GB",
xy=(ctx[-1], ours[-1]),
xytext=(-40, -20), textcoords='offset points',
color=color, fontweight='bold', fontsize=9)
ax.set_xlabel("Context Length (tokens)", fontsize=12)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
ax.set_title(f"{title}\nKV Cache Memory vs Context Length",
fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xticks(ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
plt.suptitle("Per-Head Mixed-Precision KV Cache β Long Context Benchmark\n"
"Llama-3-8B FP16 OOMs at 32K. Our method fits.",
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_both.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/long_context_both.png")
# ββ GRAPH 2: The OOM Story ββββββββββββββββββββββββββββ
fig, ax = plt.subplots(figsize=(12, 6))
# project to 32K for both
all_ctx = [512, 1024, 2048, 4096, 8192, 16384, 32768]
# mistral has all points
m_fp16 = [r["fp16_mb"] for r in mistral["results"] if "fp16_mb" in r]
m_ours = [r["mixed_precision_mb"] for r in mistral["results"]
if "mixed_precision_mb" in r]
m_ctx = [r["context_len"] for r in mistral["results"]
if "mixed_precision_mb" in r]
# llama valid points
l_valid = [r for r in llama["results"] if "mixed_precision_mb" in r]
l_fp16 = [r["fp16_mb"] for r in l_valid]
l_ours = [r["mixed_precision_mb"] for r in l_valid]
l_ctx = [r["context_len"] for r in l_valid]
# A100 40GB memory limit line (minus model weights)
mistral_model_mem = 14.5 * 1024 # MB
llama_model_mem = 16.0 * 1024 # MB
a100_total = 40 * 1024 # MB
ax.axhline(y=a100_total - mistral_model_mem,
color='gray', linestyle='--', alpha=0.7, linewidth=2,
label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB")
ax.axhline(y=a100_total - llama_model_mem,
color='gray', linestyle=':', alpha=0.7, linewidth=2,
label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB")
ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)")
ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)")
ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)")
ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)")
# OOM annotation
ax.annotate("Llama FP16\nOOM here β",
xy=(16384, l_fp16[-1]),
xytext=(12000, l_fp16[-1]+400),
color=C_FP16, fontweight='bold', fontsize=10,
arrowprops=dict(arrowstyle='->', color=C_FP16))
ax.set_xlabel("Context Length (tokens)", fontsize=13)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
ax.set_title("KV Cache Memory vs GPU Headroom\n"
"Our method keeps you under the limit longer",
fontsize=14, fontweight='bold')
ax.legend(fontsize=10, loc='upper left')
ax.grid(True, alpha=0.3)
ax.set_xticks(m_ctx)
ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/oom_story.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/oom_story.png")
# ββ GRAPH 3: Prefill Latency Both Models βββββββββββββ
fig, ax = plt.subplots(figsize=(10, 5))
m_prefill = [r["prefill_ms"] for r in mistral["results"] if "prefill_ms" in r]
l_prefill = [r["prefill_ms"] for r in llama["results"] if "prefill_ms" in r]
ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
markersize=8, label="Mistral-7B")
ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5,
markersize=8, label="Llama-3-8B")
for x, y in zip(m_ctx, m_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, 10), textcoords='offset points',
ha='center', fontsize=8, color=C_MISTRAL)
for x, y in zip(l_ctx, l_prefill):
ax.annotate(f"{y:.0f}ms", xy=(x, y),
xytext=(0, -18), textcoords='offset points',
ha='center', fontsize=8, color=C_LLAMA)
ax.set_xlabel("Context Length (tokens)", fontsize=13)
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
ax.set_title("Prefill Latency vs Context Length\nBoth Models",
fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(m_ctx)
ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/prefill_latency_both.png")
plt.close('all')
print("\nπ All long context graphs saved!") |