"""Render SVG visuals for LLM checkpoint comparison."""
from __future__ import annotations
import csv
from pathlib import Path
ARTIFACT_DIR = Path("artifacts/evals_llm")
COMPARISON_CSV = ARTIFACT_DIR / "llm_comparison.csv"
def _svg_header(width: int, height: int) -> list[str]:
return [
f'"]
def _rows() -> list[dict[str, str]]:
with COMPARISON_CSV.open() as f:
return list(csv.DictReader(f))
def plot_reward(rows: list[dict[str, str]]) -> None:
tasks = [r["task_id"] for r in rows]
base = [float(r["baseline_reward"]) for r in rows]
trained = [float(r["trained_reward"]) for r in rows]
width, height = 1360, 520
left, right, top, bottom = 80, 40, 70, 110
plot_w = width - left - right
plot_h = height - top - bottom
group_w = plot_w / max(len(tasks), 1)
bar_w = max(group_w * 0.32, 10)
lines = _svg_header(width, height)
lines.append('Base vs Trained LLM Reward by Task')
lines.append(f'')
lines.append(f'')
for tick in range(0, 6):
value = tick / 5
y = top + plot_h - (value * plot_h)
lines.append(f'')
lines.append(f'{value:.1f}')
for idx, task in enumerate(tasks):
gx = left + (idx * group_w) + (group_w * 0.5)
b_h = base[idx] * plot_h
t_h = trained[idx] * plot_h
b_x = gx - bar_w - 2
t_x = gx + 2
b_y = top + plot_h - b_h
t_y = top + plot_h - t_h
lines.append(f'')
lines.append(f'')
lines.append(
f'{task}'
)
legend_y = 52
lines.append(f'')
lines.append(f'Base')
lines.append(f'')
lines.append(f'Trained')
lines.extend(_svg_footer())
(ARTIFACT_DIR / "llm_reward_by_task.svg").write_text("\n".join(lines))
def plot_violations(rows: list[dict[str, str]]) -> None:
tasks = [r["task_id"] for r in rows]
base = [int(r["baseline_violations"]) for r in rows]
trained = [int(r["trained_violations"]) for r in rows]
max_v = max(max(base, default=0), max(trained, default=0), 1)
width, height = 1360, 500
left, right, top, bottom = 80, 40, 70, 100
plot_w = width - left - right
plot_h = height - top - bottom
def point_x(i: int) -> float:
return left + (i / max(len(tasks) - 1, 1)) * plot_w
def point_y(v: int) -> float:
return top + plot_h - ((v / max_v) * plot_h)
lines = _svg_header(width, height)
lines.append('Base vs Trained LLM Commitment Violations')
lines.append(f'')
lines.append(f'')
for tick in range(max_v + 1):
y = point_y(tick)
lines.append(f'')
lines.append(f'{tick}')
base_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(base))
tr_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(trained))
lines.append(f'')
lines.append(f'')
for i, task in enumerate(tasks):
x = point_x(i)
lines.append(f'')
lines.append(f'')
lines.append(
f'{task}'
)
legend_y = 52
lines.append(f'')
lines.append(f'Base')
lines.append(f'')
lines.append(f'Trained')
lines.extend(_svg_footer())
(ARTIFACT_DIR / "llm_violations_before_after.svg").write_text("\n".join(lines))
def main() -> None:
rows = _rows()
plot_reward(rows)
plot_violations(rows)
print("Wrote checkpoint comparison SVG plots to", ARTIFACT_DIR)
if __name__ == "__main__":
main()