"""Render SVG visuals for LLM checkpoint comparison.""" from __future__ import annotations import csv from pathlib import Path ARTIFACT_DIR = Path("artifacts/evals_llm") COMPARISON_CSV = ARTIFACT_DIR / "llm_comparison.csv" def _svg_header(width: int, height: int) -> list[str]: return [ f'', '', ] def _svg_footer() -> list[str]: return [""] def _rows() -> list[dict[str, str]]: with COMPARISON_CSV.open() as f: return list(csv.DictReader(f)) def plot_reward(rows: list[dict[str, str]]) -> None: tasks = [r["task_id"] for r in rows] base = [float(r["baseline_reward"]) for r in rows] trained = [float(r["trained_reward"]) for r in rows] width, height = 1360, 520 left, right, top, bottom = 80, 40, 70, 110 plot_w = width - left - right plot_h = height - top - bottom group_w = plot_w / max(len(tasks), 1) bar_w = max(group_w * 0.32, 10) lines = _svg_header(width, height) lines.append('Base vs Trained LLM Reward by Task') lines.append(f'') lines.append(f'') for tick in range(0, 6): value = tick / 5 y = top + plot_h - (value * plot_h) lines.append(f'') lines.append(f'{value:.1f}') for idx, task in enumerate(tasks): gx = left + (idx * group_w) + (group_w * 0.5) b_h = base[idx] * plot_h t_h = trained[idx] * plot_h b_x = gx - bar_w - 2 t_x = gx + 2 b_y = top + plot_h - b_h t_y = top + plot_h - t_h lines.append(f'') lines.append(f'') lines.append( f'{task}' ) legend_y = 52 lines.append(f'') lines.append(f'Base') lines.append(f'') lines.append(f'Trained') lines.extend(_svg_footer()) (ARTIFACT_DIR / "llm_reward_by_task.svg").write_text("\n".join(lines)) def plot_violations(rows: list[dict[str, str]]) -> None: tasks = [r["task_id"] for r in rows] base = [int(r["baseline_violations"]) for r in rows] trained = [int(r["trained_violations"]) for r in rows] max_v = max(max(base, default=0), max(trained, default=0), 1) width, height = 1360, 500 left, right, top, bottom = 80, 40, 70, 100 plot_w = width - left - right plot_h = height - top - bottom def point_x(i: int) -> float: return left + (i / max(len(tasks) - 1, 1)) * plot_w def point_y(v: int) -> float: return top + plot_h - ((v / max_v) * plot_h) lines = _svg_header(width, height) lines.append('Base vs Trained LLM Commitment Violations') lines.append(f'') lines.append(f'') for tick in range(max_v + 1): y = point_y(tick) lines.append(f'') lines.append(f'{tick}') base_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(base)) tr_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(trained)) lines.append(f'') lines.append(f'') for i, task in enumerate(tasks): x = point_x(i) lines.append(f'') lines.append(f'') lines.append( f'{task}' ) legend_y = 52 lines.append(f'') lines.append(f'Base') lines.append(f'') lines.append(f'Trained') lines.extend(_svg_footer()) (ARTIFACT_DIR / "llm_violations_before_after.svg").write_text("\n".join(lines)) def main() -> None: rows = _rows() plot_reward(rows) plot_violations(rows) print("Wrote checkpoint comparison SVG plots to", ARTIFACT_DIR) if __name__ == "__main__": main()