File size: 7,417 Bytes
9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 9e0641d 0774ec2 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 | """
Generate publication-ready graphs β 4 methods comparison.
"""
import json
import matplotlib.pyplot as plt
import numpy as np
import os
def load_results(model_name):
path = os.path.expanduser(
f"~/kv-hack/results/{model_name}/benchmark_results.json"
)
with open(path) as f:
return json.load(f)
mistral = load_results("mistral-7b")
llama = load_results("llama-3-8b")
C_FP16 = "#ef4444"
C_UNIFORM = "#f97316"
C_NAIVE = "#a855f7"
C_TRITON = "#22c55e"
C_LLAMA = "#3b82f6"
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
# ββ GRAPH 1: Memory vs Context β Mistral 4 methods βββ
fig, axes = plt.subplots(1, 2, figsize=(18, 7))
for ax, results, title, triton_color in [
(axes[0], mistral, "Mistral-7B", C_TRITON),
(axes[1], llama, "Llama-3-8B", C_LLAMA),
]:
ctx = [r["context_len"] for r in results["compression"]]
fp16 = [r["fp16_mb"] for r in results["compression"]]
uni8 = [r["uniform8_mb"] for r in results["compression"]]
naive = [r["naive_real_gpu_mb"] for r in results["compression"]]
triton = [r["triton_mb"] for r in results["compression"]]
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=2.5, markersize=8, label="Naive Per-Head (uint8)")
ax.plot(ctx, triton, '^-', color=triton_color, linewidth=2.5, markersize=8, label="Triton True 4-bit (Ours)")
# annotate at 8K
s = results["summary"]
ax.annotate(f"{fp16[-1]:.0f} MB",
xy=(8192, fp16[-1]), xytext=(-60, 10),
textcoords='offset points', color=C_FP16, fontweight='bold', fontsize=9)
ax.annotate(f"{uni8[-1]:.0f} MB",
xy=(8192, uni8[-1]), xytext=(-60, 10),
textcoords='offset points', color=C_UNIFORM, fontweight='bold', fontsize=9)
ax.annotate(f"{naive[-1]:.0f} MB",
xy=(8192, naive[-1]), xytext=(-60, -18),
textcoords='offset points', color=C_NAIVE, fontweight='bold', fontsize=9)
ax.annotate(f"{triton[-1]:.0f} MB\n({s['triton_compression_8k']}x)",
xy=(8192, triton[-1]), xytext=(-80, -35),
textcoords='offset points', color=triton_color, fontweight='bold', fontsize=9)
ax.set_xlabel("Context Length (tokens)", fontsize=12)
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold')
ax.legend(fontsize=10)
ax.grid(True, alpha=0.3)
ax.set_xticks(ctx)
ax.set_xticklabels(["512", "1K", "2K", "4K", "8K"])
plt.suptitle("Per-Head Mixed-Precision KV Cache β 4 Method Comparison",
fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_4methods.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/memory_vs_context_4methods.png")
# ββ GRAPH 2: Compression Bar Chart β 4 Methods ββββββββ
fig, ax = plt.subplots(figsize=(12, 7))
x = np.arange(4)
width = 0.35
labels = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8 actual)", "Triton True\n4-bit (Ours)"]
m_ratios = [
1.0,
2.0,
mistral["summary"]["naive_real_compression_8k"],
mistral["summary"]["triton_compression_8k"],
]
l_ratios = [
1.0,
2.0,
llama["summary"]["naive_real_compression_8k"],
llama["summary"]["triton_compression_8k"],
]
colors = [C_FP16, C_UNIFORM, C_NAIVE, C_TRITON]
bars1 = ax.bar(x - width/2, m_ratios, width,
label="Mistral-7B", color=colors,
edgecolor='white', linewidth=1.5, alpha=0.9)
bars2 = ax.bar(x + width/2, l_ratios, width,
label="Llama-3-8B", color=colors,
edgecolor='white', linewidth=1.5, alpha=0.6,
hatch='//')
for bar, ratio in zip(bars1, m_ratios):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=11)
for bar, ratio in zip(bars2, l_ratios):
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=10,
color='gray')
ax.set_xticks(x)
ax.set_xticklabels(labels, fontsize=11)
ax.set_ylabel("Compression vs FP16", fontsize=13)
ax.set_title("KV Cache Compression at 8K Context\n4-Method Comparison β Mistral-7B vs Llama-3-8B",
fontsize=14, fontweight='bold')
ax.set_ylim(0, 2.8)
ax.legend(fontsize=11)
ax.grid(True, axis='y', alpha=0.3)
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
# highlight our method
ax.add_patch(plt.Rectangle((2.5, 0), 1.0, 2.8,
alpha=0.05, color=C_TRITON, zorder=0))
ax.text(3.0, 2.65, "Our method", ha='center',
color=C_TRITON, fontweight='bold', fontsize=10)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_4methods.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/compression_bar_4methods.png")
# ββ GRAPH 3: Full Results Table ββββββββββββββββββββββββ
fig, ax = plt.subplots(figsize=(14, 5))
ax.axis('off')
s_m = mistral["summary"]
s_l = llama["summary"]
table_data = [
["Model", "Method", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
["Mistral-7B", "FP16 Baseline", "1073 MB", "1.00x", "β", "14.23", "37.4 t/s"],
["Mistral-7B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
["Mistral-7B", "Naive Per-Head (uint8)", f"{s_m['naive_real_8k_mb']} MB", f"{s_m['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
["Mistral-7B", "Triton True 4-bit (Ours)", f"{s_m['triton_8k_mb']} MB", f"{s_m['triton_compression_8k']}x", f"{s_m['triton_vs_8bit_8k']}x", "14.23", "37.4 t/s"],
["Llama-3-8B", "FP16 Baseline", "1073 MB", "1.00x", "β", "20.70", "36.8 t/s"],
["Llama-3-8B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"],
["Llama-3-8B", "Naive Per-Head (uint8)", f"{s_l['naive_real_8k_mb']} MB", f"{s_l['naive_real_compression_8k']}x", "1.00x", "~same", "~same"],
["Llama-3-8B", "Triton True 4-bit (Ours)", f"{s_l['triton_8k_mb']} MB", f"{s_l['triton_compression_8k']}x", f"{s_l['triton_vs_8bit_8k']}x", "20.70", "36.8 t/s"],
]
table = ax.table(
cellText=table_data[1:],
colLabels=table_data[0],
cellLoc='center',
loc='center',
)
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1.2, 2.0)
for j in range(7):
table[0, j].set_facecolor("#1e293b")
table[0, j].set_text_props(color='white', fontweight='bold')
table[4, j].set_facecolor("#dcfce7") # Mistral Triton row
table[8, j].set_facecolor("#dbeafe") # Llama Triton row
plt.title("Full Results β Per-Head Mixed-Precision KV Cache (4 Methods)",
fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_4methods.png"),
dpi=150, bbox_inches='tight')
print("β
Saved figures/results_table_4methods.png")
plt.close('all')
print("\nπ All 4-method graphs saved!") |