Spaces:
Running
Running
| """Generate final submission evidence plots without external dependencies. | |
| Loss curve preference: | |
| 1. If `results/sft_warmup_metrics.json` exists (written by `training/grpo_train.py` | |
| from `trainer.state.log_history` during the HF Jobs run), plot every step from | |
| that file. | |
| 2. Otherwise, fall back to the transcribed sparse points from the HF Jobs console. | |
| The score plot is computed from results/final_phaseaware_model_eval.json. | |
| """ | |
| from __future__ import annotations | |
| import json | |
| import math | |
| import struct | |
| import zlib | |
| from collections import defaultdict | |
| from pathlib import Path | |
| ROOT = Path(__file__).resolve().parents[1] | |
| RESULTS = ROOT / "results" | |
| Color = tuple[int, int, int] | |
| WHITE: Color = (255, 255, 255) | |
| BLACK: Color = (20, 24, 33) | |
| GRID: Color = (220, 225, 232) | |
| BLUE: Color = (42, 111, 219) | |
| GREEN: Color = (34, 139, 84) | |
| ORANGE: Color = (232, 140, 45) | |
| GRAY: Color = (108, 117, 125) | |
| def _png_chunk(kind: bytes, data: bytes) -> bytes: | |
| return struct.pack(">I", len(data)) + kind + data + struct.pack(">I", zlib.crc32(kind + data) & 0xFFFFFFFF) | |
| def write_png(path: Path, width: int, height: int, pixels: list[list[Color]]) -> None: | |
| raw = bytearray() | |
| for row in pixels: | |
| raw.append(0) | |
| for red, green, blue in row: | |
| raw.extend((red, green, blue)) | |
| path.parent.mkdir(parents=True, exist_ok=True) | |
| path.write_bytes( | |
| b"\x89PNG\r\n\x1a\n" | |
| + _png_chunk(b"IHDR", struct.pack(">IIBBBBB", width, height, 8, 2, 0, 0, 0)) | |
| + _png_chunk(b"IDAT", zlib.compress(bytes(raw), level=9)) | |
| + _png_chunk(b"IEND", b"") | |
| ) | |
| def canvas(width: int, height: int, color: Color = WHITE) -> list[list[Color]]: | |
| return [[color for _ in range(width)] for _ in range(height)] | |
| def set_px(img: list[list[Color]], x: int, y: int, color: Color) -> None: | |
| if 0 <= y < len(img) and 0 <= x < len(img[0]): | |
| img[y][x] = color | |
| def draw_line(img: list[list[Color]], x0: int, y0: int, x1: int, y1: int, color: Color, width: int = 2) -> None: | |
| dx = abs(x1 - x0) | |
| dy = -abs(y1 - y0) | |
| sx = 1 if x0 < x1 else -1 | |
| sy = 1 if y0 < y1 else -1 | |
| err = dx + dy | |
| while True: | |
| radius = width // 2 | |
| for ox in range(-radius, radius + 1): | |
| for oy in range(-radius, radius + 1): | |
| set_px(img, x0 + ox, y0 + oy, color) | |
| if x0 == x1 and y0 == y1: | |
| break | |
| e2 = 2 * err | |
| if e2 >= dy: | |
| err += dy | |
| x0 += sx | |
| if e2 <= dx: | |
| err += dx | |
| y0 += sy | |
| def draw_rect(img: list[list[Color]], x0: int, y0: int, x1: int, y1: int, color: Color) -> None: | |
| for y in range(max(0, y0), min(len(img), y1 + 1)): | |
| for x in range(max(0, x0), min(len(img[0]), x1 + 1)): | |
| img[y][x] = color | |
| # Tiny 5x7 bitmap font for chart labels. | |
| FONT: dict[str, list[str]] = { | |
| " ": ["00000"] * 7, | |
| ".": ["00000", "00000", "00000", "00000", "00000", "01100", "01100"], | |
| "-": ["00000", "00000", "00000", "11111", "00000", "00000", "00000"], | |
| ">": ["10000", "01000", "00100", "00010", "00100", "01000", "10000"], | |
| "%": ["11001", "11010", "00100", "01000", "10110", "00110", "00000"], | |
| "/": ["00001", "00010", "00100", "01000", "10000", "00000", "00000"], | |
| ":": ["00000", "01100", "01100", "00000", "01100", "01100", "00000"], | |
| } | |
| def _font_for(ch: str) -> list[str]: | |
| if ch in FONT: | |
| return FONT[ch] | |
| if ch.isdigit(): | |
| digits = { | |
| "0": ["11111", "10001", "10011", "10101", "11001", "10001", "11111"], | |
| "1": ["00100", "01100", "00100", "00100", "00100", "00100", "01110"], | |
| "2": ["11110", "00001", "00001", "11110", "10000", "10000", "11111"], | |
| "3": ["11110", "00001", "00001", "01110", "00001", "00001", "11110"], | |
| "4": ["10010", "10010", "10010", "11111", "00010", "00010", "00010"], | |
| "5": ["11111", "10000", "10000", "11110", "00001", "00001", "11110"], | |
| "6": ["01111", "10000", "10000", "11110", "10001", "10001", "01110"], | |
| "7": ["11111", "00001", "00010", "00100", "01000", "01000", "01000"], | |
| "8": ["01110", "10001", "10001", "01110", "10001", "10001", "01110"], | |
| "9": ["01110", "10001", "10001", "01111", "00001", "00001", "11110"], | |
| } | |
| return digits[ch] | |
| if ch.isalpha(): | |
| letters = { | |
| "a": ["00000", "01110", "00001", "01111", "10001", "10011", "01101"], | |
| "b": ["10000", "10000", "10110", "11001", "10001", "10001", "11110"], | |
| "c": ["00000", "01111", "10000", "10000", "10000", "10000", "01111"], | |
| "d": ["00001", "00001", "01101", "10011", "10001", "10001", "01111"], | |
| "e": ["00000", "01110", "10001", "11111", "10000", "10000", "01111"], | |
| "f": ["00111", "01000", "01000", "11100", "01000", "01000", "01000"], | |
| "g": ["00000", "01111", "10001", "10001", "01111", "00001", "11110"], | |
| "h": ["10000", "10000", "10110", "11001", "10001", "10001", "10001"], | |
| "i": ["00100", "00000", "01100", "00100", "00100", "00100", "01110"], | |
| "j": ["00010", "00000", "00110", "00010", "00010", "10010", "01100"], | |
| "k": ["10000", "10010", "10100", "11000", "10100", "10010", "10001"], | |
| "l": ["01100", "00100", "00100", "00100", "00100", "00100", "01110"], | |
| "m": ["00000", "11010", "10101", "10101", "10101", "10101", "10101"], | |
| "n": ["00000", "10110", "11001", "10001", "10001", "10001", "10001"], | |
| "o": ["00000", "01110", "10001", "10001", "10001", "10001", "01110"], | |
| "p": ["00000", "11110", "10001", "10001", "11110", "10000", "10000"], | |
| "q": ["00000", "01111", "10001", "10001", "01111", "00001", "00001"], | |
| "r": ["00000", "10111", "11000", "10000", "10000", "10000", "10000"], | |
| "s": ["00000", "01111", "10000", "01110", "00001", "00001", "11110"], | |
| "t": ["01000", "01000", "11100", "01000", "01000", "01001", "00110"], | |
| "u": ["00000", "10001", "10001", "10001", "10001", "10011", "01101"], | |
| "v": ["00000", "10001", "10001", "10001", "01010", "01010", "00100"], | |
| "w": ["00000", "10001", "10001", "10101", "10101", "10101", "01010"], | |
| "x": ["00000", "10001", "01010", "00100", "01010", "10001", "10001"], | |
| "y": ["00000", "10001", "10001", "01111", "00001", "00001", "11110"], | |
| "z": ["00000", "11111", "00010", "00100", "01000", "10000", "11111"], | |
| } | |
| return letters[ch.lower()] | |
| return ["00000", "00000", "11111", "00101", "00100", "00000", "00000"] | |
| def draw_text(img: list[list[Color]], x: int, y: int, text: str, color: Color = BLACK, scale: int = 2) -> None: | |
| cursor = x | |
| for ch in text: | |
| glyph = _font_for(ch) | |
| for gy, row in enumerate(glyph): | |
| for gx, bit in enumerate(row): | |
| if bit == "1": | |
| draw_rect( | |
| img, | |
| cursor + gx * scale, | |
| y + gy * scale, | |
| cursor + (gx + 1) * scale - 1, | |
| y + (gy + 1) * scale - 1, | |
| color, | |
| ) | |
| cursor += 6 * scale | |
| _TRANSCRIBED_LOSS_POINTS: list[tuple[int, float]] = [ | |
| (1, 2.4762), | |
| (10, 1.4937), | |
| (25, 0.8778), | |
| (40, 0.3899), | |
| (60, 0.2426), | |
| (80, 0.1199), | |
| (100, 0.1287), | |
| (120, 0.1242), | |
| (140, 0.1076), | |
| (160, 0.0874), | |
| (180, 0.0912), | |
| (200, 0.0746), | |
| ] | |
| def _load_loss_points() -> tuple[list[tuple[int, float]], str]: | |
| """Prefer real per-step loss from sft_warmup_metrics.json over transcribed points.""" | |
| metrics_path = RESULTS / "sft_warmup_metrics.json" | |
| if metrics_path.exists(): | |
| try: | |
| entries = json.loads(metrics_path.read_text()) | |
| points = [ | |
| (int(entry["step"]), float(entry["loss"])) | |
| for entry in entries | |
| if "step" in entry and "loss" in entry | |
| ] | |
| points.sort(key=lambda item: item[0]) | |
| if points: | |
| return points, "real" | |
| except (json.JSONDecodeError, KeyError, TypeError, ValueError): | |
| pass | |
| return _TRANSCRIBED_LOSS_POINTS, "transcribed" | |
| def plot_loss() -> None: | |
| loss_points, source = _load_loss_points() | |
| img = canvas(1000, 620) | |
| if source == "real": | |
| title = "SFT loss curve (real trainer.state.log_history)" | |
| first = loss_points[0][1] | |
| last = loss_points[-1][1] | |
| subtitle = f"{len(loss_points)} steps logged, {first:.3f} -> {last:.3f}" | |
| else: | |
| title = "SFT loss curve from final HF Jobs run" | |
| subtitle = "Transcribed console log points: 2.476 -> 0.076" | |
| draw_text(img, 55, 35, title, BLACK, 3) | |
| draw_text(img, 55, 82, subtitle, GRAY, 2) | |
| left, top, right, bottom = 95, 130, 940, 525 | |
| for i in range(6): | |
| y = top + round((bottom - top) * i / 5) | |
| draw_line(img, left, y, right, y, GRID, 1) | |
| draw_line(img, left, top, left, bottom, BLACK, 2) | |
| draw_line(img, left, bottom, right, bottom, BLACK, 2) | |
| min_step = min(step for step, _ in loss_points) | |
| max_step = max(step for step, _ in loss_points) | |
| max_loss = max(2.6, max(loss for _, loss in loss_points) * 1.05) | |
| span = max(1, max_step - min_step) | |
| coords: list[tuple[int, int]] = [] | |
| for step, loss in loss_points: | |
| x = left + round((right - left) * (step - min_step) / span) | |
| y = bottom - round((bottom - top) * loss / max_loss) | |
| coords.append((x, y)) | |
| for (x0, y0), (x1, y1) in zip(coords, coords[1:]): | |
| draw_line(img, x0, y0, x1, y1, BLUE, 4) | |
| if len(coords) <= 30: | |
| for x, y in coords: | |
| draw_rect(img, x - 4, y - 4, x + 4, y + 4, ORANGE) | |
| draw_text(img, left - 35, top - 8, f"{max_loss:.1f}", GRAY, 2) | |
| draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2) | |
| draw_text(img, left - 10, bottom + 25, f"step {min_step}", GRAY, 2) | |
| draw_text(img, right - 110, bottom + 25, f"step {max_step}", GRAY, 2) | |
| if source == "real": | |
| footer = f"Source: results/sft_warmup_metrics.json (n={len(loss_points)} steps)." | |
| else: | |
| footer = "Source: HF Jobs console log (sparse). Replace with sft_warmup_metrics.json for full curve." | |
| draw_text(img, 360, 565, footer, GRAY, 2) | |
| write_png(RESULTS / "final_sft_loss_curve.png", len(img[0]), len(img), img) | |
| def plot_scores() -> None: | |
| data = json.loads((RESULTS / "final_phaseaware_model_eval.json").read_text()) | |
| by_task: dict[str, list[float]] = defaultdict(list) | |
| for episode in data["episodes"]: | |
| by_task[episode["task_id"]].append(float(episode["score"])) | |
| order = ["easy", "medium", "hard", "cascade"] | |
| means = [sum(by_task[task]) / len(by_task[task]) for task in order] | |
| img = canvas(1000, 620) | |
| draw_text(img, 55, 35, "Final phase-aware trained LoRA score by task", BLACK, 3) | |
| draw_text(img, 55, 82, "Mean score 0.915, pass rate 100%, 12 episodes", GRAY, 2) | |
| left, top, right, bottom = 95, 130, 940, 525 | |
| for i in range(6): | |
| y = top + round((bottom - top) * i / 5) | |
| draw_line(img, left, y, right, y, GRID, 1) | |
| draw_line(img, left, top, left, bottom, BLACK, 2) | |
| draw_line(img, left, bottom, right, bottom, BLACK, 2) | |
| bar_width = 120 | |
| gap = 75 | |
| colors = [GREEN, GREEN, BLUE, ORANGE] | |
| for idx, (task, mean, color) in enumerate(zip(order, means, colors)): | |
| x0 = left + gap + idx * (bar_width + gap) | |
| x1 = x0 + bar_width | |
| y0 = bottom - round((bottom - top) * mean) | |
| draw_rect(img, x0, y0, x1, bottom - 1, color) | |
| draw_text(img, x0 + 18, bottom + 25, task, BLACK, 2) | |
| draw_text(img, x0 + 20, y0 - 28, f"{mean:.3f}", BLACK, 2) | |
| draw_text(img, left - 35, top - 8, "1.0", GRAY, 2) | |
| draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2) | |
| write_png(RESULTS / "final_score_by_task.png", len(img[0]), len(img), img) | |
| def plot_before_after() -> None: | |
| img = canvas(1000, 620) | |
| draw_text(img, 55, 35, "Raw base model vs final trained policy", BLACK, 3) | |
| draw_text(img, 55, 82, "Constrained eval before vs final phase-aware constrained eval after", GRAY, 2) | |
| left, top, right, bottom = 110, 130, 920, 525 | |
| for i in range(6): | |
| y = top + round((bottom - top) * i / 5) | |
| draw_line(img, left, y, right, y, GRID, 1) | |
| draw_line(img, left, top, left, bottom, BLACK, 2) | |
| draw_line(img, left, bottom, right, bottom, BLACK, 2) | |
| bars = [ | |
| ("raw", 0.23905299302262772, GRAY), | |
| ("trained", 0.9146806281246708, GREEN), | |
| ] | |
| for idx, (label, value, color) in enumerate(bars): | |
| x0 = left + 165 + idx * 290 | |
| x1 = x0 + 150 | |
| y0 = bottom - round((bottom - top) * value) | |
| draw_rect(img, x0, y0, x1, bottom - 1, color) | |
| draw_text(img, x0 + 25, bottom + 25, label, BLACK, 2) | |
| draw_text(img, x0 + 35, y0 - 28, f"{value:.3f}", BLACK, 2) | |
| draw_text(img, left - 35, top - 8, "1.0", GRAY, 2) | |
| draw_text(img, left - 35, bottom - 8, "0.0", GRAY, 2) | |
| draw_text(img, 330, 565, "Raw before: mean 0.239, pass 0%. Final after: mean 0.915, pass 100%.", GRAY, 2) | |
| write_png(RESULTS / "before_vs_after_scores.png", len(img[0]), len(img), img) | |
| def main() -> None: | |
| plot_loss() | |
| plot_scores() | |
| plot_before_after() | |
| for name in ["final_sft_loss_curve.png", "final_score_by_task.png", "before_vs_after_scores.png"]: | |
| print(RESULTS / name) | |
| if __name__ == "__main__": | |
| main() | |