Spaces:
Sleeping
Sleeping
| from __future__ import annotations | |
| import argparse | |
| import random | |
| from pathlib import Path | |
| from toxic_royale_env.logging_utils import StepLogRow, ensure_outputs_dirs, write_step_logs_csv | |
| from toxic_royale_env.models import ToxicRoyaleAction | |
| from toxic_royale_env.plot_utils import plot_training_curves | |
| from toxic_royale_env.server.toxic_royale_env_environment import ToxicRoyaleEnvironment | |
| def pick_action(obs_state: dict, rng: random.Random) -> ToxicRoyaleAction: | |
| """ | |
| A simple baseline agent: | |
| - 20% wait | |
| - else play a random card from hand into a reasonable zone | |
| - emotes sparingly | |
| This is *not* the RL-trained policy; it's just to generate logs/plots for debugging | |
| and to create "before training" baselines for the demo. | |
| """ | |
| hand = list(obs_state["my_hand"]) | |
| if rng.random() < 0.20: | |
| return ToxicRoyaleAction(kind="wait", emote=("yawn" if rng.random() < 0.1 else None)) | |
| card = rng.choice(hand) | |
| zone = rng.choice(["back_left", "back_right", "mid_left", "mid_right", "bridge_left", "bridge_right"]) | |
| emote = None | |
| if rng.random() < 0.08: | |
| emote = rng.choice(["laugh", "yawn", "cry", "thanks", "chicken", "wp"]) | |
| return ToxicRoyaleAction(kind="play", card=card, zone=zone, emote=emote) | |
| def main(argv: list[str] | None = None) -> int: | |
| ap = argparse.ArgumentParser() | |
| ap.add_argument("--episodes", type=int, default=20) | |
| ap.add_argument("--max-steps", type=int, default=240) | |
| ap.add_argument("--seed", type=int, default=7) | |
| args = ap.parse_args(argv) | |
| root = Path(__file__).resolve().parents[1] | |
| logs_dir, plots_dir = ensure_outputs_dirs(root) | |
| rng = random.Random(args.seed) | |
| env = ToxicRoyaleEnvironment() | |
| rows: list[StepLogRow] = [] | |
| for ep in range(args.episodes): | |
| obs = env.reset() | |
| invalid = 0 | |
| for step in range(args.max_steps): | |
| action = pick_action(obs.state, rng) | |
| if bool(obs.state.get("invalid_action_last")): | |
| invalid += 1 | |
| obs = env.step(action) | |
| rb = obs.reward_breakdown or {} | |
| rows.append( | |
| StepLogRow( | |
| episode=ep, | |
| step=step, | |
| reward_total=float(obs.reward), | |
| done=bool(obs.done), | |
| my_crowns=int(obs.state.get("my_crowns", 0)), | |
| opp_crowns=int(obs.state.get("opp_crowns", 0)), | |
| opp_tilt=float(obs.state.get("opp_tilt_meter", 0.0)), | |
| my_elixir=float(obs.state.get("my_elixir", 0.0)), | |
| tower_damage=float(rb.get("tower_damage", 0.0)), | |
| crown_differential=float(rb.get("crown_differential", 0.0)), | |
| tilt_efficiency=float(rb.get("tilt_efficiency", 0.0)), | |
| elixir_discipline=float(rb.get("elixir_discipline", 0.0)), | |
| step_reward=float(rb.get("step", 0.0)), | |
| invalid_action=invalid, | |
| ) | |
| ) | |
| if obs.done: | |
| break | |
| csv_path = logs_dir / "rollouts_metrics.csv" | |
| write_step_logs_csv(csv_path, rows) | |
| p1, p2 = plot_training_curves(csv_path, plots_dir) | |
| print(f"Wrote {csv_path}") | |
| print(f"Wrote {p1}") | |
| print(f"Wrote {p2}") | |
| return 0 | |
| if __name__ == "__main__": | |
| raise SystemExit(main()) | |