| """ |
| Generate publication-ready graphs β 4 methods comparison. |
| """ |
| import json |
| import matplotlib.pyplot as plt |
| import numpy as np |
| import os |
|
|
| def load_results(model_name): |
| path = os.path.expanduser( |
| f"~/kv-hack/results/{model_name}/benchmark_results.json" |
| ) |
| with open(path) as f: |
| return json.load(f) |
|
|
| mistral = load_results("mistral-7b") |
| llama = load_results("llama-3-8b") |
|
|
| C_FP16 = "#ef4444" |
| C_UNIFORM = "#f97316" |
| C_NAIVE = "#a855f7" |
| C_TRITON = "#22c55e" |
| C_LLAMA = "#3b82f6" |
|
|
| os.makedirs(os.path.expanduser("~/kv-hack/figures"), exist_ok=True) |
|
|
| |
| fig, axes = plt.subplots(1, 2, figsize=(18, 7)) |
|
|
| for ax, results, title, triton_color in [ |
| (axes[0], mistral, "Mistral-7B", C_TRITON), |
| (axes[1], llama, "Llama-3-8B", C_LLAMA), |
| ]: |
| ctx = [r["context_len"] for r in results["compression"]] |
| fp16 = [r["fp16_mb"] for r in results["compression"]] |
| uni8 = [r["uniform8_mb"] for r in results["compression"]] |
| naive = [r["naive_real_gpu_mb"] for r in results["compression"]] |
| triton = [r["triton_mb"] for r in results["compression"]] |
|
|
| ax.plot(ctx, fp16, 'o-', color=C_FP16, linewidth=2.5, markersize=8, label="FP16 Baseline") |
| ax.plot(ctx, uni8, 's-', color=C_UNIFORM, linewidth=2.5, markersize=8, label="Uniform 8-bit") |
| ax.plot(ctx, naive, 'D-', color=C_NAIVE, linewidth=2.5, markersize=8, label="Naive Per-Head (uint8)") |
| ax.plot(ctx, triton, '^-', color=triton_color, linewidth=2.5, markersize=8, label="Triton True 4-bit (Ours)") |
|
|
| |
| s = results["summary"] |
| ax.annotate(f"{fp16[-1]:.0f} MB", |
| xy=(8192, fp16[-1]), xytext=(-60, 10), |
| textcoords='offset points', color=C_FP16, fontweight='bold', fontsize=9) |
| ax.annotate(f"{uni8[-1]:.0f} MB", |
| xy=(8192, uni8[-1]), xytext=(-60, 10), |
| textcoords='offset points', color=C_UNIFORM, fontweight='bold', fontsize=9) |
| ax.annotate(f"{naive[-1]:.0f} MB", |
| xy=(8192, naive[-1]), xytext=(-60, -18), |
| textcoords='offset points', color=C_NAIVE, fontweight='bold', fontsize=9) |
| ax.annotate(f"{triton[-1]:.0f} MB\n({s['triton_compression_8k']}x)", |
| xy=(8192, triton[-1]), xytext=(-80, -35), |
| textcoords='offset points', color=triton_color, fontweight='bold', fontsize=9) |
|
|
| ax.set_xlabel("Context Length (tokens)", fontsize=12) |
| ax.set_ylabel("KV Cache Memory (MB)", fontsize=12) |
| ax.set_title(f"{title}\nKV Cache Memory vs Context Length", fontsize=13, fontweight='bold') |
| ax.legend(fontsize=10) |
| ax.grid(True, alpha=0.3) |
| ax.set_xticks(ctx) |
| ax.set_xticklabels(["512", "1K", "2K", "4K", "8K"]) |
|
|
| plt.suptitle("Per-Head Mixed-Precision KV Cache β 4 Method Comparison", |
| fontsize=14, fontweight='bold', y=1.02) |
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/memory_vs_context_4methods.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/memory_vs_context_4methods.png") |
|
|
|
|
| |
| fig, ax = plt.subplots(figsize=(12, 7)) |
|
|
| x = np.arange(4) |
| width = 0.35 |
| labels = ["FP16\nBaseline", "Uniform\n8-bit", "Naive Per-Head\n(uint8 actual)", "Triton True\n4-bit (Ours)"] |
|
|
| m_ratios = [ |
| 1.0, |
| 2.0, |
| mistral["summary"]["naive_real_compression_8k"], |
| mistral["summary"]["triton_compression_8k"], |
| ] |
| l_ratios = [ |
| 1.0, |
| 2.0, |
| llama["summary"]["naive_real_compression_8k"], |
| llama["summary"]["triton_compression_8k"], |
| ] |
|
|
| colors = [C_FP16, C_UNIFORM, C_NAIVE, C_TRITON] |
|
|
| bars1 = ax.bar(x - width/2, m_ratios, width, |
| label="Mistral-7B", color=colors, |
| edgecolor='white', linewidth=1.5, alpha=0.9) |
| bars2 = ax.bar(x + width/2, l_ratios, width, |
| label="Llama-3-8B", color=colors, |
| edgecolor='white', linewidth=1.5, alpha=0.6, |
| hatch='//') |
|
|
| for bar, ratio in zip(bars1, m_ratios): |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03, |
| f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=11) |
| for bar, ratio in zip(bars2, l_ratios): |
| ax.text(bar.get_x() + bar.get_width()/2, bar.get_height() + 0.03, |
| f"{ratio:.2f}x", ha='center', fontweight='bold', fontsize=10, |
| color='gray') |
|
|
| ax.set_xticks(x) |
| ax.set_xticklabels(labels, fontsize=11) |
| ax.set_ylabel("Compression vs FP16", fontsize=13) |
| ax.set_title("KV Cache Compression at 8K Context\n4-Method Comparison β Mistral-7B vs Llama-3-8B", |
| fontsize=14, fontweight='bold') |
| ax.set_ylim(0, 2.8) |
| ax.legend(fontsize=11) |
| ax.grid(True, axis='y', alpha=0.3) |
| ax.axhline(y=1.0, color='gray', linestyle='--', alpha=0.4) |
|
|
| |
| ax.add_patch(plt.Rectangle((2.5, 0), 1.0, 2.8, |
| alpha=0.05, color=C_TRITON, zorder=0)) |
| ax.text(3.0, 2.65, "Our method", ha='center', |
| color=C_TRITON, fontweight='bold', fontsize=10) |
|
|
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/compression_bar_4methods.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/compression_bar_4methods.png") |
|
|
|
|
| |
| fig, ax = plt.subplots(figsize=(14, 5)) |
| ax.axis('off') |
|
|
| s_m = mistral["summary"] |
| s_l = llama["summary"] |
|
|
| table_data = [ |
| ["Model", "Method", "KV @ 8K", "vs FP16", "vs 8-bit", "Perplexity", "Speed"], |
| ["Mistral-7B", "FP16 Baseline", "1073 MB", "1.00x", "β", "14.23", "37.4 t/s"], |
| ["Mistral-7B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"], |
| ["Mistral-7B", "Naive Per-Head (uint8)", f"{s_m['naive_real_8k_mb']} MB", f"{s_m['naive_real_compression_8k']}x", "1.00x", "~same", "~same"], |
| ["Mistral-7B", "Triton True 4-bit (Ours)", f"{s_m['triton_8k_mb']} MB", f"{s_m['triton_compression_8k']}x", f"{s_m['triton_vs_8bit_8k']}x", "14.23", "37.4 t/s"], |
| ["Llama-3-8B", "FP16 Baseline", "1073 MB", "1.00x", "β", "20.70", "36.8 t/s"], |
| ["Llama-3-8B", "Uniform 8-bit", "537 MB", "2.00x", "1.00x", "~same", "~same"], |
| ["Llama-3-8B", "Naive Per-Head (uint8)", f"{s_l['naive_real_8k_mb']} MB", f"{s_l['naive_real_compression_8k']}x", "1.00x", "~same", "~same"], |
| ["Llama-3-8B", "Triton True 4-bit (Ours)", f"{s_l['triton_8k_mb']} MB", f"{s_l['triton_compression_8k']}x", f"{s_l['triton_vs_8bit_8k']}x", "20.70", "36.8 t/s"], |
| ] |
|
|
| table = ax.table( |
| cellText=table_data[1:], |
| colLabels=table_data[0], |
| cellLoc='center', |
| loc='center', |
| ) |
| table.auto_set_font_size(False) |
| table.set_fontsize(9) |
| table.scale(1.2, 2.0) |
|
|
| for j in range(7): |
| table[0, j].set_facecolor("#1e293b") |
| table[0, j].set_text_props(color='white', fontweight='bold') |
| table[4, j].set_facecolor("#dcfce7") |
| table[8, j].set_facecolor("#dbeafe") |
|
|
| plt.title("Full Results β Per-Head Mixed-Precision KV Cache (4 Methods)", |
| fontsize=13, fontweight='bold', pad=20) |
| plt.tight_layout() |
| plt.savefig(os.path.expanduser("~/kv-hack/figures/results_table_4methods.png"), |
| dpi=150, bbox_inches='tight') |
| print("β
Saved figures/results_table_4methods.png") |
|
|
| plt.close('all') |
| print("\nπ All 4-method graphs saved!") |