Spaces:
Sleeping
Sleeping
File size: 6,235 Bytes
a09b1f5 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 | """Post-process an SFT training log into publication-ready plots.
Run this once the user pastes back the contents of
``runs/sft_v1_stdout.log`` (the file the Colab notebook ``tee``s into).
It produces two PNGs:
* ``docs/img/sft_loss_curve.png`` — training loss per logging step
* ``docs/img/sft_4bar.png`` — random / no_op / scripted / sft_adapter
on hard_drift, side-by-side
Inputs
------
A path to the log file. The script is forgiving — if the eval table is
missing, it produces only the loss curve. If the loss lines are missing,
it produces only the 4-bar.
Usage
-----
python -m scripts.sft_postprocess runs/sft_v1_stdout.log
"""
from __future__ import annotations
import argparse
import re
import statistics
import sys
from pathlib import Path
import matplotlib
matplotlib.use("Agg")
import matplotlib.pyplot as plt
# ---------------- Loss curve ----------------
LOSS_PATTERNS = [
re.compile(r"'loss':\s*([0-9.]+)"),
re.compile(r"\bloss\s*=\s*([0-9.]+)"),
re.compile(r"\bstep\s+(\d+).*loss\s*[=:]\s*([0-9.]+)"),
]
def parse_loss(log_text: str) -> list[float]:
"""Return a list of training-loss values in temporal order."""
losses: list[float] = []
for line in log_text.splitlines():
for pat in LOSS_PATTERNS:
m = pat.search(line)
if m:
try:
val = float(m.group(m.lastindex))
if 0.0 < val < 100.0:
losses.append(val)
except ValueError:
pass
break
return losses
def plot_loss(losses: list[float], out_path: Path) -> None:
if not losses:
print("[skip] loss curve — no loss values parsed from log")
return
fig, ax = plt.subplots(figsize=(8.0, 4.4))
xs = list(range(1, len(losses) + 1))
ax.plot(xs, losses, color="#2a7fbf", lw=1.3)
if len(losses) >= 5:
# Light moving average for the eye
win = max(3, len(losses) // 25)
ma = [
statistics.mean(losses[max(0, i - win) : i + 1]) for i in range(len(losses))
]
ax.plot(xs, ma, color="red", lw=1.0, alpha=0.7, label=f"moving avg (window={win})")
ax.legend(loc="upper right", framealpha=0.95)
ax.set_xlabel("logging step")
ax.set_ylabel("training loss")
ax.set_title(
f"SFT training loss (Qwen2.5-3B + LoRA, {len(losses)} logged steps)\n"
"Unsloth train_on_responses_only, masked to assistant tokens"
)
ax.grid(True, alpha=0.3)
plt.tight_layout()
plt.savefig(out_path, dpi=140, bbox_inches="tight")
print(f"[ok] wrote {out_path} ({len(losses)} loss points)")
# ---------------- Eval table ----------------
EVAL_RE = re.compile(
r"^\s*(?P<task>easy_cashless|medium_multi_payer|hard_drift)\s+"
r"n=(?P<n>\d+)\s+trained=(?P<sft>[0-9.]+)\s+"
r"scripted=(?P<scripted>[0-9.]+)",
re.MULTILINE,
)
def parse_eval(log_text: str) -> dict[str, dict[str, float]]:
out: dict[str, dict[str, float]] = {}
for m in EVAL_RE.finditer(log_text):
out[m.group("task")] = {
"sft_adapter": float(m.group("sft")),
"scripted": float(m.group("scripted")),
"n": int(m.group("n")),
}
return out
def plot_4bar(eval_table: dict[str, dict[str, float]], out_path: Path) -> None:
"""Render random/no_op/scripted/sft_adapter on hard_drift only."""
if "hard_drift" not in eval_table:
print("[skip] 4-bar — no hard_drift row found in eval table")
return
sft = eval_table["hard_drift"]["sft_adapter"]
# Hard-locked 20-seed measured baselines on hard_drift after grader v3.1
# (P6 oscillation penalty + B2/B3 bonus gating). See
# docs/baseline_reproducibility.csv.
baselines = {"random": 0.108, "no_op": 0.079, "scripted": 0.754}
names = ["random", "no_op", "scripted", "sft_adapter"]
means = [baselines["random"], baselines["no_op"], baselines["scripted"], sft]
colors = ["#888888", "#bbbb44", "#2a7fbf", "#cc3399"]
fig, ax = plt.subplots(figsize=(7.0, 4.6))
bars = ax.bar(names, means, color=colors, edgecolor="black", linewidth=0.5)
for bar, m in zip(bars, means):
ax.text(
bar.get_x() + bar.get_width() / 2,
m + 0.015,
f"{m:.3f}",
ha="center",
va="bottom",
fontsize=9,
)
ax.axhline(
baselines["scripted"],
ls="--",
color="#2a7fbf",
lw=1.0,
alpha=0.7,
label=f"scripted ceiling = {baselines['scripted']:.3f}",
)
delta = sft - baselines["scripted"]
direction = "above" if delta > 0 else "≤"
ax.set_title(
f"Trained model on hard_drift: SFT adapter = {sft:.3f} "
f"({direction} scripted by {abs(delta):.3f})"
)
ax.set_ylabel("Composite grader score")
ax.set_ylim(0, max(0.95, max(means) + 0.10))
ax.legend(loc="upper left", framealpha=0.95)
plt.tight_layout()
plt.savefig(out_path, dpi=140, bbox_inches="tight")
print(f"[ok] wrote {out_path} (sft_adapter on hard_drift = {sft:.3f})")
# ---------------- Main ----------------
def main() -> int:
parser = argparse.ArgumentParser(description=__doc__)
parser.add_argument("log_path", type=Path, help="Path to sft_v1_stdout.log")
parser.add_argument(
"--out-dir",
type=Path,
default=Path("docs/img"),
help="Directory for output PNGs (default: docs/img)",
)
args = parser.parse_args()
if not args.log_path.exists():
print(f"[err] log not found: {args.log_path}", file=sys.stderr)
return 2
args.out_dir.mkdir(parents=True, exist_ok=True)
log_text = args.log_path.read_text()
losses = parse_loss(log_text)
plot_loss(losses, args.out_dir / "sft_loss_curve.png")
eval_table = parse_eval(log_text)
plot_4bar(eval_table, args.out_dir / "sft_4bar.png")
if not losses and not eval_table:
print(
"[warn] neither loss nor eval table parsed; check log format",
file=sys.stderr,
)
return 1
return 0
if __name__ == "__main__":
raise SystemExit(main())
|