File size: 6,236 Bytes
d53a65c
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
"""Render SVG visuals for LLM checkpoint comparison."""

from __future__ import annotations

import csv
from pathlib import Path

ARTIFACT_DIR = Path("artifacts/evals_llm")
COMPARISON_CSV = ARTIFACT_DIR / "llm_comparison.csv"


def _svg_header(width: int, height: int) -> list[str]:
    return [
        f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
        '<rect width="100%" height="100%" fill="#FFFFFF"/>',
    ]


def _svg_footer() -> list[str]:
    return ["</svg>"]


def _rows() -> list[dict[str, str]]:
    with COMPARISON_CSV.open() as f:
        return list(csv.DictReader(f))


def plot_reward(rows: list[dict[str, str]]) -> None:
    tasks = [r["task_id"] for r in rows]
    base = [float(r["baseline_reward"]) for r in rows]
    trained = [float(r["trained_reward"]) for r in rows]

    width, height = 1360, 520
    left, right, top, bottom = 80, 40, 70, 110
    plot_w = width - left - right
    plot_h = height - top - bottom
    group_w = plot_w / max(len(tasks), 1)
    bar_w = max(group_w * 0.32, 10)

    lines = _svg_header(width, height)
    lines.append('<text x="80" y="35" font-size="22" font-family="Arial" fill="#111827">Base vs Trained LLM Reward by Task</text>')
    lines.append(f'<line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
    lines.append(f'<line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')

    for tick in range(0, 6):
        value = tick / 5
        y = top + plot_h - (value * plot_h)
        lines.append(f'<line x1="{left}" y1="{y:.2f}" x2="{left+plot_w}" y2="{y:.2f}" stroke="#E5E7EB" stroke-width="1"/>')
        lines.append(f'<text x="{left-38}" y="{y+5:.2f}" font-size="12" font-family="Arial" fill="#374151">{value:.1f}</text>')

    for idx, task in enumerate(tasks):
        gx = left + (idx * group_w) + (group_w * 0.5)
        b_h = base[idx] * plot_h
        t_h = trained[idx] * plot_h
        b_x = gx - bar_w - 2
        t_x = gx + 2
        b_y = top + plot_h - b_h
        t_y = top + plot_h - t_h
        lines.append(f'<rect x="{b_x:.2f}" y="{b_y:.2f}" width="{bar_w:.2f}" height="{b_h:.2f}" fill="#9CA3AF"/>')
        lines.append(f'<rect x="{t_x:.2f}" y="{t_y:.2f}" width="{bar_w:.2f}" height="{t_h:.2f}" fill="#2563EB"/>')
        lines.append(
            f'<text x="{gx:.2f}" y="{top+plot_h+22}" font-size="10" text-anchor="middle" '
            f'font-family="Arial" fill="#374151" transform="rotate(25 {gx:.2f},{top+plot_h+22})">{task}</text>'
        )

    legend_y = 52
    lines.append(f'<rect x="{width-310}" y="{legend_y-10}" width="12" height="12" fill="#9CA3AF"/>')
    lines.append(f'<text x="{width-292}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Base</text>')
    lines.append(f'<rect x="{width-230}" y="{legend_y-10}" width="12" height="12" fill="#2563EB"/>')
    lines.append(f'<text x="{width-212}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Trained</text>')
    lines.extend(_svg_footer())
    (ARTIFACT_DIR / "llm_reward_by_task.svg").write_text("\n".join(lines))


def plot_violations(rows: list[dict[str, str]]) -> None:
    tasks = [r["task_id"] for r in rows]
    base = [int(r["baseline_violations"]) for r in rows]
    trained = [int(r["trained_violations"]) for r in rows]
    max_v = max(max(base, default=0), max(trained, default=0), 1)

    width, height = 1360, 500
    left, right, top, bottom = 80, 40, 70, 100
    plot_w = width - left - right
    plot_h = height - top - bottom

    def point_x(i: int) -> float:
        return left + (i / max(len(tasks) - 1, 1)) * plot_w

    def point_y(v: int) -> float:
        return top + plot_h - ((v / max_v) * plot_h)

    lines = _svg_header(width, height)
    lines.append('<text x="80" y="35" font-size="22" font-family="Arial" fill="#111827">Base vs Trained LLM Commitment Violations</text>')
    lines.append(f'<line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
    lines.append(f'<line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')

    for tick in range(max_v + 1):
        y = point_y(tick)
        lines.append(f'<line x1="{left}" y1="{y:.2f}" x2="{left+plot_w}" y2="{y:.2f}" stroke="#E5E7EB" stroke-width="1"/>')
        lines.append(f'<text x="{left-24}" y="{y+5:.2f}" font-size="12" font-family="Arial" fill="#374151">{tick}</text>')

    base_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(base))
    tr_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(trained))
    lines.append(f'<polyline points="{base_points}" fill="none" stroke="#DC2626" stroke-width="2"/>')
    lines.append(f'<polyline points="{tr_points}" fill="none" stroke="#059669" stroke-width="2"/>')

    for i, task in enumerate(tasks):
        x = point_x(i)
        lines.append(f'<circle cx="{x:.2f}" cy="{point_y(base[i]):.2f}" r="3" fill="#DC2626"/>')
        lines.append(f'<circle cx="{x:.2f}" cy="{point_y(trained[i]):.2f}" r="3" fill="#059669"/>')
        lines.append(
            f'<text x="{x:.2f}" y="{top+plot_h+20}" font-size="10" text-anchor="middle" '
            f'font-family="Arial" fill="#374151" transform="rotate(25 {x:.2f},{top+plot_h+20})">{task}</text>'
        )

    legend_y = 52
    lines.append(f'<line x1="{width-320}" y1="{legend_y-5}" x2="{width-300}" y2="{legend_y-5}" stroke="#DC2626" stroke-width="2"/>')
    lines.append(f'<text x="{width-295}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Base</text>')
    lines.append(f'<line x1="{width-230}" y1="{legend_y-5}" x2="{width-210}" y2="{legend_y-5}" stroke="#059669" stroke-width="2"/>')
    lines.append(f'<text x="{width-205}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Trained</text>')
    lines.extend(_svg_footer())
    (ARTIFACT_DIR / "llm_violations_before_after.svg").write_text("\n".join(lines))


def main() -> None:
    rows = _rows()
    plot_reward(rows)
    plot_violations(rows)
    print("Wrote checkpoint comparison SVG plots to", ARTIFACT_DIR)


if __name__ == "__main__":
    main()