Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| from pathlib import Path | |
| def plot_training_curves(csv_path: Path, out_dir: Path) -> tuple[Path, Path]: | |
| """ | |
| Generate two judge-friendly PNGs from the CSV: | |
| - reward_curve.png: episode mean total reward | |
| - reward_components.png: episode mean tower_damage / crown_differential / tilt_efficiency | |
| """ | |
| import csv | |
| from PIL import Image, ImageDraw, ImageFont | |
| rows = [] | |
| with csv_path.open("r", encoding="utf-8") as f: | |
| r = csv.DictReader(f) | |
| for row in r: | |
| rows.append(row) | |
| # Aggregate by episode | |
| by_ep: dict[int, list[dict]] = {} | |
| for row in rows: | |
| ep = int(row["episode"]) | |
| by_ep.setdefault(ep, []).append(row) | |
| eps = sorted(by_ep.keys()) | |
| ep_mean_total = [] | |
| ep_mean_tower = [] | |
| ep_mean_crowns = [] | |
| ep_mean_tilt = [] | |
| ep_invalid_rate = [] | |
| for ep in eps: | |
| rr = by_ep[ep] | |
| n = max(1, len(rr)) | |
| ep_mean_total.append(sum(float(x["reward_total"]) for x in rr) / n) | |
| ep_mean_tower.append(sum(float(x["tower_damage"]) for x in rr) / n) | |
| ep_mean_crowns.append(sum(float(x["crown_differential"]) for x in rr) / n) | |
| ep_mean_tilt.append(sum(float(x["tilt_efficiency"]) for x in rr) / n) | |
| ep_invalid_rate.append(sum(int(x["invalid_action"]) for x in rr) / n) | |
| out_dir.mkdir(parents=True, exist_ok=True) | |
| p1 = out_dir / "reward_curve.png" | |
| p2 = out_dir / "reward_components.png" | |
| _line_plot_png( | |
| out_path=p1, | |
| title="ToxicRoyale — Reward & Invalid Action Rate", | |
| x=eps, | |
| series=[ | |
| ("mean total reward", ep_mean_total), | |
| ("invalid action rate", ep_invalid_rate), | |
| ], | |
| y_label="value", | |
| ) | |
| _line_plot_png( | |
| out_path=p2, | |
| title="ToxicRoyale — Reward Components (Episode Mean)", | |
| x=eps, | |
| series=[ | |
| ("tower_damage", ep_mean_tower), | |
| ("crown_differential", ep_mean_crowns), | |
| ("tilt_efficiency", ep_mean_tilt), | |
| ], | |
| y_label="mean component value", | |
| ) | |
| return p1, p2 | |
| def _line_plot_png( | |
| *, | |
| out_path: Path, | |
| title: str, | |
| x: list[int], | |
| series: list[tuple[str, list[float]]], | |
| y_label: str, | |
| width: int = 1100, | |
| height: int = 520, | |
| ) -> None: | |
| from PIL import Image, ImageDraw, ImageFont | |
| # Canvas | |
| img = Image.new("RGB", (width, height), (255, 255, 255)) | |
| d = ImageDraw.Draw(img) | |
| # Margins | |
| left, right, top, bottom = 70, 20, 55, 55 | |
| plot_w = width - left - right | |
| plot_h = height - top - bottom | |
| # Font (best-effort default) | |
| try: | |
| font = ImageFont.truetype("Arial.ttf", 16) | |
| font_b = ImageFont.truetype("Arial Bold.ttf", 18) | |
| except Exception: | |
| font = ImageFont.load_default() | |
| font_b = font | |
| # Title | |
| d.text((left, 15), title, fill=(0, 0, 0), font=font_b) | |
| # Determine y-range across all series | |
| vals = [v for _, ys in series for v in ys if ys] | |
| if not vals: | |
| img.save(out_path) | |
| return | |
| y_min = min(vals) | |
| y_max = max(vals) | |
| if abs(y_max - y_min) < 1e-9: | |
| y_max = y_min + 1.0 | |
| pad = 0.05 * (y_max - y_min) | |
| y_min -= pad | |
| y_max += pad | |
| # Axes | |
| d.rectangle((left, top, left + plot_w, top + plot_h), outline=(0, 0, 0), width=2) | |
| d.text((10, top + plot_h / 2 - 8), y_label, fill=(0, 0, 0), font=font) | |
| # Grid + ticks | |
| for i in range(6): | |
| yy = top + int(plot_h * i / 5) | |
| d.line((left, yy, left + plot_w, yy), fill=(235, 235, 235)) | |
| y_val = y_max - (y_max - y_min) * i / 5 | |
| d.text((left - 65, yy - 7), f"{y_val: .2f}", fill=(0, 0, 0), font=font) | |
| # X ticks | |
| if len(x) >= 2: | |
| for i in range(min(6, len(x))): | |
| idx = int(i * (len(x) - 1) / 5) if len(x) > 1 else 0 | |
| xx = left + int(plot_w * idx / (len(x) - 1)) | |
| d.line((xx, top + plot_h, xx, top + plot_h + 6), fill=(0, 0, 0)) | |
| d.text((xx - 8, top + plot_h + 10), str(x[idx]), fill=(0, 0, 0), font=font) | |
| # Colors | |
| palette = [(51, 102, 204), (220, 57, 18), (16, 150, 24), (153, 0, 153), (0, 153, 198)] | |
| def xy(i: int, yv: float) -> tuple[int, int]: | |
| if len(x) <= 1: | |
| x_norm = 0.0 | |
| else: | |
| x_norm = i / (len(x) - 1) | |
| y_norm = (yv - y_min) / (y_max - y_min) | |
| px = left + int(plot_w * x_norm) | |
| py = top + int(plot_h * (1.0 - y_norm)) | |
| return px, py | |
| # Plot lines | |
| for si, (label, ys) in enumerate(series): | |
| color = palette[si % len(palette)] | |
| pts = [xy(i, float(ys[i])) for i in range(len(ys))] | |
| if len(pts) >= 2: | |
| d.line(pts, fill=color, width=3) | |
| elif len(pts) == 1: | |
| px, py = pts[0] | |
| d.ellipse((px - 3, py - 3, px + 3, py + 3), fill=color) | |
| # Legend | |
| lx, ly = left + 10, top + 10 | |
| for si, (label, _) in enumerate(series): | |
| color = palette[si % len(palette)] | |
| d.rectangle((lx, ly + si * 22 + 5, lx + 14, ly + si * 22 + 19), fill=color) | |
| d.text((lx + 20, ly + si * 22 + 4), label, fill=(0, 0, 0), font=font) | |
| img.save(out_path) | |