| import json |
| import matplotlib.pyplot as plt |
| import numpy as np |
|
|
| with open("results/mistral-7b/sensitivity_map.json") as f: |
| sens = json.load(f) |
|
|
| num_layers = len(sens) |
| num_heads = len(sens["0"]) |
|
|
| |
| err_4bit = np.zeros((num_layers, num_heads)) |
| for l in sens: |
| for h in sens[l]: |
| err_4bit[int(l), int(h)] = sens[l][h]["4bit"] |
|
|
| fig, ax = plt.subplots(figsize=(12, 8)) |
| im = ax.imshow(err_4bit, aspect='auto', cmap='hot_r') |
| ax.set_xlabel("Attention Head", fontsize=12) |
| ax.set_ylabel("Layer", fontsize=12) |
| ax.set_title("4-bit KV Cache Quantization Error per Head\n(darker = more sensitive = needs higher precision)", fontsize=13) |
| plt.colorbar(im, ax=ax, label="MSE Reconstruction Error") |
| plt.tight_layout() |
| plt.savefig("figures/sensitivity_heatmap.png", dpi=150) |
| print("✅ Saved figures/sensitivity_heatmap.png") |
|
|
| |
| flat = [(err_4bit[l,h], l, h) for l in range(num_layers) for h in range(num_heads)] |
| flat.sort() |
| print("\n🟢 10 LEAST sensitive heads (safe to quantize to 4-bit):") |
| for err, l, h in flat[:10]: |
| print(f" Layer {l:2d}, Head {h}: error={err:.4f}") |
|
|
| print("\n🔴 10 MOST sensitive heads (keep at 8-bit):") |
| for err, l, h in flat[-10:]: |
| print(f" Layer {l:2d}, Head {h}: error={err:.4f}") |
|
|