| """Figures for the RRT-on-Laguna result report. Numbers from the matched-reference |
| sweep (2M Python tokens, batch 16). Writes PNGs into reports/.""" |
|
|
| from __future__ import annotations |
|
|
| from pathlib import Path |
|
|
| import matplotlib.pyplot as plt |
|
|
| OUT = Path(__file__).resolve().parent.parent / "reports" |
| OUT.mkdir(exist_ok=True) |
|
|
| |
| DATA = [ |
| ("main\n(rank 16)", 4.5, 13.198, 4.070, 3.058), |
| ("r32\n(rank 32)", 4.5, 13.368, 3.976, 3.075), |
| ("wide\n(rank 16)", 9.1, 17.173, 3.922, 2.959), |
| ] |
| BASELINE = 12.33 |
|
|
|
|
| def fig_recovery(): |
| fig, ax = plt.subplots(figsize=(6.2, 3.6)) |
| x = range(len(DATA)) |
| w = 0.26 |
| init = [d[2] for d in DATA] |
| tied = [d[3] for d in DATA] |
| ref = [d[4] for d in DATA] |
| ax.bar([i - w for i in x], init, w, label="tied, at init (no training)", color="#c44e52") |
| ax.bar([i for i in x], tied, w, label="tied, after KD", color="#4c72b0") |
| ax.bar([i + w for i in x], ref, w, label="full model + LoRA (reference)", color="#55a868") |
| ax.axhline(BASELINE, ls="--", c="gray", lw=1, label=f"zero-shot baseline ({BASELINE:.1f})") |
| ax.set_xticks(list(x)) |
| ax.set_xticklabels([d[0] for d in DATA]) |
| ax.set_ylabel("held-out Python perplexity") |
| ax.set_title("Tying perturbs perplexity; KD recovers to near the reference") |
| ax.legend(fontsize=8, framealpha=0.9) |
| for i, v in enumerate(tied): |
| ax.text(i, v + 0.2, f"{v:.2f}", ha="center", fontsize=8) |
| fig.tight_layout() |
| fig.savefig(OUT / "fig_recovery.png", dpi=150) |
|
|
|
|
| def fig_gap(): |
| fig, ax = plt.subplots(figsize=(6.2, 3.4)) |
| labels = [d[0].replace("\n", " ") for d in DATA] |
| gaps = [d[3] - d[4] for d in DATA] |
| comps = [d[1] for d in DATA] |
| colors = ["#4c72b0", "#4c72b0", "#dd8452"] |
| bars = ax.bar(labels, gaps, color=colors, width=0.55) |
| ax.set_ylabel("recovery gap, tied minus reference (ppl)") |
| ax.set_title("Recovery gap stays near 1 ppl across compression and rank") |
| ax.set_ylim(0, 1.35) |
| for b, g, c in zip(bars, gaps, comps): |
| ax.text(b.get_x() + b.get_width() / 2, g + 0.03, f"{g:.2f} ppl\n{c:.1f}% smaller", |
| ha="center", fontsize=8) |
| fig.tight_layout() |
| fig.savefig(OUT / "fig_gap.png", dpi=150) |
|
|
|
|
| if __name__ == "__main__": |
| fig_recovery() |
| fig_gap() |
| print(f"wrote {OUT}/fig_recovery.png and {OUT}/fig_gap.png") |
|
|