nervousystem-env / scripts /plot_final_evidence.py
vx7sh's picture
docs: publish real SFT log_history and add provenance note
1a3f787
"""Generate final submission evidence plots without external dependencies.
Loss curve preference:
1. If `results/sft_warmup_metrics.json` exists (written by `training/grpo_train.py`
from `trainer.state.log_history` during the HF Jobs run), plot every step from
that file.
2. Otherwise, fall back to the transcribed sparse points from the HF Jobs console.
The score plot is computed from results/final_phaseaware_model_eval.json.
"""
from __future__ import annotations
import json
import math
import struct
import zlib
from collections import defaultdict
from pathlib import Path
ROOT = Path(__file__).resolve().parents[1]
RESULTS = ROOT / "results"
Color = tuple[int, int, int]
WHITE: Color = (255, 255, 255)
BLACK: Color = (20, 24, 33)
GRID: Color = (220, 225, 232)
BLUE: Color = (42, 111, 219)
GREEN: Color = (34, 139, 84)
ORANGE: Color = (232, 140, 45)
GRAY: Color = (108, 117, 125)
def _png_chunk(kind: bytes, data: bytes) -> bytes:
return struct.pack(">I", len(data)) + kind + data + struct.pack(">I", zlib.crc32(kind + data) & 0xFFFFFFFF)
def write_png(path: Path, width: int, height: int, pixels: list[list[Color]]) -> None:
raw = bytearray()
for row in pixels:
raw.append(0)
for red, green, blue in row:
raw.extend((red, green, blue))
path.parent.mkdir(parents=True, exist_ok=True)
path.write_bytes(
b"\x89PNG\r\n\x1a\n"
+ _png_chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0))
+ _png_chunk(b"IDAT", zlib.compress(bytes(raw), level=9))
+ _png_chunk(b"IEND", b"")
)
def canvas(width: int, height: int, color: Color = WHITE) -> list[list[Color]]:
return [[color for _ in range(width)] for _ in range(height)]
def set_px(img: list[list[Color]], x: int, y: int, color: Color) -> None:
if 0 <= y < len(img) and 0 <= x < len(img[0]):
img[y][x] = color
def draw_line(img: list[list[Color]], x0: int, y0: int, x1: int, y1: int, color: Color, width: int = 2) -> None:
dx = abs(x1 - x0)
dy = -abs(y1 - y0)
sx = 1 if x0 < x1 else -1
sy = 1 if y0 < y1 else -1
err = dx + dy
while True:
radius = width // 2
for ox in range(-radius, radius + 1):
for oy in range(-radius, radius + 1):
set_px(img, x0 + ox, y0 + oy, color)
if x0 == x1 and y0 == y1:
break
e2 = 2 * err
if e2 >= dy:
err += dy
x0 += sx
if e2 <= dx:
err += dx
y0 += sy
def draw_rect(img: list[list[Color]], x0: int, y0: int, x1: int, y1: int, color: Color) -> None:
for y in range(max(0, y0), min(len(img), y1 + 1)):
for x in range(max(0, x0), min(len(img[0]), x1 + 1)):
img[y][x] = color
# Tiny 5x7 bitmap font for chart labels.
FONT: dict[str, list[str]] = {
" ": ["00000"] * 7,
".": ["00000", "00000", "00000", "00000", "00000", "01100", "01100"],
"-": ["00000", "00000", "00000", "11111", "00000", "00000", "00000"],
">": ["10000", "01000", "00100", "00010", "00100", "01000", "10000"],
"%": ["11001", "11010", "00100", "01000", "10110", "00110", "00000"],
"/": ["00001", "00010", "00100", "01000", "10000", "00000", "00000"],
":": ["00000", "01100", "01100", "00000", "01100", "01100", "00000"],
}
def _font_for(ch: str) -> list[str]:
if ch in FONT:
return FONT[ch]
if ch.isdigit():
digits = {
"0": ["11111", "10001", "10011", "10101", "11001", "10001", "11111"],
"1": ["00100", "01100", "00100", "00100", "00100", "00100", "01110"],
"2": ["11110", "00001", "00001", "11110", "10000", "10000", "11111"],
"3": ["11110", "00001", "00001", "01110", "00001", "00001", "11110"],
"4": ["10010", "10010", "10010", "11111", "00010", "00010", "00010"],
"5": ["11111", "10000", "10000", "11110", "00001", "00001", "11110"],
"6": ["01111", "10000", "10000", "11110", "10001", "10001", "01110"],
"7": ["11111", "00001", "00010", "00100", "01000", "01000", "01000"],
"8": ["01110", "10001", "10001", "01110", "10001", "10001", "01110"],
"9": ["01110", "10001", "10001", "01111", "00001", "00001", "11110"],
}
return digits[ch]
if ch.isalpha():
letters = {
"a": ["00000", "01110", "00001", "01111", "10001", "10011", "01101"],
"b": ["10000", "10000", "10110", "11001", "10001", "10001", "11110"],
"c": ["00000", "01111", "10000", "10000", "10000", "10000", "01111"],
"d": ["00001", "00001", "01101", "10011", "10001", "10001", "01111"],
"e": ["00000", "01110", "10001", "11111", "10000", "10000", "01111"],
"f": ["00111", "01000", "01000", "11100", "01000", "01000", "01000"],
"g": ["00000", "01111", "10001", "10001", "01111", "00001", "11110"],
"h": ["10000", "10000", "10110", "11001", "10001", "10001", "10001"],
"i": ["00100", "00000", "01100", "00100", "00100", "00100", "01110"],
"j": ["00010", "00000", "00110", "00010", "00010", "10010", "01100"],
"k": ["10000", "10010", "10100", "11000", "10100", "10010", "10001"],
"l": ["01100", "00100", "00100", "00100", "00100", "00100", "01110"],
"m": ["00000", "11010", "10101", "10101", "10101", "10101", "10101"],
"n": ["00000", "10110", "11001", "10001", "10001", "10001", "10001"],
"o": ["00000", "01110", "10001", "10001", "10001", "10001", "01110"],
"p": ["00000", "11110", "10001", "10001", "11110", "10000", "10000"],
"q": ["00000", "01111", "10001", "10001", "01111", "00001", "00001"],
"r": ["00000", "10111", "11000", "10000", "10000", "10000", "10000"],
"s": ["00000", "01111", "10000", "01110", "00001", "00001", "11110"],
"t": ["01000", "01000", "11100", "01000", "01000", "01001", "00110"],
"u": ["00000", "10001", "10001", "10001", "10001", "10011", "01101"],
"v": ["00000", "10001", "10001", "10001", "01010", "01010", "00100"],
"w": ["00000", "10001", "10001", "10101", "10101", "10101", "01010"],
"x": ["00000", "10001", "01010", "00100", "01010", "10001", "10001"],
"y": ["00000", "10001", "10001", "01111", "00001", "00001", "11110"],
"z": ["00000", "11111", "00010", "00100", "01000", "10000", "11111"],
}
return letters[ch.lower()]
return ["00000", "00000", "11111", "00101", "00100", "00000", "00000"]
def draw_text(img: list[list[Color]], x: int, y: int, text: str, color: Color = BLACK, scale: int = 2) -> None:
cursor = x
for ch in text:
glyph = _font_for(ch)
for gy, row in enumerate(glyph):
for gx, bit in enumerate(row):
if bit == "1":
draw_rect(
img,
cursor + gx * scale,
y + gy * scale,
cursor + (gx + 1) * scale - 1,
y + (gy + 1) * scale - 1,
color,
)
cursor += 6 * scale
_TRANSCRIBED_LOSS_POINTS: list[tuple[int, float]] = [
(1, 2.4762),
(10, 1.4937),
(25, 0.8778),
(40, 0.3899),
(60, 0.2426),
(80, 0.1199),
(100, 0.1287),
(120, 0.1242),
(140, 0.1076),
(160, 0.0874),
(180, 0.0912),
(200, 0.0746),
]
def _load_loss_points() -> tuple[list[tuple[int, float]], str]:
"""Prefer real per-step loss from sft_warmup_metrics.json over transcribed points."""
metrics_path = RESULTS / "sft_warmup_metrics.json"
if metrics_path.exists():
try:
entries = json.loads(metrics_path.read_text())
points = [
(int(entry["step"]), float(entry["loss"]))
for entry in entries
if "step" in entry and "loss" in entry
]
points.sort(key=lambda item: item[0])
if points:
return points, "real"
except (json.JSONDecodeError, KeyError, TypeError, ValueError):
pass
return _TRANSCRIBED_LOSS_POINTS, "transcribed"
def plot_loss() -> None:
loss_points, source = _load_loss_points()
img = canvas(1000, 620)
if source == "real":
title = "SFT loss curve (real trainer.state.log_history)"
first = loss_points[0][1]
last = loss_points[-1][1]
subtitle = f"{len(loss_points)} steps logged, {first:.3f} -> {last:.3f}"
else:
title = "SFT loss curve from final HF Jobs run"
subtitle = "Transcribed console log points: 2.476 -> 0.076"
draw_text(img, 55, 35, title, BLACK, 3)
draw_text(img, 55, 82, subtitle, GRAY, 2)
left, top, right, bottom = 95, 130, 940, 525
for i in range(6):
y = top + round((bottom - top) * i / 5)
draw_line(img, left, y, right, y, GRID, 1)
draw_line(img, left, top, left, bottom, BLACK, 2)
draw_line(img, left, bottom, right, bottom, BLACK, 2)
min_step = min(step for step, _ in loss_points)
max_step = max(step for step, _ in loss_points)
max_loss = max(2.6, max(loss for _, loss in loss_points) * 1.05)
span = max(1, max_step - min_step)
coords: list[tuple[int, int]] = []
for step, loss in loss_points:
x = left + round((right - left) * (step - min_step) / span)
y = bottom - round((bottom - top) * loss / max_loss)
coords.append((x, y))
for (x0, y0), (x1, y1) in zip(coords, coords[1:]):
draw_line(img, x0, y0, x1, y1, BLUE, 4)
if len(coords) <= 30:
for x, y in coords:
draw_rect(img, x - 4, y - 4, x + 4, y + 4, ORANGE)
draw_text(img, left - 35, top - 8, f"{max_loss:.1f}", GRAY, 2)
draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2)
draw_text(img, left - 10, bottom + 25, f"step {min_step}", GRAY, 2)
draw_text(img, right - 110, bottom + 25, f"step {max_step}", GRAY, 2)
if source == "real":
footer = f"Source: results/sft_warmup_metrics.json (n={len(loss_points)} steps)."
else:
footer = "Source: HF Jobs console log (sparse). Replace with sft_warmup_metrics.json for full curve."
draw_text(img, 360, 565, footer, GRAY, 2)
write_png(RESULTS / "final_sft_loss_curve.png", len(img[0]), len(img), img)
def plot_scores() -> None:
data = json.loads((RESULTS / "final_phaseaware_model_eval.json").read_text())
by_task: dict[str, list[float]] = defaultdict(list)
for episode in data["episodes"]:
by_task[episode["task_id"]].append(float(episode["score"]))
order = ["easy", "medium", "hard", "cascade"]
means = [sum(by_task[task]) / len(by_task[task]) for task in order]
img = canvas(1000, 620)
draw_text(img, 55, 35, "Final phase-aware trained LoRA score by task", BLACK, 3)
draw_text(img, 55, 82, "Mean score 0.915, pass rate 100%, 12 episodes", GRAY, 2)
left, top, right, bottom = 95, 130, 940, 525
for i in range(6):
y = top + round((bottom - top) * i / 5)
draw_line(img, left, y, right, y, GRID, 1)
draw_line(img, left, top, left, bottom, BLACK, 2)
draw_line(img, left, bottom, right, bottom, BLACK, 2)
bar_width = 120
gap = 75
colors = [GREEN, GREEN, BLUE, ORANGE]
for idx, (task, mean, color) in enumerate(zip(order, means, colors)):
x0 = left + gap + idx * (bar_width + gap)
x1 = x0 + bar_width
y0 = bottom - round((bottom - top) * mean)
draw_rect(img, x0, y0, x1, bottom - 1, color)
draw_text(img, x0 + 18, bottom + 25, task, BLACK, 2)
draw_text(img, x0 + 20, y0 - 28, f"{mean:.3f}", BLACK, 2)
draw_text(img, left - 35, top - 8, "1.0", GRAY, 2)
draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2)
write_png(RESULTS / "final_score_by_task.png", len(img[0]), len(img), img)
def plot_before_after() -> None:
img = canvas(1000, 620)
draw_text(img, 55, 35, "Raw base model vs final trained policy", BLACK, 3)
draw_text(img, 55, 82, "Constrained eval before vs final phase-aware constrained eval after", GRAY, 2)
left, top, right, bottom = 110, 130, 920, 525
for i in range(6):
y = top + round((bottom - top) * i / 5)
draw_line(img, left, y, right, y, GRID, 1)
draw_line(img, left, top, left, bottom, BLACK, 2)
draw_line(img, left, bottom, right, bottom, BLACK, 2)
bars = [
("raw", 0.23905299302262772, GRAY),
("trained", 0.9146806281246708, GREEN),
]
for idx, (label, value, color) in enumerate(bars):
x0 = left + 165 + idx * 290
x1 = x0 + 150
y0 = bottom - round((bottom - top) * value)
draw_rect(img, x0, y0, x1, bottom - 1, color)
draw_text(img, x0 + 25, bottom + 25, label, BLACK, 2)
draw_text(img, x0 + 35, y0 - 28, f"{value:.3f}", BLACK, 2)
draw_text(img, left - 35, top - 8, "1.0", GRAY, 2)
draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2)
draw_text(img, 330, 565, "Raw before: mean 0.239, pass 0%. Final after: mean 0.915, pass 100%.", GRAY, 2)
write_png(RESULTS / "before_vs_after_scores.png", len(img[0]), len(img), img)
def main() -> None:
plot_loss()
plot_scores()
plot_before_after()
for name in ["final_sft_loss_curve.png", "final_score_by_task.png", "before_vs_after_scores.png"]:
print(RESULTS / name)
if __name__ == "__main__":
main()