harshithsaiv commited on
Commit
9e0641d
Β·
1 Parent(s): c32a0aa

feat: Implementation of the vizualize script

Browse files
Files changed (1) hide show
  1. visualize_results.py +143 -0
visualize_results.py ADDED
@@ -0,0 +1,143 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate all publication-ready graphs for both models.
3
+ """
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import numpy as np
7
+ import os
8
+
9
+ def load_results(model_name):
10
+ path = os.path.expanduser(f"~/kv-hack/results/{model_name}/benchmark_results.json")
11
+ with open(path) as f:
12
+ return json.load(f)
13
+
14
+ mistral = load_results("mistral-7b")
15
+ llama = load_results("llama-3-8b")
16
+
17
+ C_FP16 = "#ef4444"
18
+ C_UNIFORM = "#f97316"
19
+ C_MISTRAL = "#22c55e"
20
+ C_LLAMA = "#3b82f6"
21
+
22
+ os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
23
+
24
+ # ── GRAPH 1: Memory vs Context β€” Both Models ──────────
25
+ fig, axes = plt.subplots(1, 2, figsize=(16, 6))
26
+
27
+ for ax, results, title in [
28
+ (axes[0], mistral, "Mistral-7B"),
29
+ (axes[1], llama, "Llama-3-8B"),
30
+ ]:
31
+ ctx = [r["context_len"] for r in results["compression"]]
32
+ fp16 = [r["fp16_mb"] for r in results["compression"]]
33
+ uni8 = [r["uniform8_mb"] for r in results["compression"]]
34
+ ours = [r["mixed_precision_mb"] for r in results["compression"]]
35
+
36
+ ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline")
37
+ ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit")
38
+ ax.plot(ctx, ours, '^-', color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
39
+ linewidth=2.5, markersize=8, label="Per-Head Mixed (Ours)")
40
+
41
+ # annotate at 8K
42
+ ax.annotate(f"{fp16[-1]:.0f} MB", xy=(8192, fp16[-1]),
43
+ xytext=(5500, fp16[-1]+30), color=C_FP16, fontweight='bold', fontsize=9)
44
+ ax.annotate(f"{uni8[-1]:.0f} MB", xy=(8192, uni8[-1]),
45
+ xytext=(5500, uni8[-1]+30), color=C_UNIFORM, fontweight='bold', fontsize=9)
46
+ ax.annotate(f"{ours[-1]:.0f} MB\n({results['summary']['compression_8k']}x vs FP16)",
47
+ xy=(8192, ours[-1]), xytext=(4000, ours[-1]-150),
48
+ color=C_MISTRAL if title == "Mistral-7B" else C_LLAMA,
49
+ fontweight='bold', fontsize=9)
50
+
51
+ ax.set_xlabel("Context Length (tokens)", fontsize=12)
52
+ ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
53
+ ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold')
54
+ ax.legend(fontsize=10)
55
+ ax.grid(True, alpha=0.3)
56
+ ax.set_xticks(ctx)
57
+
58
+ plt.suptitle("Per-Head Mixed-Precision KV Cache Compression",
59
+ fontsize=15, fontweight='bold', y=1.02)
60
+ plt.tight_layout()
61
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_both.png"),
62
+ dpi=150, bbox_inches='tight')
63
+ print("βœ… Saved figures/memory_vs_context_both.png")
64
+
65
+
66
+ # ── GRAPH 2: Compression Bar Chart β€” Both Models ──────
67
+ fig, ax = plt.subplots(figsize=(10, 6))
68
+
69
+ x = np.arange(3)
70
+ width = 0.35
71
+ models = ["FP16\nBaseline", "Uniform\n8-bit", "Per-Head\nMixed (Ours)"]
72
+
73
+ bars1 = ax.bar(x - width/2,
74
+ [1.0, 2.0, mistral["summary"]["compression_8k"]],
75
+ width, label="Mistral-7B", color=C_MISTRAL, edgecolor='white')
76
+ bars2 = ax.bar(x + width/2,
77
+ [1.0, 2.0, llama["summary"]["compression_8k"]],
78
+ width, label="Llama-3-8B", color=C_LLAMA, edgecolor='white')
79
+
80
+ for bar in bars1:
81
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
82
+ f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
83
+ for bar in bars2:
84
+ ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03,
85
+ f"{bar.get_height():.2f}x", ha='center', fontweight='bold', fontsize=11)
86
+
87
+ ax.set_xticks(x)
88
+ ax.set_xticklabels(models, fontsize=12)
89
+ ax.set_ylabel("Compression vs FP16", fontsize=13)
90
+ ax.set_title("KV Cache Compression at 8K Context\nPer-Head Mixed Precision vs Baselines",
91
+ fontsize=14, fontweight='bold')
92
+ ax.set_ylim(0, 2.8)
93
+ ax.legend(fontsize=12)
94
+ ax.grid(True, axis='y', alpha=0.3)
95
+ ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4)
96
+ plt.tight_layout()
97
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_both.png"), dpi=150)
98
+ print("βœ… Saved figures/compression_bar_both.png")
99
+
100
+
101
+ # ── GRAPH 3: Hero Summary Table ───────────────────────
102
+ fig, ax = plt.subplots(figsize=(12, 4))
103
+ ax.axis('off')
104
+
105
+ table_data = [
106
+ ["Model", "Method", "Avg Bits", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"],
107
+ ["Mistral-7B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β€”", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
108
+ ["Mistral-7B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
109
+ ["Mistral-7B", "Per-Head Mixed (Ours)", f"{mistral['avg_bits']}", f"{mistral['summary']['ours_8k_mb']} MB", f"{mistral['summary']['compression_8k']}x", "1.15x", str(mistral["perplexity"]), f"{mistral['decode_tokens_per_sec']} t/s"],
110
+ ["Llama-3-8B", "FP16 Baseline", "16", "1073 MB", "1.0x", "β€”", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
111
+ ["Llama-3-8B", "Uniform 8-bit", "8", "537 MB", "2.0x", "1.0x", "~same", "~same"],
112
+ ["Llama-3-8B", "Per-Head Mixed (Ours)", f"{llama['avg_bits']}", f"{llama['summary']['ours_8k_mb']} MB", f"{llama['summary']['compression_8k']}x", "1.02x", str(llama["perplexity"]), f"{llama['decode_tokens_per_sec']} t/s"],
113
+ ]
114
+
115
+ table = ax.table(
116
+ cellText=table_data[1:],
117
+ colLabels=table_data[0],
118
+ cellLoc='center',
119
+ loc='center',
120
+ )
121
+ table.auto_set_font_size(False)
122
+ table.set_fontsize(9)
123
+ table.scale(1.2, 2.2)
124
+
125
+ # style header
126
+ for j in range(8):
127
+ table[0, j].set_facecolor("#1e293b")
128
+ table[0, j].set_text_props(color='white', fontweight='bold')
129
+
130
+ # highlight our rows green
131
+ for j in range(8):
132
+ table[3, j].set_facecolor("#dcfce7")
133
+ table[6, j].set_facecolor("#dbeafe")
134
+
135
+ plt.title("Full Results β€” Per-Head Mixed-Precision KV Cache",
136
+ fontsize=13, fontweight='bold', pad=20)
137
+ plt.tight_layout()
138
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_both.png"),
139
+ dpi=150, bbox_inches='tight')
140
+ print("βœ… Saved figures/results_table_both.png")
141
+
142
+ plt.close('all')
143
+ print("\nπŸŽ‰ All graphs saved to ~/kv-hack/figures/")