File size: 6,235 Bytes
a09b1f5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
"""Post-process an SFT training log into publication-ready plots.

Run this once the user pastes back the contents of
``runs/sft_v1_stdout.log`` (the file the Colab notebook ``tee``s into).

It produces two PNGs:

  * ``docs/img/sft_loss_curve.png``  — training loss per logging step
  * ``docs/img/sft_4bar.png``        — random / no_op / scripted / sft_adapter
                                       on hard_drift, side-by-side

Inputs
------
A path to the log file. The script is forgiving — if the eval table is
missing, it produces only the loss curve. If the loss lines are missing,
it produces only the 4-bar.

Usage
-----
    python -m scripts.sft_postprocess runs/sft_v1_stdout.log
"""

from __future__ import annotations

import argparse
import re
import statistics
import sys
from pathlib import Path

import matplotlib

matplotlib.use("Agg")
import matplotlib.pyplot as plt


# ---------------- Loss curve ----------------

LOSS_PATTERNS = [
    re.compile(r"'loss':\s*([0-9.]+)"),
    re.compile(r"\bloss\s*=\s*([0-9.]+)"),
    re.compile(r"\bstep\s+(\d+).*loss\s*[=:]\s*([0-9.]+)"),
]


def parse_loss(log_text: str) -> list[float]:
    """Return a list of training-loss values in temporal order."""
    losses: list[float] = []
    for line in log_text.splitlines():
        for pat in LOSS_PATTERNS:
            m = pat.search(line)
            if m:
                try:
                    val = float(m.group(m.lastindex))
                    if 0.0 < val < 100.0:
                        losses.append(val)
                except ValueError:
                    pass
                break
    return losses


def plot_loss(losses: list[float], out_path: Path) -> None:
    if not losses:
        print("[skip] loss curve — no loss values parsed from log")
        return
    fig, ax = plt.subplots(figsize=(8.0, 4.4))
    xs = list(range(1, len(losses) + 1))
    ax.plot(xs, losses, color="#2a7fbf", lw=1.3)
    if len(losses) >= 5:
        # Light moving average for the eye
        win = max(3, len(losses) // 25)
        ma = [
            statistics.mean(losses[max(0, i - win) : i + 1]) for i in range(len(losses))
        ]
        ax.plot(xs, ma, color="red", lw=1.0, alpha=0.7, label=f"moving avg (window={win})")
        ax.legend(loc="upper right", framealpha=0.95)
    ax.set_xlabel("logging step")
    ax.set_ylabel("training loss")
    ax.set_title(
        f"SFT training loss (Qwen2.5-3B + LoRA, {len(losses)} logged steps)\n"
        "Unsloth train_on_responses_only, masked to assistant tokens"
    )
    ax.grid(True, alpha=0.3)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"[ok]   wrote {out_path} ({len(losses)} loss points)")


# ---------------- Eval table ----------------

EVAL_RE = re.compile(
    r"^\s*(?P<task>easy_cashless|medium_multi_payer|hard_drift)\s+"
    r"n=(?P<n>\d+)\s+trained=(?P<sft>[0-9.]+)\s+"
    r"scripted=(?P<scripted>[0-9.]+)",
    re.MULTILINE,
)


def parse_eval(log_text: str) -> dict[str, dict[str, float]]:
    out: dict[str, dict[str, float]] = {}
    for m in EVAL_RE.finditer(log_text):
        out[m.group("task")] = {
            "sft_adapter": float(m.group("sft")),
            "scripted": float(m.group("scripted")),
            "n": int(m.group("n")),
        }
    return out


def plot_4bar(eval_table: dict[str, dict[str, float]], out_path: Path) -> None:
    """Render random/no_op/scripted/sft_adapter on hard_drift only."""
    if "hard_drift" not in eval_table:
        print("[skip] 4-bar — no hard_drift row found in eval table")
        return
    sft = eval_table["hard_drift"]["sft_adapter"]
    # Hard-locked 20-seed measured baselines on hard_drift after grader v3.1
    # (P6 oscillation penalty + B2/B3 bonus gating). See
    # docs/baseline_reproducibility.csv.
    baselines = {"random": 0.108, "no_op": 0.079, "scripted": 0.754}
    names = ["random", "no_op", "scripted", "sft_adapter"]
    means = [baselines["random"], baselines["no_op"], baselines["scripted"], sft]
    colors = ["#888888", "#bbbb44", "#2a7fbf", "#cc3399"]

    fig, ax = plt.subplots(figsize=(7.0, 4.6))
    bars = ax.bar(names, means, color=colors, edgecolor="black", linewidth=0.5)
    for bar, m in zip(bars, means):
        ax.text(
            bar.get_x() + bar.get_width() / 2,
            m + 0.015,
            f"{m:.3f}",
            ha="center",
            va="bottom",
            fontsize=9,
        )
    ax.axhline(
        baselines["scripted"],
        ls="--",
        color="#2a7fbf",
        lw=1.0,
        alpha=0.7,
        label=f"scripted ceiling = {baselines['scripted']:.3f}",
    )
    delta = sft - baselines["scripted"]
    direction = "above" if delta > 0 else "≤"
    ax.set_title(
        f"Trained model on hard_drift: SFT adapter = {sft:.3f} "
        f"({direction} scripted by {abs(delta):.3f})"
    )
    ax.set_ylabel("Composite grader score")
    ax.set_ylim(0, max(0.95, max(means) + 0.10))
    ax.legend(loc="upper left", framealpha=0.95)
    plt.tight_layout()
    plt.savefig(out_path, dpi=140, bbox_inches="tight")
    print(f"[ok]   wrote {out_path}  (sft_adapter on hard_drift = {sft:.3f})")


# ---------------- Main ----------------


def main() -> int:
    parser = argparse.ArgumentParser(description=__doc__)
    parser.add_argument("log_path", type=Path, help="Path to sft_v1_stdout.log")
    parser.add_argument(
        "--out-dir",
        type=Path,
        default=Path("docs/img"),
        help="Directory for output PNGs (default: docs/img)",
    )
    args = parser.parse_args()

    if not args.log_path.exists():
        print(f"[err] log not found: {args.log_path}", file=sys.stderr)
        return 2

    args.out_dir.mkdir(parents=True, exist_ok=True)
    log_text = args.log_path.read_text()

    losses = parse_loss(log_text)
    plot_loss(losses, args.out_dir / "sft_loss_curve.png")

    eval_table = parse_eval(log_text)
    plot_4bar(eval_table, args.out_dir / "sft_4bar.png")

    if not losses and not eval_table:
        print(
            "[warn] neither loss nor eval table parsed; check log format",
            file=sys.stderr,
        )
        return 1
    return 0


if __name__ == "__main__":
    raise SystemExit(main())