"""Figures for the RRT-on-Laguna result report. Numbers from the matched-reference sweep (2M Python tokens, batch 16). Writes PNGs into reports/.""" from __future__ import annotations from pathlib import Path import matplotlib.pyplot as plt OUT = Path(__file__).resolve().parent.parent / "reports" OUT.mkdir(exist_ok=True) # config: (label, compression %, tied@init ppl, tied final ppl, ref final ppl) DATA = [ ("main\n(rank 16)", 4.5, 13.198, 4.070, 3.058), ("r32\n(rank 32)", 4.5, 13.368, 3.976, 3.075), ("wide\n(rank 16)", 9.1, 17.173, 3.922, 2.959), ] BASELINE = 12.33 # zero-shot untied ppl on held-out Python (mean across runs) def fig_recovery(): fig, ax = plt.subplots(figsize=(6.2, 3.6)) x = range(len(DATA)) w = 0.26 init = [d[2] for d in DATA] tied = [d[3] for d in DATA] ref = [d[4] for d in DATA] ax.bar([i - w for i in x], init, w, label="tied, at init (no training)", color="#c44e52") ax.bar([i for i in x], tied, w, label="tied, after KD", color="#4c72b0") ax.bar([i + w for i in x], ref, w, label="full model + LoRA (reference)", color="#55a868") ax.axhline(BASELINE, ls="--", c="gray", lw=1, label=f"zero-shot baseline ({BASELINE:.1f})") ax.set_xticks(list(x)) ax.set_xticklabels([d[0] for d in DATA]) ax.set_ylabel("held-out Python perplexity") ax.set_title("Tying perturbs perplexity; KD recovers to near the reference") ax.legend(fontsize=8, framealpha=0.9) for i, v in enumerate(tied): ax.text(i, v + 0.2, f"{v:.2f}", ha="center", fontsize=8) fig.tight_layout() fig.savefig(OUT / "fig_recovery.png", dpi=150) def fig_gap(): fig, ax = plt.subplots(figsize=(6.2, 3.4)) labels = [d[0].replace("\n", " ") for d in DATA] gaps = [d[3] - d[4] for d in DATA] comps = [d[1] for d in DATA] colors = ["#4c72b0", "#4c72b0", "#dd8452"] bars = ax.bar(labels, gaps, color=colors, width=0.55) ax.set_ylabel("recovery gap, tied minus reference (ppl)") ax.set_title("Recovery gap stays near 1 ppl across compression and rank") ax.set_ylim(0, 1.35) for b, g, c in zip(bars, gaps, comps): ax.text(b.get_x() + b.get_width() / 2, g + 0.03, f"{g:.2f} ppl\n{c:.1f}% smaller", ha="center", fontsize=8) fig.tight_layout() fig.savefig(OUT / "fig_gap.png", dpi=150) if __name__ == "__main__": fig_recovery() fig_gap() print(f"wrote {OUT}/fig_recovery.png and {OUT}/fig_gap.png")