File size: 2,461 Bytes
4afd791 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 | """Figures for the RRT-on-Laguna result report. Numbers from the matched-reference
sweep (2M Python tokens, batch 16). Writes PNGs into reports/."""
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
OUT = Path(__file__).resolve().parent.parent / "reports"
OUT.mkdir(exist_ok=True)
# config: (label, compression %, tied@init ppl, tied final ppl, ref final ppl)
DATA = [
("main\n(rank 16)", 4.5, 13.198, 4.070, 3.058),
("r32\n(rank 32)", 4.5, 13.368, 3.976, 3.075),
("wide\n(rank 16)", 9.1, 17.173, 3.922, 2.959),
]
BASELINE = 12.33 # zero-shot untied ppl on held-out Python (mean across runs)
def fig_recovery():
fig, ax = plt.subplots(figsize=(6.2, 3.6))
x = range(len(DATA))
w = 0.26
init = [d[2] for d in DATA]
tied = [d[3] for d in DATA]
ref = [d[4] for d in DATA]
ax.bar([i - w for i in x], init, w, label="tied, at init (no training)", color="#c44e52")
ax.bar([i for i in x], tied, w, label="tied, after KD", color="#4c72b0")
ax.bar([i + w for i in x], ref, w, label="full model + LoRA (reference)", color="#55a868")
ax.axhline(BASELINE, ls="--", c="gray", lw=1, label=f"zero-shot baseline ({BASELINE:.1f})")
ax.set_xticks(list(x))
ax.set_xticklabels([d[0] for d in DATA])
ax.set_ylabel("held-out Python perplexity")
ax.set_title("Tying perturbs perplexity; KD recovers to near the reference")
ax.legend(fontsize=8, framealpha=0.9)
for i, v in enumerate(tied):
ax.text(i, v + 0.2, f"{v:.2f}", ha="center", fontsize=8)
fig.tight_layout()
fig.savefig(OUT / "fig_recovery.png", dpi=150)
def fig_gap():
fig, ax = plt.subplots(figsize=(6.2, 3.4))
labels = [d[0].replace("\n", " ") for d in DATA]
gaps = [d[3] - d[4] for d in DATA]
comps = [d[1] for d in DATA]
colors = ["#4c72b0", "#4c72b0", "#dd8452"]
bars = ax.bar(labels, gaps, color=colors, width=0.55)
ax.set_ylabel("recovery gap, tied minus reference (ppl)")
ax.set_title("Recovery gap stays near 1 ppl across compression and rank")
ax.set_ylim(0, 1.35)
for b, g, c in zip(bars, gaps, comps):
ax.text(b.get_x() + b.get_width() / 2, g + 0.03, f"{g:.2f} ppl\n{c:.1f}% smaller",
ha="center", fontsize=8)
fig.tight_layout()
fig.savefig(OUT / "fig_gap.png", dpi=150)
if __name__ == "__main__":
fig_recovery()
fig_gap()
print(f"wrote {OUT}/fig_recovery.png and {OUT}/fig_gap.png")
|