File size: 2,461 Bytes
4afd791
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
"""Figures for the RRT-on-Laguna result report. Numbers from the matched-reference
sweep (2M Python tokens, batch 16). Writes PNGs into reports/."""

from __future__ import annotations

from pathlib import Path

import matplotlib.pyplot as plt

OUT = Path(__file__).resolve().parent.parent / "reports"
OUT.mkdir(exist_ok=True)

# config: (label, compression %, tied@init ppl, tied final ppl, ref final ppl)
DATA = [
    ("main\n(rank 16)", 4.5, 13.198, 4.070, 3.058),
    ("r32\n(rank 32)", 4.5, 13.368, 3.976, 3.075),
    ("wide\n(rank 16)", 9.1, 17.173, 3.922, 2.959),
]
BASELINE = 12.33  # zero-shot untied ppl on held-out Python (mean across runs)


def fig_recovery():
    fig, ax = plt.subplots(figsize=(6.2, 3.6))
    x = range(len(DATA))
    w = 0.26
    init = [d[2] for d in DATA]
    tied = [d[3] for d in DATA]
    ref = [d[4] for d in DATA]
    ax.bar([i - w for i in x], init, w, label="tied, at init (no training)", color="#c44e52")
    ax.bar([i for i in x], tied, w, label="tied, after KD", color="#4c72b0")
    ax.bar([i + w for i in x], ref, w, label="full model + LoRA (reference)", color="#55a868")
    ax.axhline(BASELINE, ls="--", c="gray", lw=1, label=f"zero-shot baseline ({BASELINE:.1f})")
    ax.set_xticks(list(x))
    ax.set_xticklabels([d[0] for d in DATA])
    ax.set_ylabel("held-out Python perplexity")
    ax.set_title("Tying perturbs perplexity; KD recovers to near the reference")
    ax.legend(fontsize=8, framealpha=0.9)
    for i, v in enumerate(tied):
        ax.text(i, v + 0.2, f"{v:.2f}", ha="center", fontsize=8)
    fig.tight_layout()
    fig.savefig(OUT / "fig_recovery.png", dpi=150)


def fig_gap():
    fig, ax = plt.subplots(figsize=(6.2, 3.4))
    labels = [d[0].replace("\n", " ") for d in DATA]
    gaps = [d[3] - d[4] for d in DATA]
    comps = [d[1] for d in DATA]
    colors = ["#4c72b0", "#4c72b0", "#dd8452"]
    bars = ax.bar(labels, gaps, color=colors, width=0.55)
    ax.set_ylabel("recovery gap, tied minus reference (ppl)")
    ax.set_title("Recovery gap stays near 1 ppl across compression and rank")
    ax.set_ylim(0, 1.35)
    for b, g, c in zip(bars, gaps, comps):
        ax.text(b.get_x() + b.get_width() / 2, g + 0.03, f"{g:.2f} ppl\n{c:.1f}% smaller",
                ha="center", fontsize=8)
    fig.tight_layout()
    fig.savefig(OUT / "fig_gap.png", dpi=150)


if __name__ == "__main__":
    fig_recovery()
    fig_gap()
    print(f"wrote {OUT}/fig_recovery.png and {OUT}/fig_gap.png")