toxic-royale-env / plot_utils.py
omm7's picture
Upload folder using huggingface_hub
05a09dc verified
from __future__ import annotations
from pathlib import Path
def plot_training_curves(csv_path: Path, out_dir: Path) -> tuple[Path, Path]:
"""
Generate two judge-friendly PNGs from the CSV:
- reward_curve.png: episode mean total reward
- reward_components.png: episode mean tower_damage / crown_differential / tilt_efficiency
"""
import csv
from PIL import Image, ImageDraw, ImageFont
rows = []
with csv_path.open("r", encoding="utf-8") as f:
r = csv.DictReader(f)
for row in r:
rows.append(row)
# Aggregate by episode
by_ep: dict[int, list[dict]] = {}
for row in rows:
ep = int(row["episode"])
by_ep.setdefault(ep, []).append(row)
eps = sorted(by_ep.keys())
ep_mean_total = []
ep_mean_tower = []
ep_mean_crowns = []
ep_mean_tilt = []
ep_invalid_rate = []
for ep in eps:
rr = by_ep[ep]
n = max(1, len(rr))
ep_mean_total.append(sum(float(x["reward_total"]) for x in rr) / n)
ep_mean_tower.append(sum(float(x["tower_damage"]) for x in rr) / n)
ep_mean_crowns.append(sum(float(x["crown_differential"]) for x in rr) / n)
ep_mean_tilt.append(sum(float(x["tilt_efficiency"]) for x in rr) / n)
ep_invalid_rate.append(sum(int(x["invalid_action"]) for x in rr) / n)
out_dir.mkdir(parents=True, exist_ok=True)
p1 = out_dir / "reward_curve.png"
p2 = out_dir / "reward_components.png"
_line_plot_png(
out_path=p1,
title="ToxicRoyale — Reward & Invalid Action Rate",
x=eps,
series=[
("mean total reward", ep_mean_total),
("invalid action rate", ep_invalid_rate),
],
y_label="value",
)
_line_plot_png(
out_path=p2,
title="ToxicRoyale — Reward Components (Episode Mean)",
x=eps,
series=[
("tower_damage", ep_mean_tower),
("crown_differential", ep_mean_crowns),
("tilt_efficiency", ep_mean_tilt),
],
y_label="mean component value",
)
return p1, p2
def _line_plot_png(
*,
out_path: Path,
title: str,
x: list[int],
series: list[tuple[str, list[float]]],
y_label: str,
width: int = 1100,
height: int = 520,
) -> None:
from PIL import Image, ImageDraw, ImageFont
# Canvas
img = Image.new("RGB", (width, height), (255, 255, 255))
d = ImageDraw.Draw(img)
# Margins
left, right, top, bottom = 70, 20, 55, 55
plot_w = width - left - right
plot_h = height - top - bottom
# Font (best-effort default)
try:
font = ImageFont.truetype("Arial.ttf", 16)
font_b = ImageFont.truetype("Arial Bold.ttf", 18)
except Exception:
font = ImageFont.load_default()
font_b = font
# Title
d.text((left, 15), title, fill=(0, 0, 0), font=font_b)
# Determine y-range across all series
vals = [v for _, ys in series for v in ys if ys]
if not vals:
img.save(out_path)
return
y_min = min(vals)
y_max = max(vals)
if abs(y_max - y_min) < 1e-9:
y_max = y_min + 1.0
pad = 0.05 * (y_max - y_min)
y_min -= pad
y_max += pad
# Axes
d.rectangle((left, top, left + plot_w, top + plot_h), outline=(0, 0, 0), width=2)
d.text((10, top + plot_h / 2 - 8), y_label, fill=(0, 0, 0), font=font)
# Grid + ticks
for i in range(6):
yy = top + int(plot_h * i / 5)
d.line((left, yy, left + plot_w, yy), fill=(235, 235, 235))
y_val = y_max - (y_max - y_min) * i / 5
d.text((left - 65, yy - 7), f"{y_val: .2f}", fill=(0, 0, 0), font=font)
# X ticks
if len(x) >= 2:
for i in range(min(6, len(x))):
idx = int(i * (len(x) - 1) / 5) if len(x) > 1 else 0
xx = left + int(plot_w * idx / (len(x) - 1))
d.line((xx, top + plot_h, xx, top + plot_h + 6), fill=(0, 0, 0))
d.text((xx - 8, top + plot_h + 10), str(x[idx]), fill=(0, 0, 0), font=font)
# Colors
palette = [(51, 102, 204), (220, 57, 18), (16, 150, 24), (153, 0, 153), (0, 153, 198)]
def xy(i: int, yv: float) -> tuple[int, int]:
if len(x) <= 1:
x_norm = 0.0
else:
x_norm = i / (len(x) - 1)
y_norm = (yv - y_min) / (y_max - y_min)
px = left + int(plot_w * x_norm)
py = top + int(plot_h * (1.0 - y_norm))
return px, py
# Plot lines
for si, (label, ys) in enumerate(series):
color = palette[si % len(palette)]
pts = [xy(i, float(ys[i])) for i in range(len(ys))]
if len(pts) >= 2:
d.line(pts, fill=color, width=3)
elif len(pts) == 1:
px, py = pts[0]
d.ellipse((px - 3, py - 3, px + 3, py + 3), fill=color)
# Legend
lx, ly = left + 10, top + 10
for si, (label, _) in enumerate(series):
color = palette[si % len(palette)]
d.rectangle((lx, ly + si * 22 + 5, lx + 14, ly + si * 22 + 19), fill=color)
d.text((lx + 20, ly + si * 22 + 4), label, fill=(0, 0, 0), font=font)
img.save(out_path)