File size: 6,149 Bytes
9190eff
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
"""
Generate all publication-ready graphs for both models.
"""
import json
import matplotlib.pyplot as plt
import numpy as np
import os

def load_results(model_name):
    path = os.path.expanduser(f"~/kv-hack/results/{model_name}/benchmark_results.json")
    with open(path) as f:
        return json.load(f)

mistral = load_results("mistral-7b")
llama   = load_results("llama-3-8b")

C_FP16    = "#ef4444"
C_UNIFORM = "#f97316"
C_MISTRAL = "#22c55e"
C_LLAMA   = "#3b82f6"

os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)

# ── GRAPH 1: Memory vs Context β€” Both Models ──────────
fig, axes = plt.subplots(1, 2, figsize=(16, 6))

for ax, results, title in [
    (axes[0], mistral, "Mistral-7B"),
    (axes[1], llama,   "Llama-3-8B"),
]:
    ctx  = [r["context_len"]        for r in results["compression"]]
    fp16 = [r["fp16_mb"]            for r in results["compression"]]
    uni8 = [r["uniform8_mb"]        for r in results["compression"]]
    ours = [r["mixed_precision_mb"] for r in results["compression"]]

    ax.plot(ctx, fp16, 'o-', color=C_FP16,    linewidth=2.5, markersize=8, label="FP16 Baseline")
    ax.plot(ctx, uni8, 's-', color=C_UNIFORM,  linewidth=2.5, markersize=8, label="Uniform 8-bit")
    ax.plot(ctx, ours, '^-', color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
            linewidth=2.5, markersize=8, label="Per-Head Mixed (Ours)")

    # annotate at 8K
    ax.annotate(f"{fp16[-1]:.0f} MB", xy=(8192, fp16[-1]),
                xytext=(5500, fp16[-1]+30), color=C_FP16, fontweight='bold', fontsize=9)
    ax.annotate(f"{uni8[-1]:.0f} MB", xy=(8192, uni8[-1]),
                xytext=(5500, uni8[-1]+30), color=C_UNIFORM, fontweight='bold', fontsize=9)
    ax.annotate(f"{ours[-1]:.0f} MB\n({results['summary']['compression_8k']}x vs FP16)",
                xy=(8192, ours[-1]), xytext=(4000, ours[-1]-150),
                color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
                fontweight='bold', fontsize=9)

    ax.set_xlabel("Context Length (tokens)", fontsize=12)
    ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
    ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold')
    ax.legend(fontsize=10)
    ax.grid(True, alpha=0.3)
    ax.set_xticks(ctx)

plt.suptitle("Per-Head Mixed-Precision KV Cache Compression",
             fontsize=15, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_both.png"),
            dpi=150, bbox_inches='tight')
print("βœ… Saved figures/memory_vs_context_both.png")


# ── GRAPH 2: Compression Bar Chart β€” Both Models ──────
fig, ax = plt.subplots(figsize=(10, 6))

x      = np.arange(3)
width  = 0.35
models = ["FP16\nBaseline", "Uniform\n8-bit", "Per-Head\nMixed (Ours)"]

bars1 = ax.bar(x - width/2,
               [1.0, 2.0, mistral["summary"]["compression_8k"]],
               width, label="Mistral-7B", color=C_MISTRAL, edgecolor='white')
bars2 = ax.bar(x + width/2,
               [1.0, 2.0, llama["summary"]["compression_8k"]],
               width, label="Llama-3-8B",  color=C_LLAMA,   edgecolor='white')

for bar in bars1:
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
            f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
for bar in bars2:
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
            f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)

ax.set_xticks(x)
ax.set_xticklabels(models, fontsize=12)
ax.set_ylabel("Compression vs FP16", fontsize=13)
ax.set_title("KV Cache Compression at 8K Context\nPer-Head Mixed Precision vs Baselines",
             fontsize=14, fontweight='bold')
ax.set_ylim(0, 2.8)
ax.legend(fontsize=12)
ax.grid(True, axis='y', alpha=0.3)
ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_both.png"), dpi=150)
print("βœ… Saved figures/compression_bar_both.png")


# ── GRAPH 3: Hero Summary Table ───────────────────────
fig, ax = plt.subplots(figsize=(12, 4))
ax.axis('off')

table_data = [
    ["Model", "Method", "Avg Bits", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
    ["Mistral-7B", "FP16 Baseline",        "16",   "1073 MB", "1.0x", "β€”",     str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
    ["Mistral-7B", "Uniform 8-bit",         "8",    "537 MB",  "2.0x", "1.0x",  "~same",    "~same"],
    ["Mistral-7B", "Per-Head Mixed (Ours)", f"{mistral['avg_bits']}", f"{mistral['summary']['ours_8k_mb']} MB", f"{mistral['summary']['compression_8k']}x", "1.15x", "14.23 (Β±0.00)", f"{mistral['decode_tokens_per_sec']} t/s"],
    ["Llama-3-8B", "FP16 Baseline",        "16",   "1073 MB", "1.0x", "β€”",     str(llama["perplexity"]),   f"{llama['decode_tokens_per_sec']} t/s"],
    ["Llama-3-8B", "Uniform 8-bit",         "8",    "537 MB",  "2.0x", "1.0x",  "~same",                   "~same"],
    ["Llama-3-8B", "Per-Head Mixed (Ours)", f"{llama['avg_bits']}",   f"{llama['summary']['ours_8k_mb']} MB",   f"{llama['summary']['compression_8k']}x", "1.02x", "20.70 (Β±0.00)",   f"{llama['decode_tokens_per_sec']} t/s"],
]

table = ax.table(
    cellText=table_data[1:],
    colLabels=table_data[0],
    cellLoc='center',
    loc='center',
)
table.auto_set_font_size(False)
table.set_fontsize(9)
table.scale(1.2, 2.2)

# style header
for j in range(8):
    table[0, j].set_facecolor("#1e293b")
    table[0, j].set_text_props(color='white', fontweight='bold')

# highlight our rows green
for j in range(8):
    table[3, j].set_facecolor("#dcfce7")
    table[6, j].set_facecolor("#dbeafe")

plt.title("Full Results β€” Per-Head Mixed-Precision KV Cache",
          fontsize=13, fontweight='bold', pad=20)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_both.png"),
            dpi=150, bbox_inches='tight')
print("βœ… Saved figures/results_table_both.png")

plt.close('all')
print("\nπŸŽ‰ All graphs saved to ~/kv-hack/figures/")