import json import matplotlib.pyplot as plt import numpy as np with open("results/mistral-7b/sensitivity_map.json") as f: sens = json.load(f) num_layers = len(sens) num_heads = len(sens["0"]) # build heatmaps err_4bit = np.zeros((num_layers, num_heads)) for l in sens: for h in sens[l]: err_4bit[int(l), int(h)] = sens[l][h]["4bit"] fig, ax = plt.subplots(figsize=(12, 8)) im = ax.imshow(err_4bit, aspect='auto', cmap='hot_r') ax.set_xlabel("Attention Head", fontsize=12) ax.set_ylabel("Layer", fontsize=12) ax.set_title("4-bit KV Cache Quantization Error per Head\n(darker = more sensitive = needs higher precision)", fontsize=13) plt.colorbar(im, ax=ax, label="MSE Reconstruction Error") plt.tight_layout() plt.savefig("figures/sensitivity_heatmap.png", dpi=150) print("āœ… Saved figures/sensitivity_heatmap.png") # print most and least sensitive heads flat = [(err_4bit[l,h], l, h) for l in range(num_layers) for h in range(num_heads)] flat.sort() print("\n🟢 10 LEAST sensitive heads (safe to quantize to 4-bit):") for err, l, h in flat[:10]: print(f" Layer {l:2d}, Head {h}: error={err:.4f}") print("\nšŸ”“ 10 MOST sensitive heads (keep at 8-bit):") for err, l, h in flat[-10:]: print(f" Layer {l:2d}, Head {h}: error={err:.4f}")