File size: 6,830 Bytes
eec6c0e
0774ec2
eec6c0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0774ec2
eec6c0e
 
 
0774ec2
eec6c0e
 
0774ec2
 
 
eec6c0e
0774ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec6c0e
 
0774ec2
eec6c0e
0774ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec6c0e
 
 
0774ec2
eec6c0e
0774ec2
eec6c0e
 
 
 
0774ec2
eec6c0e
 
0774ec2
eec6c0e
0774ec2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
eec6c0e
 
0774ec2
eec6c0e
0774ec2
eec6c0e
0774ec2
eec6c0e
0774ec2
eec6c0e
 
0774ec2
eec6c0e
 
0774ec2
 
 
 
 
 
eec6c0e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0774ec2
eec6c0e
 
 
 
0774ec2
eec6c0e
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
"""
Long context visualization β€” 4 methods comparison.
"""
import json
import matplotlib.pyplot as plt
import os

def load_long(model_name):
    path = os.path.expanduser(
        f"~/kv-hack/results/{model_name}/long_context_results.json"
    )
    with open(path) as f:
        return json.load(f)

os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True)

mistral = load_long("mistral-7b")
llama   = load_long("llama-3-8b")

C_FP16    = "#ef4444"
C_UNIFORM = "#f97316"
C_NAIVE   = "#a855f7"
C_MISTRAL = "#22c55e"
C_LLAMA   = "#3b82f6"

# ── GRAPH 1: Both Models 4 Methods ───────────────────
fig, axes = plt.subplots(1, 2, figsize=(18, 7))

for ax, data, triton_color, title in [
    (axes[0], mistral, C_MISTRAL, "Mistral-7B"),
    (axes[1], llama,   C_LLAMA,   "Llama-3-8B"),
]:
    valid = [r for r in data["results"] if "triton_mb" in r]
    ctx   = [r["context_len"]       for r in valid]
    fp16  = [r["fp16_mb"]           for r in valid]
    uni8  = [r["uniform8_mb"]       for r in valid]
    naive = [r["naive_real_gpu_mb"] for r in valid]
    triton= [r["triton_mb"]         for r in valid]

    ax.plot(ctx, fp16,   'o-', color=C_FP16,    linewidth=3, markersize=9, label="FP16 Baseline")
    ax.plot(ctx, uni8,   's-', color=C_UNIFORM,  linewidth=3, markersize=9, label="Uniform 8-bit")
    ax.plot(ctx, naive,  'D-', color=C_NAIVE,    linewidth=3, markersize=9, label="Naive Per-Head (uint8)")
    ax.plot(ctx, triton, '^-', color=triton_color, linewidth=3, markersize=9, label="Triton True 4-bit (Ours)")

    ax.fill_between(ctx, fp16, triton, alpha=0.07, color=triton_color)

    # annotate last point
    ax.annotate(f"{fp16[-1]/1024:.1f} GB",
                xy=(ctx[-1], fp16[-1]),
                xytext=(-50, 10), textcoords='offset points',
                color=C_FP16, fontweight='bold', fontsize=9)
    ax.annotate(f"{uni8[-1]/1024:.1f} GB",
                xy=(ctx[-1], uni8[-1]),
                xytext=(-50, 10), textcoords='offset points',
                color=C_UNIFORM, fontweight='bold', fontsize=9)
    ax.annotate(f"{naive[-1]/1024:.1f} GB",
                xy=(ctx[-1], naive[-1]),
                xytext=(-50, -18), textcoords='offset points',
                color=C_NAIVE, fontweight='bold', fontsize=9)
    ax.annotate(f"{triton[-1]/1024:.1f} GB\n({valid[-1]['triton_compression']}x)",
                xy=(ctx[-1], triton[-1]),
                xytext=(-80, -35), textcoords='offset points',
                color=triton_color, fontweight='bold', fontsize=9)

    # OOM marker for llama
    if title == "Llama-3-8B":
        ax.axvline(x=ctx[-1], color=C_FP16, linestyle='--', alpha=0.5)
        ax.text(ctx[-1]*0.88, max(fp16)*0.88,
                "FP16\nOOM β†’", color=C_FP16,
                fontweight='bold', fontsize=10, ha='right')

    ax.set_xlabel("Context Length (tokens)", fontsize=12)
    ax.set_ylabel("KV Cache Memory (MB)", fontsize=12)
    ax.set_title(f"{title}\nKV Cache Memory vs Context Length (4 Methods)",
                 fontsize=13, fontweight='bold')
    ax.legend(fontsize=10, loc='upper left')
    ax.grid(True, alpha=0.3)
    ax.set_xticks(ctx)
    ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in ctx])

plt.suptitle("Per-Head Mixed-Precision KV Cache β€” Long Context Benchmark",
             fontsize=14, fontweight='bold', y=1.02)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/long_context_4methods.png"),
            dpi=150, bbox_inches='tight')
print("βœ… Saved figures/long_context_4methods.png")


# ── GRAPH 2: The savings story at 32K ─────────────────
fig, ax = plt.subplots(figsize=(10, 6))

# use mistral 32K numbers
r32 = next(r for r in mistral["results"] if r["context_len"] == 32768)

methods = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8)", "Triton True\n4-bit (Ours)"]
values  = [r32["fp16_mb"], r32["uniform8_mb"], r32["naive_real_gpu_mb"], r32["triton_mb"]]
colors  = [C_FP16, C_UNIFORM, C_NAIVE, C_MISTRAL]

bars = ax.bar(methods, values, color=colors, width=0.5,
              edgecolor='white', linewidth=2)

for bar, val in zip(bars, values):
    ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 30,
            f"{val/1024:.1f} GB", ha='center',
            fontweight='bold', fontsize=12)

# savings arrows
ax.annotate('', xy=(3, r32["triton_mb"]),
            xytext=(0, r32["fp16_mb"]),
            arrowprops=dict(arrowstyle='<->', color='gray', lw=2))
ax.text(1.5, (r32["fp16_mb"] + r32["triton_mb"])/2,
        f"Save {(r32['fp16_mb']-r32['triton_mb'])/1024:.1f} GB\n({r32['triton_compression']}x)",
        ha='center', color='gray', fontweight='bold', fontsize=11)

ax.set_ylabel("KV Cache Memory (MB)", fontsize=13)
ax.set_title("KV Cache Memory at 32K Context β€” Mistral-7B\nTriton saves 2.4GB vs FP16 baseline",
             fontsize=14, fontweight='bold')
ax.grid(True, axis='y', alpha=0.3)
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_32k_4methods.png"),
            dpi=150, bbox_inches='tight')
print("βœ… Saved figures/memory_32k_4methods.png")


# ── GRAPH 3: Prefill Latency Both Models ──────────────
fig, ax = plt.subplots(figsize=(10, 5))

m_valid   = [r for r in mistral["results"] if "prefill_ms" in r]
l_valid   = [r for r in llama["results"]   if "prefill_ms" in r]
m_ctx     = [r["context_len"] for r in m_valid]
l_ctx     = [r["context_len"] for r in l_valid]
m_prefill = [r["prefill_ms"]  for r in m_valid]
l_prefill = [r["prefill_ms"]  for r in l_valid]

ax.plot(m_ctx, m_prefill, 'o-', color=C_MISTRAL, linewidth=2.5,
        markersize=8, label="Mistral-7B")
ax.plot(l_ctx, l_prefill, 's-', color=C_LLAMA,   linewidth=2.5,
        markersize=8, label="Llama-3-8B")

for x, y in zip(m_ctx, m_prefill):
    ax.annotate(f"{y:.0f}ms", xy=(x, y),
                xytext=(0, 10), textcoords='offset points',
                ha='center', fontsize=8, color=C_MISTRAL)
for x, y in zip(l_ctx, l_prefill):
    ax.annotate(f"{y:.0f}ms", xy=(x, y),
                xytext=(0, -18), textcoords='offset points',
                ha='center', fontsize=8, color=C_LLAMA)

ax.set_xlabel("Context Length (tokens)", fontsize=13)
ax.set_ylabel("Prefill Latency (ms)", fontsize=13)
ax.set_title("Prefill Latency vs Context Length β€” Both Models",
             fontsize=14, fontweight='bold')
ax.legend(fontsize=11)
ax.grid(True, alpha=0.3)
ax.set_xticks(m_ctx)
ax.set_xticklabels([f"{c//1024}K" if c >= 1024 else str(c) for c in m_ctx])
plt.tight_layout()
plt.savefig(os.path.expanduser("~/kv-hack/figures/prefill_latency_both.png"),
            dpi=150, bbox_inches='tight')
print("βœ… Saved figures/prefill_latency_both.png")

plt.close('all')
print("\nπŸŽ‰ All long context graphs saved!")