Spaces:
Sleeping
Sleeping
File size: 5,140 Bytes
05a09dc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 | from __future__ import annotations
from pathlib import Path
def plot_training_curves(csv_path: Path, out_dir: Path) -> tuple[Path, Path]:
"""
Generate two judge-friendly PNGs from the CSV:
- reward_curve.png: episode mean total reward
- reward_components.png: episode mean tower_damage / crown_differential / tilt_efficiency
"""
import csv
from PIL import Image, ImageDraw, ImageFont
rows = []
with csv_path.open("r", encoding="utf-8") as f:
r = csv.DictReader(f)
for row in r:
rows.append(row)
# Aggregate by episode
by_ep: dict[int, list[dict]] = {}
for row in rows:
ep = int(row["episode"])
by_ep.setdefault(ep, []).append(row)
eps = sorted(by_ep.keys())
ep_mean_total = []
ep_mean_tower = []
ep_mean_crowns = []
ep_mean_tilt = []
ep_invalid_rate = []
for ep in eps:
rr = by_ep[ep]
n = max(1, len(rr))
ep_mean_total.append(sum(float(x["reward_total"]) for x in rr) / n)
ep_mean_tower.append(sum(float(x["tower_damage"]) for x in rr) / n)
ep_mean_crowns.append(sum(float(x["crown_differential"]) for x in rr) / n)
ep_mean_tilt.append(sum(float(x["tilt_efficiency"]) for x in rr) / n)
ep_invalid_rate.append(sum(int(x["invalid_action"]) for x in rr) / n)
out_dir.mkdir(parents=True, exist_ok=True)
p1 = out_dir / "reward_curve.png"
p2 = out_dir / "reward_components.png"
_line_plot_png(
out_path=p1,
title="ToxicRoyale — Reward & Invalid Action Rate",
x=eps,
series=[
("mean total reward", ep_mean_total),
("invalid action rate", ep_invalid_rate),
],
y_label="value",
)
_line_plot_png(
out_path=p2,
title="ToxicRoyale — Reward Components (Episode Mean)",
x=eps,
series=[
("tower_damage", ep_mean_tower),
("crown_differential", ep_mean_crowns),
("tilt_efficiency", ep_mean_tilt),
],
y_label="mean component value",
)
return p1, p2
def _line_plot_png(
*,
out_path: Path,
title: str,
x: list[int],
series: list[tuple[str, list[float]]],
y_label: str,
width: int = 1100,
height: int = 520,
) -> None:
from PIL import Image, ImageDraw, ImageFont
# Canvas
img = Image.new("RGB", (width, height), (255, 255, 255))
d = ImageDraw.Draw(img)
# Margins
left, right, top, bottom = 70, 20, 55, 55
plot_w = width - left - right
plot_h = height - top - bottom
# Font (best-effort default)
try:
font = ImageFont.truetype("Arial.ttf", 16)
font_b = ImageFont.truetype("Arial Bold.ttf", 18)
except Exception:
font = ImageFont.load_default()
font_b = font
# Title
d.text((left, 15), title, fill=(0, 0, 0), font=font_b)
# Determine y-range across all series
vals = [v for _, ys in series for v in ys if ys]
if not vals:
img.save(out_path)
return
y_min = min(vals)
y_max = max(vals)
if abs(y_max - y_min) < 1e-9:
y_max = y_min + 1.0
pad = 0.05 * (y_max - y_min)
y_min -= pad
y_max += pad
# Axes
d.rectangle((left, top, left + plot_w, top + plot_h), outline=(0, 0, 0), width=2)
d.text((10, top + plot_h / 2 - 8), y_label, fill=(0, 0, 0), font=font)
# Grid + ticks
for i in range(6):
yy = top + int(plot_h * i / 5)
d.line((left, yy, left + plot_w, yy), fill=(235, 235, 235))
y_val = y_max - (y_max - y_min) * i / 5
d.text((left - 65, yy - 7), f"{y_val: .2f}", fill=(0, 0, 0), font=font)
# X ticks
if len(x) >= 2:
for i in range(min(6, len(x))):
idx = int(i * (len(x) - 1) / 5) if len(x) > 1 else 0
xx = left + int(plot_w * idx / (len(x) - 1))
d.line((xx, top + plot_h, xx, top + plot_h + 6), fill=(0, 0, 0))
d.text((xx - 8, top + plot_h + 10), str(x[idx]), fill=(0, 0, 0), font=font)
# Colors
palette = [(51, 102, 204), (220, 57, 18), (16, 150, 24), (153, 0, 153), (0, 153, 198)]
def xy(i: int, yv: float) -> tuple[int, int]:
if len(x) <= 1:
x_norm = 0.0
else:
x_norm = i / (len(x) - 1)
y_norm = (yv - y_min) / (y_max - y_min)
px = left + int(plot_w * x_norm)
py = top + int(plot_h * (1.0 - y_norm))
return px, py
# Plot lines
for si, (label, ys) in enumerate(series):
color = palette[si % len(palette)]
pts = [xy(i, float(ys[i])) for i in range(len(ys))]
if len(pts) >= 2:
d.line(pts, fill=color, width=3)
elif len(pts) == 1:
px, py = pts[0]
d.ellipse((px - 3, py - 3, px + 3, py + 3), fill=color)
# Legend
lx, ly = left + 10, top + 10
for si, (label, _) in enumerate(series):
color = palette[si % len(palette)]
d.rectangle((lx, ly + si * 22 + 5, lx + 14, ly + si * 22 + 19), fill=color)
d.text((lx + 20, ly + si * 22 + 4), label, fill=(0, 0, 0), font=font)
img.save(out_path)
|