toxic-royale-env / scripts /run_rollouts.py
omm7's picture
Upload folder using huggingface_hub
05a09dc verified
from __future__ import annotations
import argparse
import random
from pathlib import Path
from toxic_royale_env.logging_utils import StepLogRow, ensure_outputs_dirs, write_step_logs_csv
from toxic_royale_env.models import ToxicRoyaleAction
from toxic_royale_env.plot_utils import plot_training_curves
from toxic_royale_env.server.toxic_royale_env_environment import ToxicRoyaleEnvironment
def pick_action(obs_state: dict, rng: random.Random) -> ToxicRoyaleAction:
"""
A simple baseline agent:
- 20% wait
- else play a random card from hand into a reasonable zone
- emotes sparingly
This is *not* the RL-trained policy; it's just to generate logs/plots for debugging
and to create "before training" baselines for the demo.
"""
hand = list(obs_state["my_hand"])
if rng.random() < 0.20:
return ToxicRoyaleAction(kind="wait", emote=("yawn" if rng.random() < 0.1 else None))
card = rng.choice(hand)
zone = rng.choice(["back_left", "back_right", "mid_left", "mid_right", "bridge_left", "bridge_right"])
emote = None
if rng.random() < 0.08:
emote = rng.choice(["laugh", "yawn", "cry", "thanks", "chicken", "wp"])
return ToxicRoyaleAction(kind="play", card=card, zone=zone, emote=emote)
def main(argv: list[str] | None = None) -> int:
ap = argparse.ArgumentParser()
ap.add_argument("--episodes", type=int, default=20)
ap.add_argument("--max-steps", type=int, default=240)
ap.add_argument("--seed", type=int, default=7)
args = ap.parse_args(argv)
root = Path(__file__).resolve().parents[1]
logs_dir, plots_dir = ensure_outputs_dirs(root)
rng = random.Random(args.seed)
env = ToxicRoyaleEnvironment()
rows: list[StepLogRow] = []
for ep in range(args.episodes):
obs = env.reset()
invalid = 0
for step in range(args.max_steps):
action = pick_action(obs.state, rng)
if bool(obs.state.get("invalid_action_last")):
invalid += 1
obs = env.step(action)
rb = obs.reward_breakdown or {}
rows.append(
StepLogRow(
episode=ep,
step=step,
reward_total=float(obs.reward),
done=bool(obs.done),
my_crowns=int(obs.state.get("my_crowns", 0)),
opp_crowns=int(obs.state.get("opp_crowns", 0)),
opp_tilt=float(obs.state.get("opp_tilt_meter", 0.0)),
my_elixir=float(obs.state.get("my_elixir", 0.0)),
tower_damage=float(rb.get("tower_damage", 0.0)),
crown_differential=float(rb.get("crown_differential", 0.0)),
tilt_efficiency=float(rb.get("tilt_efficiency", 0.0)),
elixir_discipline=float(rb.get("elixir_discipline", 0.0)),
step_reward=float(rb.get("step", 0.0)),
invalid_action=invalid,
)
)
if obs.done:
break
csv_path = logs_dir / "rollouts_metrics.csv"
write_step_logs_csv(csv_path, rows)
p1, p2 = plot_training_curves(csv_path, plots_dir)
print(f"Wrote {csv_path}")
print(f"Wrote {p1}")
print(f"Wrote {p2}")
return 0
if __name__ == "__main__":
raise SystemExit(main())