Spaces:
Sleeping
Sleeping
File size: 6,477 Bytes
98b25a9 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 | """Generate judge-friendly SVG plots from evaluation comparison CSV.
This module intentionally avoids matplotlib to keep plotting deterministic
in restricted CI/sandbox environments.
"""
from __future__ import annotations
import csv
from pathlib import Path
ARTIFACT_DIR = Path("artifacts/evals")
COMPARISON_CSV = ARTIFACT_DIR / "comparison.csv"
def _load_rows() -> list[dict[str, str]]:
with COMPARISON_CSV.open() as f:
return list(csv.DictReader(f))
def _svg_header(width: int, height: int) -> list[str]:
return [
f'<svg xmlns="http://www.w3.org/2000/svg" width="{width}" height="{height}" viewBox="0 0 {width} {height}">',
'<rect width="100%" height="100%" fill="#FFFFFF"/>',
]
def _svg_footer() -> list[str]:
return ["</svg>"]
def plot_reward_by_task(rows: list[dict[str, str]]) -> None:
tasks = [row["task_id"] for row in rows]
baseline = [float(row["baseline_reward"]) for row in rows]
improved = [float(row["improved_reward"]) for row in rows]
width, height = 1360, 520
left, right, top, bottom = 80, 40, 70, 110
plot_w = width - left - right
plot_h = height - top - bottom
group_w = plot_w / max(len(tasks), 1)
bar_w = max(group_w * 0.32, 10)
lines = _svg_header(width, height)
lines.append('<text x="80" y="35" font-size="22" font-family="Arial" fill="#111827">Baseline vs Improved Reward by Task</text>')
lines.append(f'<line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
lines.append(f'<line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
for tick in range(0, 6):
value = tick / 5
y = top + plot_h - (value * plot_h)
lines.append(f'<line x1="{left}" y1="{y:.2f}" x2="{left+plot_w}" y2="{y:.2f}" stroke="#E5E7EB" stroke-width="1"/>')
lines.append(f'<text x="{left-38}" y="{y+5:.2f}" font-size="12" font-family="Arial" fill="#374151">{value:.1f}</text>')
for idx, task in enumerate(tasks):
gx = left + (idx * group_w) + (group_w * 0.5)
b_h = baseline[idx] * plot_h
i_h = improved[idx] * plot_h
b_x = gx - bar_w - 2
i_x = gx + 2
b_y = top + plot_h - b_h
i_y = top + plot_h - i_h
lines.append(f'<rect x="{b_x:.2f}" y="{b_y:.2f}" width="{bar_w:.2f}" height="{b_h:.2f}" fill="#9CA3AF"/>')
lines.append(f'<rect x="{i_x:.2f}" y="{i_y:.2f}" width="{bar_w:.2f}" height="{i_h:.2f}" fill="#2563EB"/>')
lines.append(
f'<text x="{gx:.2f}" y="{top+plot_h+22}" font-size="10" text-anchor="middle" '
f'font-family="Arial" fill="#374151" transform="rotate(25 {gx:.2f},{top+plot_h+22})">{task}</text>'
)
legend_y = 52
lines.append(f'<rect x="{width-300}" y="{legend_y-10}" width="12" height="12" fill="#9CA3AF"/>')
lines.append(f'<text x="{width-282}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Baseline</text>')
lines.append(f'<rect x="{width-210}" y="{legend_y-10}" width="12" height="12" fill="#2563EB"/>')
lines.append(f'<text x="{width-192}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Improved</text>')
lines.extend(_svg_footer())
(ARTIFACT_DIR / "reward_by_task.svg").write_text("\n".join(lines))
def plot_violation_before_after(rows: list[dict[str, str]]) -> None:
tasks = [row["task_id"] for row in rows]
baseline = [int(row["baseline_violations"]) for row in rows]
improved = [int(row["improved_violations"]) for row in rows]
max_v = max(max(baseline, default=0), max(improved, default=0), 1)
width, height = 1360, 500
left, right, top, bottom = 80, 40, 70, 100
plot_w = width - left - right
plot_h = height - top - bottom
def point_x(idx: int) -> float:
return left + (idx / max(len(tasks) - 1, 1)) * plot_w
def point_y(value: int) -> float:
return top + plot_h - ((value / max_v) * plot_h)
lines = _svg_header(width, height)
lines.append('<text x="80" y="35" font-size="22" font-family="Arial" fill="#111827">Commitment Violations Before vs After</text>')
lines.append(f'<line x1="{left}" y1="{top+plot_h}" x2="{left+plot_w}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
lines.append(f'<line x1="{left}" y1="{top}" x2="{left}" y2="{top+plot_h}" stroke="#374151" stroke-width="1"/>')
for tick in range(max_v + 1):
y = point_y(tick)
lines.append(f'<line x1="{left}" y1="{y:.2f}" x2="{left+plot_w}" y2="{y:.2f}" stroke="#E5E7EB" stroke-width="1"/>')
lines.append(f'<text x="{left-24}" y="{y+5:.2f}" font-size="12" font-family="Arial" fill="#374151">{tick}</text>')
baseline_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(baseline))
improved_points = " ".join(f"{point_x(i):.2f},{point_y(v):.2f}" for i, v in enumerate(improved))
lines.append(f'<polyline points="{baseline_points}" fill="none" stroke="#DC2626" stroke-width="2"/>')
lines.append(f'<polyline points="{improved_points}" fill="none" stroke="#059669" stroke-width="2"/>')
for i, task in enumerate(tasks):
x = point_x(i)
lines.append(f'<circle cx="{x:.2f}" cy="{point_y(baseline[i]):.2f}" r="3" fill="#DC2626"/>')
lines.append(f'<circle cx="{x:.2f}" cy="{point_y(improved[i]):.2f}" r="3" fill="#059669"/>')
lines.append(
f'<text x="{x:.2f}" y="{top+plot_h+20}" font-size="10" text-anchor="middle" '
f'font-family="Arial" fill="#374151" transform="rotate(25 {x:.2f},{top+plot_h+20})">{task}</text>'
)
legend_y = 52
lines.append(f'<line x1="{width-320}" y1="{legend_y-5}" x2="{width-300}" y2="{legend_y-5}" stroke="#DC2626" stroke-width="2"/>')
lines.append(f'<text x="{width-295}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Baseline</text>')
lines.append(f'<line x1="{width-220}" y1="{legend_y-5}" x2="{width-200}" y2="{legend_y-5}" stroke="#059669" stroke-width="2"/>')
lines.append(f'<text x="{width-195}" y="{legend_y}" font-size="12" font-family="Arial" fill="#111827">Improved</text>')
lines.extend(_svg_footer())
(ARTIFACT_DIR / "violations_before_after.svg").write_text("\n".join(lines))
def main() -> None:
rows = _load_rows()
plot_reward_by_task(rows)
plot_violation_before_after(rows)
print("Wrote SVG plots to", ARTIFACT_DIR)
if __name__ == "__main__":
main()
|