from __future__ import annotations from pathlib import Path def plot_training_curves(csv_path: Path, out_dir: Path) -> tuple[Path, Path]: """ Generate two judge-friendly PNGs from the CSV: - reward_curve.png: episode mean total reward - reward_components.png: episode mean tower_damage / crown_differential / tilt_efficiency """ import csv from PIL import Image, ImageDraw, ImageFont rows = [] with csv_path.open("r", encoding="utf-8") as f: r = csv.DictReader(f) for row in r: rows.append(row) # Aggregate by episode by_ep: dict[int, list[dict]] = {} for row in rows: ep = int(row["episode"]) by_ep.setdefault(ep, []).append(row) eps = sorted(by_ep.keys()) ep_mean_total = [] ep_mean_tower = [] ep_mean_crowns = [] ep_mean_tilt = [] ep_invalid_rate = [] for ep in eps: rr = by_ep[ep] n = max(1, len(rr)) ep_mean_total.append(sum(float(x["reward_total"]) for x in rr) / n) ep_mean_tower.append(sum(float(x["tower_damage"]) for x in rr) / n) ep_mean_crowns.append(sum(float(x["crown_differential"]) for x in rr) / n) ep_mean_tilt.append(sum(float(x["tilt_efficiency"]) for x in rr) / n) ep_invalid_rate.append(sum(int(x["invalid_action"]) for x in rr) / n) out_dir.mkdir(parents=True, exist_ok=True) p1 = out_dir / "reward_curve.png" p2 = out_dir / "reward_components.png" _line_plot_png( out_path=p1, title="ToxicRoyale — Reward & Invalid Action Rate", x=eps, series=[ ("mean total reward", ep_mean_total), ("invalid action rate", ep_invalid_rate), ], y_label="value", ) _line_plot_png( out_path=p2, title="ToxicRoyale — Reward Components (Episode Mean)", x=eps, series=[ ("tower_damage", ep_mean_tower), ("crown_differential", ep_mean_crowns), ("tilt_efficiency", ep_mean_tilt), ], y_label="mean component value", ) return p1, p2 def _line_plot_png( *, out_path: Path, title: str, x: list[int], series: list[tuple[str, list[float]]], y_label: str, width: int = 1100, height: int = 520, ) -> None: from PIL import Image, ImageDraw, ImageFont # Canvas img = Image.new("RGB", (width, height), (255, 255, 255)) d = ImageDraw.Draw(img) # Margins left, right, top, bottom = 70, 20, 55, 55 plot_w = width - left - right plot_h = height - top - bottom # Font (best-effort default) try: font = ImageFont.truetype("Arial.ttf", 16) font_b = ImageFont.truetype("Arial Bold.ttf", 18) except Exception: font = ImageFont.load_default() font_b = font # Title d.text((left, 15), title, fill=(0, 0, 0), font=font_b) # Determine y-range across all series vals = [v for _, ys in series for v in ys if ys] if not vals: img.save(out_path) return y_min = min(vals) y_max = max(vals) if abs(y_max - y_min) < 1e-9: y_max = y_min + 1.0 pad = 0.05 * (y_max - y_min) y_min -= pad y_max += pad # Axes d.rectangle((left, top, left + plot_w, top + plot_h), outline=(0, 0, 0), width=2) d.text((10, top + plot_h / 2 - 8), y_label, fill=(0, 0, 0), font=font) # Grid + ticks for i in range(6): yy = top + int(plot_h * i / 5) d.line((left, yy, left + plot_w, yy), fill=(235, 235, 235)) y_val = y_max - (y_max - y_min) * i / 5 d.text((left - 65, yy - 7), f"{y_val: .2f}", fill=(0, 0, 0), font=font) # X ticks if len(x) >= 2: for i in range(min(6, len(x))): idx = int(i * (len(x) - 1) / 5) if len(x) > 1 else 0 xx = left + int(plot_w * idx / (len(x) - 1)) d.line((xx, top + plot_h, xx, top + plot_h + 6), fill=(0, 0, 0)) d.text((xx - 8, top + plot_h + 10), str(x[idx]), fill=(0, 0, 0), font=font) # Colors palette = [(51, 102, 204), (220, 57, 18), (16, 150, 24), (153, 0, 153), (0, 153, 198)] def xy(i: int, yv: float) -> tuple[int, int]: if len(x) <= 1: x_norm = 0.0 else: x_norm = i / (len(x) - 1) y_norm = (yv - y_min) / (y_max - y_min) px = left + int(plot_w * x_norm) py = top + int(plot_h * (1.0 - y_norm)) return px, py # Plot lines for si, (label, ys) in enumerate(series): color = palette[si % len(palette)] pts = [xy(i, float(ys[i])) for i in range(len(ys))] if len(pts) >= 2: d.line(pts, fill=color, width=3) elif len(pts) == 1: px, py = pts[0] d.ellipse((px - 3, py - 3, px + 3, py + 3), fill=color) # Legend lx, ly = left + 10, top + 10 for si, (label, _) in enumerate(series): color = palette[si % len(palette)] d.rectangle((lx, ly + si * 22 + 5, lx + 14, ly + si * 22 + 19), fill=color) d.text((lx + 20, ly + si * 22 + 4), label, fill=(0, 0, 0), font=font) img.save(out_path)