harshithsaiv commited on
Commit
eec6c0e
Β·
1 Parent(s): 1a0124b

feat: adding visualization for longer context

Browse files
Files changed (1) hide show
  1. visualize_long_context.py +177 -0
visualize_long_context.py ADDED
@@ -0,0 +1,177 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Long context visualization β€” both models.
3
+ """
4
+ import json
5
+ import matplotlib.pyplot as plt
6
+ import matplotlib.ticker as ticker
7
+ import os
8
+
9
+ def load_long(model_name):
10
+ path = os.path.expanduser(
11
+ f"~/kv-hack/results/{model_name}/long_context_results.json"
12
+ )
13
+ with open(path) as f:
14
+ return json.load(f)
15
+
16
+ os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)
17
+
18
+ mistral = load_long("mistral-7b")
19
+ llama = load_long("llama-3-8b")
20
+
21
+ C_FP16 = "#ef4444"
22
+ C_UNIFORM = "#f97316"
23
+ C_MISTRAL = "#22c55e"
24
+ C_LLAMA = "#3b82f6"
25
+
26
+ # ── GRAPH 1: Both Models Side by Side ─────────────────
27
+ fig, axes = plt.subplots(1, 2, figsize=(18, 7))
28
+
29
+ for ax, data, color, title, oom_ctx in [
30
+ (axes[0], mistral, C_MISTRAL, "Mistral-7B", None),
31
+ (axes[1], llama, C_LLAMA, "Llama-3-8B", 32768),
32
+ ]:
33
+ valid = [r for r in data["results"] if "mixed_precision_mb" in r]
34
+ ctx = [r["context_len"] for r in valid]
35
+ fp16 = [r["fp16_mb"] for r in valid]
36
+ uni8 = [r["uniform8_mb"] for r in valid]
37
+ ours = [r["mixed_precision_mb"] for r in valid]
38
+
39
+ ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=3, markersize=9, label="FP16 Baseline")
40
+ ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=3, markersize=9, label="Uniform 8-bit")
41
+ ax.plot(ctx, ours, '^-', color=color, linewidth=3, markersize=9, label="Per-Head Mixed (Ours)")
42
+ ax.fill_between(ctx, fp16, ours, alpha=0.08, color=color)
43
+
44
+ # OOM marker
45
+ if oom_ctx:
46
+ ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
47
+ ax.text(ctx[-1]*0.92, max(fp16)*0.85,
48
+ "FP16\nOOM β†’", color=C_FP16,
49
+ fontweight='bold', fontsize=10, ha='right')
50
+ # show where ours would be at 32K
51
+ ours_32k = ours[-1] * 2
52
+ ax.annotate(f"Ours at 32K:\n~{ours_32k:.0f}MB βœ…",
53
+ xy=(ctx[-1], ours[-1]),
54
+ xytext=(ctx[-2], ours[-1]+200),
55
+ color=color, fontweight='bold', fontsize=9,
56
+ arrowprops=dict(arrowstyle='->', color=color))
57
+
58
+ # annotate last valid point
59
+ ax.annotate(f"{fp16[-1]/1024:.1f} GB",
60
+ xy=(ctx[-1], fp16[-1]),
61
+ xytext=(-40, 10), textcoords='offset points',
62
+ color=C_FP16, fontweight='bold', fontsize=9)
63
+ ax.annotate(f"{ours[-1]/1024:.1f} GB",
64
+ xy=(ctx[-1], ours[-1]),
65
+ xytext=(-40, -20), textcoords='offset points',
66
+ color=color, fontweight='bold', fontsize=9)
67
+
68
+ ax.set_xlabel("Context Length (tokens)", fontsize=12)
69
+ ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
70
+ ax.set_title(f"{title}\nKV Cache Memory vs Context Length",
71
+ fontsize=13, fontweight='bold')
72
+ ax.legend(fontsize=10)
73
+ ax.grid(True, alpha=0.3)
74
+ ax.set_xticks(ctx)
75
+ ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])
76
+
77
+ plt.suptitle("Per-Head Mixed-Precision KV Cache β€” Long Context Benchmark\n"
78
+ "Llama-3-8B FP16 OOMs at 32K. Our method fits.",
79
+ fontsize=14, fontweight='bold', y=1.02)
80
+ plt.tight_layout()
81
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_both.png"),
82
+ dpi=150, bbox_inches='tight')
83
+ print("βœ… Saved figures/long_context_both.png")
84
+
85
+
86
+ # ── GRAPH 2: The OOM Story ────────────────────────────
87
+ fig, ax = plt.subplots(figsize=(12, 6))
88
+
89
+ # project to 32K for both
90
+ all_ctx = [512, 1024, 2048, 4096, 8192, 16384, 32768]
91
+ # mistral has all points
92
+ m_fp16 = [r["fp16_mb"] for r in mistral["results"] if "fp16_mb" in r]
93
+ m_ours = [r["mixed_precision_mb"] for r in mistral["results"]
94
+ if "mixed_precision_mb" in r]
95
+ m_ctx = [r["context_len"] for r in mistral["results"]
96
+ if "mixed_precision_mb" in r]
97
+
98
+ # llama valid points
99
+ l_valid = [r for r in llama["results"] if "mixed_precision_mb" in r]
100
+ l_fp16 = [r["fp16_mb"] for r in l_valid]
101
+ l_ours = [r["mixed_precision_mb"] for r in l_valid]
102
+ l_ctx = [r["context_len"] for r in l_valid]
103
+
104
+ # A100 40GB memory limit line (minus model weights)
105
+ mistral_model_mem = 14.5 * 1024 # MB
106
+ llama_model_mem = 16.0 * 1024 # MB
107
+ a100_total = 40 * 1024 # MB
108
+
109
+ ax.axhline(y=a100_total - mistral_model_mem,
110
+ color='gray', linestyle='--', alpha=0.7, linewidth=2,
111
+ label=f"A100 headroom (Mistral): {(a100_total-mistral_model_mem)/1024:.0f}GB")
112
+ ax.axhline(y=a100_total - llama_model_mem,
113
+ color='gray', linestyle=':', alpha=0.7, linewidth=2,
114
+ label=f"A100 headroom (Llama): {(a100_total-llama_model_mem)/1024:.0f}GB")
115
+
116
+ ax.plot(m_ctx, m_fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=7, label="FP16 (Mistral)")
117
+ ax.plot(m_ctx, m_ours, '^-', color=C_MISTRAL, linewidth=2.5, markersize=7, label="Ours (Mistral)")
118
+ ax.plot(l_ctx, l_fp16, 'o--', color="#f87171", linewidth=2.5, markersize=7, label="FP16 (Llama)")
119
+ ax.plot(l_ctx, l_ours, '^--', color=C_LLAMA, linewidth=2.5, markersize=7, label="Ours (Llama)")
120
+
121
+ # OOM annotation
122
+ ax.annotate("Llama FP16\nOOM here ❌",
123
+ xy=(16384, l_fp16[-1]),
124
+ xytext=(12000, l_fp16[-1]+400),
125
+ color=C_FP16, fontweight='bold', fontsize=10,
126
+ arrowprops=dict(arrowstyle='->', color=C_FP16))
127
+
128
+ ax.set_xlabel("Context Length (tokens)", fontsize=13)
129
+ ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
130
+ ax.set_title("KV Cache Memory vs GPU Headroom\n"
131
+ "Our method keeps you under the limit longer",
132
+ fontsize=14, fontweight='bold')
133
+ ax.legend(fontsize=10, loc='upper left')
134
+ ax.grid(True, alpha=0.3)
135
+ ax.set_xticks(m_ctx)
136
+ ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
137
+ plt.tight_layout()
138
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/oom_story.png"),
139
+ dpi=150, bbox_inches='tight')
140
+ print("βœ… Saved figures/oom_story.png")
141
+
142
+
143
+ # ── GRAPH 3: Prefill Latency Both Models ─────────────
144
+ fig, ax = plt.subplots(figsize=(10, 5))
145
+
146
+ m_prefill = [r["prefill_ms"] for r in mistral["results"] if "prefill_ms" in r]
147
+ l_prefill = [r["prefill_ms"] for r in llama["results"] if "prefill_ms" in r]
148
+
149
+ ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
150
+ markersize=8, label="Mistral-7B")
151
+ ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA, linewidth=2.5,
152
+ markersize=8, label="Llama-3-8B")
153
+
154
+ for x, y in zip(m_ctx, m_prefill):
155
+ ax.annotate(f"{y:.0f}ms", xy=(x, y),
156
+ xytext=(0, 10), textcoords='offset points',
157
+ ha='center', fontsize=8, color=C_MISTRAL)
158
+ for x, y in zip(l_ctx, l_prefill):
159
+ ax.annotate(f"{y:.0f}ms", xy=(x, y),
160
+ xytext=(0, -18), textcoords='offset points',
161
+ ha='center', fontsize=8, color=C_LLAMA)
162
+
163
+ ax.set_xlabel("Context Length (tokens)", fontsize=13)
164
+ ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
165
+ ax.set_title("Prefill Latency vs Context Length\nBoth Models",
166
+ fontsize=14, fontweight='bold')
167
+ ax.legend(fontsize=11)
168
+ ax.grid(True, alpha=0.3)
169
+ ax.set_xticks(m_ctx)
170
+ ax.set_xticklabels(["512","1K","2K","4K","8K","16K","32K"])
171
+ plt.tight_layout()
172
+ plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
173
+ dpi=150, bbox_inches='tight')
174
+ print("βœ… Saved figures/prefill_latency_both.png")
175
+
176
+ plt.close('all')
177
+ print("\nπŸŽ‰ All long context graphs saved!")