Commit Β·
9e0641d
1
Parent(s): c32a0aa
feat: Implementation of the vizualize script
Browse files- visualize_results.py +143 -0
visualize_results.py
ADDED
|
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Generate all publication-ready graphs for both models.
|
| 3 |
+
"""
|
| 4 |
+
import json
|
| 5 |
+
import matplotlib.pyplot as plt
|
| 6 |
+
import numpy as np
|
| 7 |
+
import os
|
| 8 |
+
|
| 9 |
+
def load_results(model_name):
|
| 10 |
+
path = os.path.expanduser(f"~/kv-hack/results/{model_name}/benchmark_results.json")
|
| 11 |
+
with open(path) as f:
|
| 12 |
+
return json.load(f)
|
| 13 |
+
|
| 14 |
+
mistral = load_results("mistral-7b")
|
| 15 |
+
llama = load_results("llama-3-8b")
|
| 16 |
+
|
| 17 |
+
C_FP16 = "#ef4444"
|
| 18 |
+
C_UNIFORM = "#f97316"
|
| 19 |
+
C_MISTRAL = "#22c55e"
|
| 20 |
+
C_LLAMA = "#3b82f6"
|
| 21 |
+
|
| 22 |
+
os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
|
| 23 |
+
|
| 24 |
+
# ββ GRAPH 1: Memory vs Context β Both Models ββββββββββ
|
| 25 |
+
fig, axes = plt.subplots(1, 2, figsize=(16, 6))
|
| 26 |
+
|
| 27 |
+
for ax, results, title in [
|
| 28 |
+
(axes[0], mistral, "Mistral-7B"),
|
| 29 |
+
(axes[1], llama, "Llama-3-8B"),
|
| 30 |
+
]:
|
| 31 |
+
ctx = [r["context_len"] for r in results["compression"]]
|
| 32 |
+
fp16 = [r["fp16_mb"] for r in results["compression"]]
|
| 33 |
+
uni8 = [r["uniform8_mb"] for r in results["compression"]]
|
| 34 |
+
ours = [r["mixed_precision_mb"] for r in results["compression"]]
|
| 35 |
+
|
| 36 |
+
ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
|
| 37 |
+
ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
|
| 38 |
+
ax.plot(ctx, ours, '^-', color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
|
| 39 |
+
linewidth=2.5, markersize=8, label="Per-Head Mixed (Ours)")
|
| 40 |
+
|
| 41 |
+
# annotate at 8K
|
| 42 |
+
ax.annotate(f"{fp16[-1]:.0f} MB", xy=(8192, fp16[-1]),
|
| 43 |
+
xytext=(5500, fp16[-1]+30), color=C_FP16, fontweight='bold', fontsize=9)
|
| 44 |
+
ax.annotate(f"{uni8[-1]:.0f} MB", xy=(8192, uni8[-1]),
|
| 45 |
+
xytext=(5500, uni8[-1]+30), color=C_UNIFORM, fontweight='bold', fontsize=9)
|
| 46 |
+
ax.annotate(f"{ours[-1]:.0f} MB\n({results['summary']['compression_8k']}x vs FP16)",
|
| 47 |
+
xy=(8192, ours[-1]), xytext=(4000, ours[-1]-150),
|
| 48 |
+
color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
|
| 49 |
+
fontweight='bold', fontsize=9)
|
| 50 |
+
|
| 51 |
+
ax.set_xlabel("Context Length (tokens)", fontsize=12)
|
| 52 |
+
ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
|
| 53 |
+
ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold')
|
| 54 |
+
ax.legend(fontsize=10)
|
| 55 |
+
ax.grid(True, alpha=0.3)
|
| 56 |
+
ax.set_xticks(ctx)
|
| 57 |
+
|
| 58 |
+
plt.suptitle("Per-Head Mixed-Precision KV Cache Compression",
|
| 59 |
+
fontsize=15, fontweight='bold', y=1.02)
|
| 60 |
+
plt.tight_layout()
|
| 61 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_both.png"),
|
| 62 |
+
dpi=150, bbox_inches='tight')
|
| 63 |
+
print("β
Saved figures/memory_vs_context_both.png")
|
| 64 |
+
|
| 65 |
+
|
| 66 |
+
# ββ GRAPH 2: Compression Bar Chart β Both Models ββββββ
|
| 67 |
+
fig, ax = plt.subplots(figsize=(10, 6))
|
| 68 |
+
|
| 69 |
+
x = np.arange(3)
|
| 70 |
+
width = 0.35
|
| 71 |
+
models = ["FP16\nBaseline", "Uniform\n8-bit", "Per-Head\nMixed (Ours)"]
|
| 72 |
+
|
| 73 |
+
bars1 = ax.bar(x - width/2,
|
| 74 |
+
[1.0, 2.0, mistral["summary"]["compression_8k"]],
|
| 75 |
+
width, label="Mistral-7B", color=C_MISTRAL, edgecolor='white')
|
| 76 |
+
bars2 = ax.bar(x + width/2,
|
| 77 |
+
[1.0, 2.0, llama["summary"]["compression_8k"]],
|
| 78 |
+
width, label="Llama-3-8B", color=C_LLAMA, edgecolor='white')
|
| 79 |
+
|
| 80 |
+
for bar in bars1:
|
| 81 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 82 |
+
f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
|
| 83 |
+
for bar in bars2:
|
| 84 |
+
ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
|
| 85 |
+
f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
|
| 86 |
+
|
| 87 |
+
ax.set_xticks(x)
|
| 88 |
+
ax.set_xticklabels(models, fontsize=12)
|
| 89 |
+
ax.set_ylabel("Compression vs FP16", fontsize=13)
|
| 90 |
+
ax.set_title("KV Cache Compression at 8K Context\nPer-Head Mixed Precision vs Baselines",
|
| 91 |
+
fontsize=14, fontweight='bold')
|
| 92 |
+
ax.set_ylim(0, 2.8)
|
| 93 |
+
ax.legend(fontsize=12)
|
| 94 |
+
ax.grid(True, axis='y', alpha=0.3)
|
| 95 |
+
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
|
| 96 |
+
plt.tight_layout()
|
| 97 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_both.png"), dpi=150)
|
| 98 |
+
print("β
Saved figures/compression_bar_both.png")
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
# ββ GRAPH 3: Hero Summary Table βββββββββββββββββββββββ
|
| 102 |
+
fig, ax = plt.subplots(figsize=(12, 4))
|
| 103 |
+
ax.axis('off')
|
| 104 |
+
|
| 105 |
+
table_data = [
|
| 106 |
+
["Model", "Method", "Avg Bits", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
|
| 107 |
+
["Mistral-7B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
|
| 108 |
+
["Mistral-7B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
|
| 109 |
+
["Mistral-7B", "Per-Head Mixed (Ours)", f"{mistral['avg_bits']}", f"{mistral['summary']['ours_8k_mb']} MB", f"{mistral['summary']['compression_8k']}x", "1.15x", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
|
| 110 |
+
["Llama-3-8B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
|
| 111 |
+
["Llama-3-8B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
|
| 112 |
+
["Llama-3-8B", "Per-Head Mixed (Ours)", f"{llama['avg_bits']}", f"{llama['summary']['ours_8k_mb']} MB", f"{llama['summary']['compression_8k']}x", "1.02x", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
|
| 113 |
+
]
|
| 114 |
+
|
| 115 |
+
table = ax.table(
|
| 116 |
+
cellText=table_data[1:],
|
| 117 |
+
colLabels=table_data[0],
|
| 118 |
+
cellLoc='center',
|
| 119 |
+
loc='center',
|
| 120 |
+
)
|
| 121 |
+
table.auto_set_font_size(False)
|
| 122 |
+
table.set_fontsize(9)
|
| 123 |
+
table.scale(1.2, 2.2)
|
| 124 |
+
|
| 125 |
+
# style header
|
| 126 |
+
for j in range(8):
|
| 127 |
+
table[0, j].set_facecolor("#1e293b")
|
| 128 |
+
table[0, j].set_text_props(color='white', fontweight='bold')
|
| 129 |
+
|
| 130 |
+
# highlight our rows green
|
| 131 |
+
for j in range(8):
|
| 132 |
+
table[3, j].set_facecolor("#dcfce7")
|
| 133 |
+
table[6, j].set_facecolor("#dbeafe")
|
| 134 |
+
|
| 135 |
+
plt.title("Full Results β Per-Head Mixed-Precision KV Cache",
|
| 136 |
+
fontsize=13, fontweight='bold', pad=20)
|
| 137 |
+
plt.tight_layout()
|
| 138 |
+
plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_both.png"),
|
| 139 |
+
dpi=150, bbox_inches='tight')
|
| 140 |
+
print("β
Saved figures/results_table_both.png")
|
| 141 |
+
|
| 142 |
+
plt.close('all')
|
| 143 |
+
print("\nπ All graphs saved to ~/kv-hack/figures/")
|