looped-laguna / scripts /plot_rrt_results.py
e-p's picture
submission: combined README, looped + RRT reports, figures
4afd791
"""Figures for the RRT-on-Laguna result report. Numbers from the matched-reference
sweep (2M Python tokens, batch 16). Writes PNGs into reports/."""
from __future__ import annotations
from pathlib import Path
import matplotlib.pyplot as plt
OUT = Path(__file__).resolve().parent.parent / "reports"
OUT.mkdir(exist_ok=True)
# config: (label, compression %, tied@init ppl, tied final ppl, ref final ppl)
DATA = [
("main\n(rank 16)", 4.5, 13.198, 4.070, 3.058),
("r32\n(rank 32)", 4.5, 13.368, 3.976, 3.075),
("wide\n(rank 16)", 9.1, 17.173, 3.922, 2.959),
]
BASELINE = 12.33 # zero-shot untied ppl on held-out Python (mean across runs)
def fig_recovery():
fig, ax = plt.subplots(figsize=(6.2, 3.6))
x = range(len(DATA))
w = 0.26
init = [d[2] for d in DATA]
tied = [d[3] for d in DATA]
ref = [d[4] for d in DATA]
ax.bar([i - w for i in x], init, w, label="tied, at init (no training)", color="#c44e52")
ax.bar([i for i in x], tied, w, label="tied, after KD", color="#4c72b0")
ax.bar([i + w for i in x], ref, w, label="full model + LoRA (reference)", color="#55a868")
ax.axhline(BASELINE, ls="--", c="gray", lw=1, label=f"zero-shot baseline ({BASELINE:.1f})")
ax.set_xticks(list(x))
ax.set_xticklabels([d[0] for d in DATA])
ax.set_ylabel("held-out Python perplexity")
ax.set_title("Tying perturbs perplexity; KD recovers to near the reference")
ax.legend(fontsize=8, framealpha=0.9)
for i, v in enumerate(tied):
ax.text(i, v + 0.2, f"{v:.2f}", ha="center", fontsize=8)
fig.tight_layout()
fig.savefig(OUT / "fig_recovery.png", dpi=150)
def fig_gap():
fig, ax = plt.subplots(figsize=(6.2, 3.4))
labels = [d[0].replace("\n", " ") for d in DATA]
gaps = [d[3] - d[4] for d in DATA]
comps = [d[1] for d in DATA]
colors = ["#4c72b0", "#4c72b0", "#dd8452"]
bars = ax.bar(labels, gaps, color=colors, width=0.55)
ax.set_ylabel("recovery gap, tied minus reference (ppl)")
ax.set_title("Recovery gap stays near 1 ppl across compression and rank")
ax.set_ylim(0, 1.35)
for b, g, c in zip(bars, gaps, comps):
ax.text(b.get_x() + b.get_width() / 2, g + 0.03, f"{g:.2f} ppl\n{c:.1f}% smaller",
ha="center", fontsize=8)
fig.tight_layout()
fig.savefig(OUT / "fig_gap.png", dpi=150)
if __name__ == "__main__":
fig_recovery()
fig_gap()
print(f"wrote {OUT}/fig_recovery.png and {OUT}/fig_gap.png")