Spaces:
Sleeping
Sleeping
Atharva commited on
Commit ·
a8df3de
0
Parent(s):
Initial hackathon submission export
Browse files- .gitignore +21 -0
- README.md +195 -0
- examples/run_single_episode.py +52 -0
- pyproject.toml +30 -0
- src/smogon_rl/__init__.py +13 -0
- src/smogon_rl/action_space.py +138 -0
- src/smogon_rl/config.py +28 -0
- src/smogon_rl/openenv_sync_env.py +260 -0
- src/smogon_rl/pokeenv_client.py +304 -0
- src/smogon_rl/reward.py +320 -0
- src/smogon_rl/state_formatter.py +181 -0
- trainer.ipynb +0 -0
.gitignore
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Environment and secrets — never commit
|
| 2 |
+
.env
|
| 3 |
+
.env.local
|
| 4 |
+
*.env
|
| 5 |
+
|
| 6 |
+
# Python
|
| 7 |
+
__pycache__/
|
| 8 |
+
*.py[cod]
|
| 9 |
+
*.egg-info/
|
| 10 |
+
.eggs/
|
| 11 |
+
dist/
|
| 12 |
+
build/
|
| 13 |
+
|
| 14 |
+
# Virtual environments
|
| 15 |
+
.venv/
|
| 16 |
+
venv/
|
| 17 |
+
env/
|
| 18 |
+
|
| 19 |
+
# IDE / OS
|
| 20 |
+
.idea/
|
| 21 |
+
.DS_Store
|
README.md
ADDED
|
@@ -0,0 +1,195 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# OpenEnv-WolfeClick
|
| 2 |
+
|
| 3 |
+
OpenEnv-WolfeClick is a reinforcement learning environment and training workflow for competitive Pokemon battles with large language models.
|
| 4 |
+
|
| 5 |
+
The project was built for the OpenEnv hackathon to answer a specific question: can an LLM learn to act in a partially observable, adversarial, long-horizon environment where legal actions are constrained, rewards are delayed, and the opponent is another agent?
|
| 6 |
+
|
| 7 |
+
This repo focuses on that environment and a minimal Colab training path.
|
| 8 |
+
|
| 9 |
+
## Why I Built This
|
| 10 |
+
|
| 11 |
+
Pokemon battles are a strong multi-agent training environment for LLMs because they require:
|
| 12 |
+
|
| 13 |
+
- hidden information and opponent modeling
|
| 14 |
+
- long-horizon planning over many turns
|
| 15 |
+
- legal action grounding under a constrained action space
|
| 16 |
+
- adapting to a changing world state after every action
|
| 17 |
+
- balancing local rewards against later consequences
|
| 18 |
+
|
| 19 |
+
I built this environment to make those properties trainable with a simple `reset()` / `step()` loop and a small JSON action interface.
|
| 20 |
+
|
| 21 |
+
## What is in this repo
|
| 22 |
+
|
| 23 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl): environment, state formatting, action space, reward shaping, and client code
|
| 24 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb): main Colab notebook for warm-up SFT, rollout collection, and GRPO training
|
| 25 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/examples`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/examples): small local examples
|
| 26 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/pyproject.toml`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/pyproject.toml): package metadata
|
| 27 |
+
|
| 28 |
+
## Environment design
|
| 29 |
+
|
| 30 |
+
### State design
|
| 31 |
+
|
| 32 |
+
The state is not a raw simulator dump. It is a structured markdown representation designed to preserve strategic information while remaining readable to an LLM.
|
| 33 |
+
|
| 34 |
+
Each prompt includes:
|
| 35 |
+
|
| 36 |
+
- active self Pokemon
|
| 37 |
+
- active opponent Pokemon
|
| 38 |
+
- HP, status, ability, item, and current stat modifiers
|
| 39 |
+
- full self team roster with currently known moves
|
| 40 |
+
- opponent history and revealed information
|
| 41 |
+
- exact legal actions available this turn
|
| 42 |
+
|
| 43 |
+
This is implemented through the environment wrapper and state formatter:
|
| 44 |
+
|
| 45 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py)
|
| 46 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/state_formatter.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/state_formatter.py)
|
| 47 |
+
|
| 48 |
+
My design goal was to expose enough information for strategic decisions without giving the model shortcuts that bypass the game structure.
|
| 49 |
+
|
| 50 |
+
### Action design
|
| 51 |
+
|
| 52 |
+
The action space is deliberately constrained.
|
| 53 |
+
|
| 54 |
+
The model must emit exactly one JSON object:
|
| 55 |
+
|
| 56 |
+
```json
|
| 57 |
+
{"action": "move" | "switch", "choice": "Exact Name of Move or Pokemon"}
|
| 58 |
+
```
|
| 59 |
+
|
| 60 |
+
At every step, legal actions are enumerated from the current battle state using:
|
| 61 |
+
|
| 62 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/action_space.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/action_space.py)
|
| 63 |
+
|
| 64 |
+
This module does three important things:
|
| 65 |
+
|
| 66 |
+
- enumerates legal moves and switches for the turn
|
| 67 |
+
- builds the action instruction block shown to the model
|
| 68 |
+
- validates model outputs against the legal action set
|
| 69 |
+
|
| 70 |
+
This matters because I do not want the model to “sort of” describe an action. I want the environment to enforce a concrete legal interface.
|
| 71 |
+
|
| 72 |
+
### Reward design
|
| 73 |
+
|
| 74 |
+
The environment reward is shaped but still tied to battle outcomes.
|
| 75 |
+
|
| 76 |
+
Reward computation lives in:
|
| 77 |
+
|
| 78 |
+
- [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/reward.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/reward.py)
|
| 79 |
+
|
| 80 |
+
The reward includes:
|
| 81 |
+
|
| 82 |
+
- damage dealt to the opponent
|
| 83 |
+
- damage taken by the agent
|
| 84 |
+
- knockouts and faint penalties
|
| 85 |
+
- healing value
|
| 86 |
+
- setup value and opponent setup penalties
|
| 87 |
+
- passive damage value
|
| 88 |
+
- status penalties
|
| 89 |
+
|
| 90 |
+
The environment wrapper in [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py) adds practical rollout constraints:
|
| 91 |
+
|
| 92 |
+
- illegal action fallback handling
|
| 93 |
+
- illegal action penalties
|
| 94 |
+
- anti-stall living penalty
|
| 95 |
+
- battle length caps
|
| 96 |
+
- no-progress termination penalties
|
| 97 |
+
|
| 98 |
+
This separation is intentional:
|
| 99 |
+
|
| 100 |
+
- `reward.py` captures battle-quality shaping
|
| 101 |
+
- the env wrapper handles rollout hygiene and training throughput
|
| 102 |
+
|
| 103 |
+
## Training design
|
| 104 |
+
|
| 105 |
+
### 1. Warm-up SFT
|
| 106 |
+
|
| 107 |
+
The notebook begins with a supervised warm-up stage so the model learns to emit valid action JSON for the battle-state prompt format.
|
| 108 |
+
|
| 109 |
+
This does not claim strategic mastery. It only ensures the model is good enough to participate in the environment without collapsing into malformed outputs.
|
| 110 |
+
|
| 111 |
+
### 2. Real rollout collection
|
| 112 |
+
|
| 113 |
+
The policy is then run in real Pokemon Showdown battles. For each turn, the notebook stores:
|
| 114 |
+
|
| 115 |
+
- `prompt`
|
| 116 |
+
- `collected_action`
|
| 117 |
+
- `collected_reward`
|
| 118 |
+
|
| 119 |
+
This makes the rollout data usable for GRPO training while preserving the exact environment reward signal.
|
| 120 |
+
|
| 121 |
+
### 3. GRPO training
|
| 122 |
+
|
| 123 |
+
The GRPO reward used in the notebook is a wrapper around the stored rollout reward.
|
| 124 |
+
|
| 125 |
+
It is designed to preserve ranking pressure inside a completion group:
|
| 126 |
+
|
| 127 |
+
- malformed output is penalized strongly
|
| 128 |
+
- valid but different actions are penalized lightly
|
| 129 |
+
- the action matching the executed rollout action receives the collected environment reward plus a positive margin
|
| 130 |
+
|
| 131 |
+
That matters because raw rollout rewards alone do not always create a clean learning signal for group-relative optimization.
|
| 132 |
+
|
| 133 |
+
## How it works end to end
|
| 134 |
+
|
| 135 |
+
1. Start Pokemon Showdown locally in Colab.
|
| 136 |
+
2. Create the OpenEnv-style synchronous environment.
|
| 137 |
+
3. Format battle state into markdown.
|
| 138 |
+
4. Enumerate legal actions.
|
| 139 |
+
5. Generate one JSON action from the model.
|
| 140 |
+
6. Execute the action in the environment.
|
| 141 |
+
7. Receive next state, reward, done flag, and info.
|
| 142 |
+
8. Store rollout rows.
|
| 143 |
+
9. Train with GRPO on the collected rows.
|
| 144 |
+
|
| 145 |
+
## How to use
|
| 146 |
+
|
| 147 |
+
### Local package install
|
| 148 |
+
|
| 149 |
+
From the repo root:
|
| 150 |
+
|
| 151 |
+
```bash
|
| 152 |
+
python3 -m pip install -e .
|
| 153 |
+
```
|
| 154 |
+
|
| 155 |
+
### Colab training
|
| 156 |
+
|
| 157 |
+
Open [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb) in Colab and run it top to bottom.
|
| 158 |
+
|
| 159 |
+
The notebook does the following:
|
| 160 |
+
|
| 161 |
+
1. clones or uses the repo
|
| 162 |
+
2. installs the training stack
|
| 163 |
+
3. loads the model and LoRA adapter
|
| 164 |
+
4. starts a local Pokemon Showdown server
|
| 165 |
+
5. runs JSON warm-up SFT
|
| 166 |
+
6. collects rollout data from real battles
|
| 167 |
+
7. trains with GRPO
|
| 168 |
+
8. optionally saves the adapter to Hugging Face Hub
|
| 169 |
+
|
| 170 |
+
### Requirements
|
| 171 |
+
|
| 172 |
+
- GPU runtime in Colab
|
| 173 |
+
- local Pokemon Showdown server started from the notebook
|
| 174 |
+
- Hugging Face token only if you want to push adapters
|
| 175 |
+
|
| 176 |
+
## Current status
|
| 177 |
+
|
| 178 |
+
This repo now has a working end-to-end path where:
|
| 179 |
+
|
| 180 |
+
- real battle rollouts are collected from the environment
|
| 181 |
+
- valid action JSON is produced reliably after warm-up
|
| 182 |
+
- GRPO can train on real rollout data in the non-quantized plain TRL path
|
| 183 |
+
|
| 184 |
+
This is the basis for my hackathon demo and benchmark runs.
|
| 185 |
+
|
| 186 |
+
## Submission notes
|
| 187 |
+
|
| 188 |
+
This repo is intended to be my clean hackathon submission repo.
|
| 189 |
+
|
| 190 |
+
Linked artifacts to add before submission:
|
| 191 |
+
|
| 192 |
+
- Hugging Face model repo
|
| 193 |
+
- Hugging Face Space using OpenEnv stable release `0.2.1`
|
| 194 |
+
- benchmark/results file
|
| 195 |
+
- 1-minute demo video
|
examples/run_single_episode.py
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from pathlib import Path
|
| 5 |
+
|
| 6 |
+
from smogon_rl.action_space import ActionOption, enumerate_actions
|
| 7 |
+
from smogon_rl.config import EnvConfig
|
| 8 |
+
from smogon_rl.openenv_sync_env import PokemonShowdownEnv
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def main() -> None:
|
| 12 |
+
config = EnvConfig()
|
| 13 |
+
env = PokemonShowdownEnv(config=config)
|
| 14 |
+
|
| 15 |
+
print("Starting a single gen4randombattle episode.")
|
| 16 |
+
obs = env.reset()
|
| 17 |
+
print("Initial state (truncated):")
|
| 18 |
+
print("\n".join(obs.splitlines()[:40]))
|
| 19 |
+
|
| 20 |
+
done = False
|
| 21 |
+
total_reward = 0.0
|
| 22 |
+
step_idx = 0
|
| 23 |
+
|
| 24 |
+
while not done and step_idx < config.max_steps_per_battle:
|
| 25 |
+
step_idx += 1
|
| 26 |
+
print(f"\n=== Step {step_idx} ===")
|
| 27 |
+
|
| 28 |
+
# Naive policy: query valid actions from the environment and always pick
|
| 29 |
+
# the first one. A real agent would send `obs` and `info["instructions"]`
|
| 30 |
+
# to an LLM and use its JSON response here.
|
| 31 |
+
battle = env._ensure_battle() # type: ignore[attr-defined]
|
| 32 |
+
valid_actions = enumerate_actions(battle)
|
| 33 |
+
if not valid_actions:
|
| 34 |
+
print("No valid actions available; terminating.")
|
| 35 |
+
break
|
| 36 |
+
|
| 37 |
+
chosen: ActionOption = valid_actions[0]
|
| 38 |
+
action_json = {"action": chosen.action_type, "choice": chosen.choice}
|
| 39 |
+
obs, reward, done, info = env.step(json.dumps(action_json))
|
| 40 |
+
|
| 41 |
+
total_reward += reward
|
| 42 |
+
print(f"Chosen action: {action_json}")
|
| 43 |
+
print(f"Reward: {reward:.3f}, Done: {done}")
|
| 44 |
+
print("State (truncated):")
|
| 45 |
+
print("\n".join(obs.splitlines()[:20]))
|
| 46 |
+
|
| 47 |
+
print(f"\nTotal reward: {total_reward}")
|
| 48 |
+
|
| 49 |
+
|
| 50 |
+
if __name__ == "__main__":
|
| 51 |
+
main()
|
| 52 |
+
|
pyproject.toml
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
[project]
|
| 2 |
+
name = "smogon-rl"
|
| 3 |
+
version = "0.1.0"
|
| 4 |
+
description = "Theory-of-Mind Pokémon RL environment using poke-env and OpenEnv."
|
| 5 |
+
readme = "README.md"
|
| 6 |
+
requires-python = ">=3.10"
|
| 7 |
+
authors = [
|
| 8 |
+
{ name = "Atharva" }
|
| 9 |
+
]
|
| 10 |
+
dependencies = [
|
| 11 |
+
"poke-env>=0.8.0,<0.9.0",
|
| 12 |
+
"numpy>=1.24.0",
|
| 13 |
+
"pydantic>=2.0.0",
|
| 14 |
+
]
|
| 15 |
+
|
| 16 |
+
[project.optional-dependencies]
|
| 17 |
+
dev = [
|
| 18 |
+
"pytest>=7.0.0",
|
| 19 |
+
"ruff>=0.5.0",
|
| 20 |
+
]
|
| 21 |
+
|
| 22 |
+
[build-system]
|
| 23 |
+
requires = ["hatchling"]
|
| 24 |
+
build-backend = "hatchling.build"
|
| 25 |
+
|
| 26 |
+
[tool.uv]
|
| 27 |
+
package = "smogon-rl"
|
| 28 |
+
|
| 29 |
+
[tool.uv.sources]
|
| 30 |
+
|
src/smogon_rl/__init__.py
ADDED
|
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Smogon-RL core package.
|
| 3 |
+
|
| 4 |
+
This package provides:
|
| 5 |
+
- An async poke-env client for Pokémon Showdown battles.
|
| 6 |
+
- A synchronous, OpenEnv-style wrapper exposing reset/step.
|
| 7 |
+
- State formatting, action space handling, and reward shaping utilities.
|
| 8 |
+
"""
|
| 9 |
+
|
| 10 |
+
from .config import DEFAULT_BATTLE_FORMAT
|
| 11 |
+
|
| 12 |
+
__all__ = ["DEFAULT_BATTLE_FORMAT"]
|
| 13 |
+
|
src/smogon_rl/action_space.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
import re
|
| 5 |
+
from dataclasses import dataclass
|
| 6 |
+
from typing import List, Literal, Optional
|
| 7 |
+
|
| 8 |
+
from pydantic import BaseModel, ValidationError
|
| 9 |
+
from poke_env.environment.battle import Battle
|
| 10 |
+
from poke_env.environment.move import Move
|
| 11 |
+
from poke_env.environment.pokemon import Pokemon
|
| 12 |
+
|
| 13 |
+
# Match a single JSON object with "action" and "choice" (handles <think>...</think> + JSON).
|
| 14 |
+
_ACTION_JSON_RE = re.compile(
|
| 15 |
+
r'\{\s*"action"\s*:\s*"(?:move|switch)"\s*,\s*"choice"\s*:\s*"[^"]*"\s*\}',
|
| 16 |
+
re.IGNORECASE,
|
| 17 |
+
)
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
ActionType = Literal["move", "switch"]
|
| 21 |
+
|
| 22 |
+
|
| 23 |
+
@dataclass
|
| 24 |
+
class ActionOption:
|
| 25 |
+
"""Concrete action option available in the current state."""
|
| 26 |
+
|
| 27 |
+
action_type: ActionType
|
| 28 |
+
choice: str
|
| 29 |
+
move: Optional[Move] = None
|
| 30 |
+
pokemon: Optional[Pokemon] = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ActionJSON(BaseModel):
|
| 34 |
+
"""Strict JSON schema the LLM must output."""
|
| 35 |
+
|
| 36 |
+
action: ActionType
|
| 37 |
+
choice: str
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def enumerate_actions(battle: Battle) -> List[ActionOption]:
|
| 41 |
+
"""Enumerate up to 4 moves and up to 5 switches for the current state."""
|
| 42 |
+
options: List[ActionOption] = []
|
| 43 |
+
|
| 44 |
+
# Moves
|
| 45 |
+
for move in battle.available_moves[:4]:
|
| 46 |
+
if getattr(move, "current_pp", 1) <= 0:
|
| 47 |
+
continue
|
| 48 |
+
choice = move.id
|
| 49 |
+
options.append(ActionOption(action_type="move", choice=choice, move=move))
|
| 50 |
+
|
| 51 |
+
# Switches
|
| 52 |
+
for pokemon in battle.available_switches[:5]:
|
| 53 |
+
if pokemon.fainted:
|
| 54 |
+
continue
|
| 55 |
+
choice = pokemon.species or pokemon.nickname or "Unknown"
|
| 56 |
+
options.append(
|
| 57 |
+
ActionOption(action_type="switch", choice=choice, pokemon=pokemon)
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return options
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def _normalize_choice(s: str) -> str:
|
| 64 |
+
"""Normalize choice for comparison: lowercase, spaces to hyphens (matches poke-env move ids)."""
|
| 65 |
+
return s.strip().lower().replace(" ", "-")
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def extract_action_json_from_text(text: str) -> Optional[str]:
|
| 69 |
+
"""Extract a single action JSON object from model output that may contain thinking or prose.
|
| 70 |
+
|
| 71 |
+
Strips think tags first, then looks for our schema in the remainder (or in the full string).
|
| 72 |
+
Returns the first matching JSON substring, or None if none found.
|
| 73 |
+
"""
|
| 74 |
+
if not text or not text.strip():
|
| 75 |
+
return None
|
| 76 |
+
# Strip think blocks first so we prefer content after thinking.
|
| 77 |
+
stripped = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
|
| 78 |
+
for candidate in (stripped, text):
|
| 79 |
+
match = _ACTION_JSON_RE.search(candidate)
|
| 80 |
+
if match:
|
| 81 |
+
return match.group(0)
|
| 82 |
+
return None
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def parse_llm_action(raw_output: str, valid_actions: List[ActionOption]) -> ActionJSON:
|
| 86 |
+
"""Parse and validate the LLM JSON output against the current action set.
|
| 87 |
+
|
| 88 |
+
The model must output:
|
| 89 |
+
{
|
| 90 |
+
"action": "move" | "switch",
|
| 91 |
+
"choice": "Exact Name of Move or Pokemon"
|
| 92 |
+
}
|
| 93 |
+
Choice matching is case-insensitive and normalizes spaces to hyphens so
|
| 94 |
+
"Flamethrower" and "Thunder Wave" match env ids "flamethrower" and "thunder-wave".
|
| 95 |
+
"""
|
| 96 |
+
try:
|
| 97 |
+
payload = json.loads(raw_output)
|
| 98 |
+
except json.JSONDecodeError as exc:
|
| 99 |
+
raise ValueError(f"Model output is not valid JSON: {exc}") from exc
|
| 100 |
+
|
| 101 |
+
try:
|
| 102 |
+
action = ActionJSON.model_validate(payload)
|
| 103 |
+
except ValidationError as exc:
|
| 104 |
+
raise ValueError(f"Model JSON does not match schema: {exc}") from exc
|
| 105 |
+
|
| 106 |
+
want_norm = _normalize_choice(action.choice)
|
| 107 |
+
matched = None
|
| 108 |
+
for a in valid_actions:
|
| 109 |
+
if a.action_type != action.action:
|
| 110 |
+
continue
|
| 111 |
+
if _normalize_choice(a.choice) == want_norm:
|
| 112 |
+
matched = a
|
| 113 |
+
break
|
| 114 |
+
if matched is None:
|
| 115 |
+
valid_desc = [
|
| 116 |
+
{"action": a.action_type, "choice": a.choice} for a in valid_actions
|
| 117 |
+
]
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Invalid action selection {action.model_dump()}. "
|
| 120 |
+
f"Valid options are: {valid_desc}"
|
| 121 |
+
)
|
| 122 |
+
# Return with the env's exact choice string so downstream uses the right id.
|
| 123 |
+
return ActionJSON(action=action.action, choice=matched.choice)
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
def build_action_instructions(valid_actions: List[ActionOption]) -> str:
|
| 127 |
+
"""Build a short instruction string describing the JSON schema and options."""
|
| 128 |
+
lines = [
|
| 129 |
+
"You must choose exactly one action and output pure JSON with this schema:",
|
| 130 |
+
"",
|
| 131 |
+
'{"action": "move" | "switch", "choice": "Exact Name of Move or Pokemon"}',
|
| 132 |
+
"",
|
| 133 |
+
"Valid options for this state:",
|
| 134 |
+
]
|
| 135 |
+
for opt in valid_actions:
|
| 136 |
+
lines.append(f"- action: {opt.action_type!r}, choice: {opt.choice!r}")
|
| 137 |
+
return "\n".join(lines)
|
| 138 |
+
|
src/smogon_rl/config.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
DEFAULT_BATTLE_FORMAT = "gen4randombattle"
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
@dataclass
|
| 10 |
+
class EnvConfig:
|
| 11 |
+
"""Configuration for the Pokémon RL environment."""
|
| 12 |
+
|
| 13 |
+
battle_format: str = DEFAULT_BATTLE_FORMAT
|
| 14 |
+
# Hard cap to prevent very long battles from dominating rollout wall-time.
|
| 15 |
+
max_steps_per_battle: int = 30
|
| 16 |
+
poll_interval_seconds: float = 0.2
|
| 17 |
+
open_timeout: float = 25.0
|
| 18 |
+
show_replays: bool = False
|
| 19 |
+
verbose_logging: bool = False
|
| 20 |
+
log_every_n_steps: int = 25
|
| 21 |
+
poll_heartbeat_seconds: float = 5.0
|
| 22 |
+
min_battle_reward: float = -100.0
|
| 23 |
+
max_no_progress_steps: int = 2
|
| 24 |
+
# Small per-step time penalty to bias toward faster, decisive games.
|
| 25 |
+
step_living_penalty: float = -0.05
|
| 26 |
+
# Additional truncation/timeout penalties.
|
| 27 |
+
no_progress_termination_penalty: float = -1.0
|
| 28 |
+
max_steps_termination_penalty: float = -2.0
|
src/smogon_rl/openenv_sync_env.py
ADDED
|
@@ -0,0 +1,260 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import json
|
| 4 |
+
from dataclasses import dataclass, field
|
| 5 |
+
from typing import Any, Dict, Optional, Tuple
|
| 6 |
+
|
| 7 |
+
from poke_env.environment.battle import Battle
|
| 8 |
+
from poke_env.player.player import Player
|
| 9 |
+
|
| 10 |
+
from .action_space import (
|
| 11 |
+
ActionJSON,
|
| 12 |
+
ActionOption,
|
| 13 |
+
build_action_instructions,
|
| 14 |
+
enumerate_actions,
|
| 15 |
+
extract_action_json_from_text,
|
| 16 |
+
parse_llm_action,
|
| 17 |
+
)
|
| 18 |
+
from .config import EnvConfig
|
| 19 |
+
from .pokeenv_client import PokeEnvClient
|
| 20 |
+
from .reward import (
|
| 21 |
+
BattleStateSummary,
|
| 22 |
+
ILLEGAL_ACTION_PENALTY,
|
| 23 |
+
RewardTrackingState,
|
| 24 |
+
calculate_reward,
|
| 25 |
+
count_new_passive_hits_for_turn,
|
| 26 |
+
summarize_battle_state,
|
| 27 |
+
)
|
| 28 |
+
from .state_formatter import OpponentHistoryTracker, format_battle_state
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
@dataclass
|
| 32 |
+
class PokemonShowdownEnv:
|
| 33 |
+
"""Synchronous, OpenEnv-style wrapper around a poke-env battle.
|
| 34 |
+
|
| 35 |
+
The environment exposes a simple Gymnasium-like / OpenEnv-like API:
|
| 36 |
+
|
| 37 |
+
obs = env.reset()
|
| 38 |
+
obs, reward, done, info = env.step(action_json_str)
|
| 39 |
+
|
| 40 |
+
where `action_json_str` is a JSON string describing a move or switch using
|
| 41 |
+
the constrained 9-action space.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
config: EnvConfig = field(default_factory=EnvConfig)
|
| 45 |
+
_client: PokeEnvClient = field(init=False)
|
| 46 |
+
_opponent_history: OpponentHistoryTracker = field(init=False)
|
| 47 |
+
_reward_trackers: RewardTrackingState = field(init=False)
|
| 48 |
+
_prev_state: Optional[BattleStateSummary] = field(init=False, default=None)
|
| 49 |
+
_steps_this_battle: int = field(init=False, default=0)
|
| 50 |
+
# Running total of passive hits — updated O(k) per step via the single-turn
|
| 51 |
+
# scanner, never by re-scanning the full observation history.
|
| 52 |
+
_cumulative_passive_hits: int = field(init=False, default=0)
|
| 53 |
+
_battle_index: int = field(init=False, default=0)
|
| 54 |
+
_battle_reward_total: float = field(init=False, default=0.0)
|
| 55 |
+
_no_progress_steps: int = field(init=False, default=0)
|
| 56 |
+
|
| 57 |
+
def __post_init__(self) -> None:
|
| 58 |
+
self._client = PokeEnvClient(config=self.config)
|
| 59 |
+
self._opponent_history = OpponentHistoryTracker()
|
| 60 |
+
self._reward_trackers = RewardTrackingState()
|
| 61 |
+
|
| 62 |
+
def _log(self, message: str) -> None:
|
| 63 |
+
if self.config.verbose_logging:
|
| 64 |
+
print(f"[PokemonShowdownEnv] {message}", flush=True)
|
| 65 |
+
|
| 66 |
+
# ------------------------------------------------------------------ API
|
| 67 |
+
|
| 68 |
+
def reset(self) -> str:
|
| 69 |
+
"""Start a new battle and return the initial markdown state."""
|
| 70 |
+
self._battle_index += 1
|
| 71 |
+
self._client.start_new_battle()
|
| 72 |
+
self._opponent_history = OpponentHistoryTracker()
|
| 73 |
+
self._reward_trackers = RewardTrackingState()
|
| 74 |
+
self._steps_this_battle = 0
|
| 75 |
+
self._cumulative_passive_hits = 0
|
| 76 |
+
self._battle_reward_total = 0.0
|
| 77 |
+
self._no_progress_steps = 0
|
| 78 |
+
|
| 79 |
+
battle = self._wait_for_battle_or_raise()
|
| 80 |
+
self._log(
|
| 81 |
+
f"Battle {self._battle_index} started at turn={battle.turn} "
|
| 82 |
+
f"(format={self.config.battle_format})."
|
| 83 |
+
)
|
| 84 |
+
self._prev_state = summarize_battle_state(battle, self._cumulative_passive_hits)
|
| 85 |
+
return format_battle_state(battle, self._opponent_history)
|
| 86 |
+
|
| 87 |
+
def step(self, action_json: str | Dict[str, Any]) -> Tuple[str, float, bool, Dict[str, Any]]:
|
| 88 |
+
"""Apply one action and return (state_str, reward, done, info)."""
|
| 89 |
+
battle = self._ensure_battle()
|
| 90 |
+
if battle.finished:
|
| 91 |
+
raise RuntimeError("Cannot call step() on a finished battle. Call reset().")
|
| 92 |
+
|
| 93 |
+
self._steps_this_battle += 1
|
| 94 |
+
if self._steps_this_battle > self.config.max_steps_per_battle:
|
| 95 |
+
return self._terminal_from_truncation(battle)
|
| 96 |
+
|
| 97 |
+
valid_actions = enumerate_actions(battle)
|
| 98 |
+
if isinstance(action_json, dict):
|
| 99 |
+
raw = json.dumps(action_json)
|
| 100 |
+
else:
|
| 101 |
+
raw = action_json
|
| 102 |
+
|
| 103 |
+
used_fallback = False
|
| 104 |
+
try:
|
| 105 |
+
parsed = parse_llm_action(raw, valid_actions)
|
| 106 |
+
order = self._to_battle_order(parsed, valid_actions, battle)
|
| 107 |
+
except ValueError:
|
| 108 |
+
extracted = extract_action_json_from_text(raw)
|
| 109 |
+
if extracted is not None:
|
| 110 |
+
try:
|
| 111 |
+
parsed = parse_llm_action(extracted, valid_actions)
|
| 112 |
+
order = self._to_battle_order(parsed, valid_actions, battle)
|
| 113 |
+
except ValueError:
|
| 114 |
+
used_fallback = True
|
| 115 |
+
else:
|
| 116 |
+
used_fallback = True
|
| 117 |
+
if used_fallback:
|
| 118 |
+
opt = valid_actions[0]
|
| 119 |
+
from poke_env.player import Player as PlayerCls
|
| 120 |
+
if opt.action_type == "move" and opt.move is not None:
|
| 121 |
+
order = PlayerCls.create_order(opt.move)
|
| 122 |
+
else:
|
| 123 |
+
order = PlayerCls.create_order(opt.pokemon)
|
| 124 |
+
|
| 125 |
+
previous_turn = battle.turn
|
| 126 |
+
self._client.send_action(order)
|
| 127 |
+
new_battle = self._client.wait_for_battle_update(previous_turn) or battle
|
| 128 |
+
|
| 129 |
+
# Increment the passive-hit counter by scanning only the turn that just
|
| 130 |
+
# resolved — O(k) where k = events on that single turn, not O(total turns).
|
| 131 |
+
self._cumulative_passive_hits += count_new_passive_hits_for_turn(
|
| 132 |
+
new_battle, previous_turn
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
prev_state = self._prev_state or summarize_battle_state(battle, self._cumulative_passive_hits)
|
| 136 |
+
curr_state = summarize_battle_state(new_battle, self._cumulative_passive_hits)
|
| 137 |
+
|
| 138 |
+
active = new_battle.active_pokemon
|
| 139 |
+
opponent_active = new_battle.opponent_active_pokemon
|
| 140 |
+
|
| 141 |
+
if used_fallback:
|
| 142 |
+
reward = ILLEGAL_ACTION_PENALTY
|
| 143 |
+
else:
|
| 144 |
+
reward = calculate_reward(
|
| 145 |
+
prev_state=prev_state,
|
| 146 |
+
curr_state=curr_state,
|
| 147 |
+
action=ActionJSON(action=parsed.action, choice=parsed.choice),
|
| 148 |
+
trackers=self._reward_trackers,
|
| 149 |
+
active=active,
|
| 150 |
+
opponent_active=opponent_active,
|
| 151 |
+
)
|
| 152 |
+
# Small time cost per turn to discourage excessively long battles.
|
| 153 |
+
reward += self.config.step_living_penalty
|
| 154 |
+
|
| 155 |
+
self._prev_state = curr_state
|
| 156 |
+
if new_battle.turn == previous_turn and not new_battle.finished:
|
| 157 |
+
self._no_progress_steps += 1
|
| 158 |
+
else:
|
| 159 |
+
self._no_progress_steps = 0
|
| 160 |
+
|
| 161 |
+
done_reason: Optional[str] = None
|
| 162 |
+
done = False
|
| 163 |
+
if new_battle.finished:
|
| 164 |
+
done = True
|
| 165 |
+
done_reason = "battle_finished"
|
| 166 |
+
elif self._steps_this_battle >= self.config.max_steps_per_battle:
|
| 167 |
+
done = True
|
| 168 |
+
done_reason = "max_steps"
|
| 169 |
+
reward += self.config.max_steps_termination_penalty
|
| 170 |
+
elif (self._battle_reward_total + reward) <= self.config.min_battle_reward:
|
| 171 |
+
done = True
|
| 172 |
+
done_reason = "min_battle_reward"
|
| 173 |
+
elif self._no_progress_steps >= self.config.max_no_progress_steps:
|
| 174 |
+
done = True
|
| 175 |
+
done_reason = "no_progress_timeout"
|
| 176 |
+
reward += self.config.no_progress_termination_penalty
|
| 177 |
+
|
| 178 |
+
self._battle_reward_total += reward
|
| 179 |
+
|
| 180 |
+
# If we terminate early (not a natural finished battle), forfeit cleanly
|
| 181 |
+
# so the next reset starts from a free player/session state.
|
| 182 |
+
if done and not new_battle.finished and done_reason in {
|
| 183 |
+
"max_steps",
|
| 184 |
+
"min_battle_reward",
|
| 185 |
+
"no_progress_timeout",
|
| 186 |
+
}:
|
| 187 |
+
try:
|
| 188 |
+
self._client.forfeit_current_battle()
|
| 189 |
+
except Exception:
|
| 190 |
+
pass
|
| 191 |
+
|
| 192 |
+
obs = format_battle_state(new_battle, self._opponent_history)
|
| 193 |
+
info: Dict[str, Any] = {
|
| 194 |
+
"turn": new_battle.turn,
|
| 195 |
+
"valid_actions": [
|
| 196 |
+
{"action": a.action_type, "choice": a.choice} for a in valid_actions
|
| 197 |
+
],
|
| 198 |
+
"instructions": build_action_instructions(valid_actions),
|
| 199 |
+
"battle_finished": new_battle.finished,
|
| 200 |
+
"reason": done_reason,
|
| 201 |
+
"action_illegal": used_fallback,
|
| 202 |
+
"battle_reward_total": self._battle_reward_total,
|
| 203 |
+
"no_progress_steps": self._no_progress_steps,
|
| 204 |
+
}
|
| 205 |
+
if self.config.verbose_logging:
|
| 206 |
+
should_log_step = (
|
| 207 |
+
used_fallback
|
| 208 |
+
or done
|
| 209 |
+
or self._steps_this_battle == 1
|
| 210 |
+
or self._steps_this_battle % max(1, self.config.log_every_n_steps) == 0
|
| 211 |
+
)
|
| 212 |
+
if should_log_step:
|
| 213 |
+
self._log(
|
| 214 |
+
f"battle={self._battle_index} step={self._steps_this_battle} "
|
| 215 |
+
f"turn={new_battle.turn} reward={reward:.3f} "
|
| 216 |
+
f"running_reward={self._battle_reward_total:.3f} "
|
| 217 |
+
f"illegal_action={used_fallback} done={done}"
|
| 218 |
+
)
|
| 219 |
+
return obs, reward, done, info
|
| 220 |
+
|
| 221 |
+
# ------------------------------------------------------------------ helpers
|
| 222 |
+
|
| 223 |
+
def _wait_for_battle_or_raise(self) -> Battle:
|
| 224 |
+
battle = self._client.battle
|
| 225 |
+
if battle is None:
|
| 226 |
+
battle = self._client.wait_for_battle_update(previous_turn=0)
|
| 227 |
+
if battle is None:
|
| 228 |
+
raise RuntimeError("Failed to obtain initial battle from poke-env.")
|
| 229 |
+
return battle
|
| 230 |
+
|
| 231 |
+
def _ensure_battle(self) -> Battle:
|
| 232 |
+
battle = self._client.battle
|
| 233 |
+
if battle is None:
|
| 234 |
+
raise RuntimeError("No active battle. Call reset() first.")
|
| 235 |
+
return battle
|
| 236 |
+
|
| 237 |
+
def _terminal_from_truncation(self, battle: Battle) -> Tuple[str, float, bool, Dict[str, Any]]:
|
| 238 |
+
obs = format_battle_state(battle, self._opponent_history)
|
| 239 |
+
info: Dict[str, Any] = {
|
| 240 |
+
"turn": battle.turn,
|
| 241 |
+
"battle_finished": battle.finished,
|
| 242 |
+
"reason": "max_steps",
|
| 243 |
+
}
|
| 244 |
+
return obs, self.config.max_steps_termination_penalty, True, info
|
| 245 |
+
|
| 246 |
+
@staticmethod
|
| 247 |
+
def _to_battle_order(
|
| 248 |
+
parsed: ActionJSON,
|
| 249 |
+
valid_actions: list[ActionOption],
|
| 250 |
+
battle: Battle,
|
| 251 |
+
) -> "Player.create_order.__annotations__['return']":
|
| 252 |
+
from poke_env.player import Player as PlayerCls
|
| 253 |
+
|
| 254 |
+
for opt in valid_actions:
|
| 255 |
+
if opt.action_type == parsed.action and opt.choice == parsed.choice:
|
| 256 |
+
if opt.action_type == "move" and opt.move is not None:
|
| 257 |
+
return PlayerCls.create_order(opt.move)
|
| 258 |
+
if opt.action_type == "switch" and opt.pokemon is not None:
|
| 259 |
+
return PlayerCls.create_order(opt.pokemon)
|
| 260 |
+
raise ValueError(f"Could not map parsed action {parsed} to a BattleOrder")
|
src/smogon_rl/pokeenv_client.py
ADDED
|
@@ -0,0 +1,304 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
import asyncio
|
| 4 |
+
import threading
|
| 5 |
+
import time
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
from typing import Optional
|
| 8 |
+
|
| 9 |
+
from poke_env.environment.battle import Battle
|
| 10 |
+
from poke_env.player import Player, RandomPlayer
|
| 11 |
+
from poke_env.player.battle_order import BattleOrder
|
| 12 |
+
from poke_env.ps_client.server_configuration import LocalhostServerConfiguration
|
| 13 |
+
|
| 14 |
+
from .config import DEFAULT_BATTLE_FORMAT, EnvConfig
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class RLPlayer(Player):
|
| 18 |
+
"""Player controlled externally via an asyncio queue of BattleOrders."""
|
| 19 |
+
|
| 20 |
+
def __init__(self, action_queue: "asyncio.Queue[BattleOrder]", **kwargs) -> None:
|
| 21 |
+
super().__init__(**kwargs)
|
| 22 |
+
self._action_queue: "asyncio.Queue[BattleOrder]" = action_queue
|
| 23 |
+
|
| 24 |
+
async def choose_move(self, battle: Battle) -> BattleOrder:
|
| 25 |
+
return await self._action_queue.get()
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@dataclass
|
| 29 |
+
class PokeEnvClient:
|
| 30 |
+
"""Asynchronous client that manages poke-env battles in a background loop.
|
| 31 |
+
|
| 32 |
+
Players are created ONCE when the loop starts and reused across battles to
|
| 33 |
+
avoid Showdown nametaken errors from zombie connections.
|
| 34 |
+
"""
|
| 35 |
+
|
| 36 |
+
config: EnvConfig
|
| 37 |
+
|
| 38 |
+
def __post_init__(self) -> None:
|
| 39 |
+
self._loop: Optional[asyncio.AbstractEventLoop] = None
|
| 40 |
+
self._thread: Optional[threading.Thread] = None
|
| 41 |
+
self._action_queue: Optional["asyncio.Queue[BattleOrder]"] = None
|
| 42 |
+
self._rl_player: Optional[RLPlayer] = None
|
| 43 |
+
self._opponent: Optional[RandomPlayer] = None
|
| 44 |
+
self._battle_task: Optional[asyncio.Future] = None
|
| 45 |
+
# Snapshot of existing battle tags before we request a new battle.
|
| 46 |
+
self._known_battle_tags: set[str] = set()
|
| 47 |
+
self._awaiting_new_battle: bool = False
|
| 48 |
+
# Stored reference to the battle we are in (set when .battle is read).
|
| 49 |
+
# Used for forfeit so we always target the right battle.
|
| 50 |
+
self._current_battle: Optional[Battle] = None
|
| 51 |
+
|
| 52 |
+
def _log(self, message: str) -> None:
|
| 53 |
+
if self.config.verbose_logging:
|
| 54 |
+
print(f"[PokeEnvClient] {message}", flush=True)
|
| 55 |
+
|
| 56 |
+
# -------------------------------------------------------------------------
|
| 57 |
+
# Event loop management
|
| 58 |
+
# -------------------------------------------------------------------------
|
| 59 |
+
|
| 60 |
+
def start(self) -> None:
|
| 61 |
+
"""Start the background asyncio loop and create players (once)."""
|
| 62 |
+
if self._loop is not None:
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
loop = asyncio.new_event_loop()
|
| 66 |
+
|
| 67 |
+
def _run_loop() -> None:
|
| 68 |
+
asyncio.set_event_loop(loop)
|
| 69 |
+
loop.run_forever()
|
| 70 |
+
|
| 71 |
+
thread = threading.Thread(target=_run_loop, daemon=True)
|
| 72 |
+
thread.start()
|
| 73 |
+
|
| 74 |
+
self._loop = loop
|
| 75 |
+
self._thread = thread
|
| 76 |
+
self._log("Background event loop started.")
|
| 77 |
+
|
| 78 |
+
# Create players once; they stay connected for the lifetime of this env.
|
| 79 |
+
self._action_queue = asyncio.Queue()
|
| 80 |
+
fmt = self.config.battle_format or DEFAULT_BATTLE_FORMAT
|
| 81 |
+
|
| 82 |
+
async def _create_players() -> None:
|
| 83 |
+
self._rl_player = RLPlayer(
|
| 84 |
+
action_queue=self._action_queue,
|
| 85 |
+
battle_format=fmt,
|
| 86 |
+
server_configuration=LocalhostServerConfiguration,
|
| 87 |
+
)
|
| 88 |
+
self._opponent = RandomPlayer(
|
| 89 |
+
battle_format=fmt,
|
| 90 |
+
server_configuration=LocalhostServerConfiguration,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
future = asyncio.run_coroutine_threadsafe(_create_players(), loop)
|
| 94 |
+
future.result(timeout=15.0)
|
| 95 |
+
# Give the server a moment to register both connections.
|
| 96 |
+
time.sleep(1.0)
|
| 97 |
+
self._log("Players created and connected.")
|
| 98 |
+
|
| 99 |
+
def stop(self) -> None:
|
| 100 |
+
"""Stop the background loop and clean up."""
|
| 101 |
+
if self._loop is None:
|
| 102 |
+
return
|
| 103 |
+
self._loop.call_soon_threadsafe(self._loop.stop)
|
| 104 |
+
if self._thread is not None:
|
| 105 |
+
self._thread.join(timeout=5.0)
|
| 106 |
+
self._loop = None
|
| 107 |
+
self._thread = None
|
| 108 |
+
self._battle_task = None
|
| 109 |
+
self._rl_player = None
|
| 110 |
+
self._opponent = None
|
| 111 |
+
self._action_queue = None
|
| 112 |
+
self._known_battle_tags = set()
|
| 113 |
+
self._awaiting_new_battle = False
|
| 114 |
+
self._current_battle = None
|
| 115 |
+
self._log("Background event loop stopped.")
|
| 116 |
+
|
| 117 |
+
def restart(self) -> None:
|
| 118 |
+
"""Hard-restart loop + players to recover from stuck/cancelled battles."""
|
| 119 |
+
self._log("Restarting client event loop and players.")
|
| 120 |
+
self.stop()
|
| 121 |
+
self.start()
|
| 122 |
+
|
| 123 |
+
# -------------------------------------------------------------------------
|
| 124 |
+
# Battle lifecycle
|
| 125 |
+
# -------------------------------------------------------------------------
|
| 126 |
+
|
| 127 |
+
def forfeit_current_battle(self) -> None:
|
| 128 |
+
"""Forfeit the current Showdown battle if it is still in progress.
|
| 129 |
+
|
| 130 |
+
Must be called before start_new_battle() when the env ends a battle early
|
| 131 |
+
(e.g. due to min_battle_reward) so the player is freed for the next battle.
|
| 132 |
+
"""
|
| 133 |
+
if self._loop is None or self._rl_player is None:
|
| 134 |
+
return
|
| 135 |
+
# Use stored battle so we forfeit the one we were in, not whatever .battle returns now.
|
| 136 |
+
battle = self._current_battle if self._current_battle is not None else self.battle
|
| 137 |
+
if battle is None or battle.finished:
|
| 138 |
+
return
|
| 139 |
+
|
| 140 |
+
room = battle.battle_tag
|
| 141 |
+
|
| 142 |
+
async def _do_forfeit() -> None:
|
| 143 |
+
try:
|
| 144 |
+
await self._rl_player.send_message("/forfeit", room)
|
| 145 |
+
except Exception:
|
| 146 |
+
pass
|
| 147 |
+
|
| 148 |
+
try:
|
| 149 |
+
fut = asyncio.run_coroutine_threadsafe(_do_forfeit(), self._loop)
|
| 150 |
+
fut.result(timeout=5.0)
|
| 151 |
+
except Exception:
|
| 152 |
+
pass
|
| 153 |
+
# Give the server time to end the battle and free both players.
|
| 154 |
+
time.sleep(1.5)
|
| 155 |
+
self._current_battle = None
|
| 156 |
+
self._log("Forfeited current battle.")
|
| 157 |
+
|
| 158 |
+
def start_new_battle(self) -> None:
|
| 159 |
+
"""Launch a new battle using the already-connected players."""
|
| 160 |
+
if self._loop is None:
|
| 161 |
+
self.start()
|
| 162 |
+
assert self._loop is not None
|
| 163 |
+
assert self._rl_player is not None
|
| 164 |
+
assert self._opponent is not None
|
| 165 |
+
|
| 166 |
+
# Forfeit any ongoing Showdown battle before starting a new one so the
|
| 167 |
+
# player is not stuck mid-battle when battle_against is called again.
|
| 168 |
+
self.forfeit_current_battle()
|
| 169 |
+
|
| 170 |
+
# Let the previous battle task finish cleanly (server will end battle
|
| 171 |
+
# after forfeit). If it does not settle, hard-restart the client.
|
| 172 |
+
restart_required = False
|
| 173 |
+
if self._battle_task is not None and not self._battle_task.done():
|
| 174 |
+
try:
|
| 175 |
+
self._battle_task.result(timeout=25.0)
|
| 176 |
+
except Exception:
|
| 177 |
+
self._battle_task.cancel()
|
| 178 |
+
self._log("Previous battle task timed out or failed; requesting client restart.")
|
| 179 |
+
restart_required = True
|
| 180 |
+
else:
|
| 181 |
+
self._log("Previous battle task finished.")
|
| 182 |
+
|
| 183 |
+
if restart_required:
|
| 184 |
+
# Hard recovery path: refresh websocket connections and players.
|
| 185 |
+
self.restart()
|
| 186 |
+
assert self._loop is not None
|
| 187 |
+
assert self._rl_player is not None
|
| 188 |
+
assert self._opponent is not None
|
| 189 |
+
|
| 190 |
+
self._current_battle = None # Will be set when the new battle appears.
|
| 191 |
+
|
| 192 |
+
# Let the server fully free both players before we start the next battle.
|
| 193 |
+
time.sleep(2.0)
|
| 194 |
+
|
| 195 |
+
# Fresh action queue for this battle.
|
| 196 |
+
self._action_queue = asyncio.Queue()
|
| 197 |
+
self._rl_player._action_queue = self._action_queue
|
| 198 |
+
|
| 199 |
+
# Record current battle tags so .battle can wait for a genuinely new one.
|
| 200 |
+
self._known_battle_tags = set(self._rl_player.battles.keys())
|
| 201 |
+
self._awaiting_new_battle = True
|
| 202 |
+
|
| 203 |
+
async def _run_battle() -> None:
|
| 204 |
+
await self._rl_player.battle_against(self._opponent, n_battles=1)
|
| 205 |
+
|
| 206 |
+
self._battle_task = asyncio.run_coroutine_threadsafe(
|
| 207 |
+
_run_battle(), self._loop
|
| 208 |
+
)
|
| 209 |
+
self._log(
|
| 210 |
+
f"Launching new battle in format "
|
| 211 |
+
f"{self.config.battle_format or DEFAULT_BATTLE_FORMAT}."
|
| 212 |
+
)
|
| 213 |
+
time.sleep(self.config.poll_interval_seconds)
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def battle(self) -> Optional[Battle]:
|
| 217 |
+
"""Return the current Battle for this run, or None if not started yet."""
|
| 218 |
+
if self._rl_player is None or not self._rl_player.battles:
|
| 219 |
+
return None
|
| 220 |
+
|
| 221 |
+
# During reset(), wait for a battle tag that did not exist before
|
| 222 |
+
# start_new_battle() was called.
|
| 223 |
+
if self._awaiting_new_battle:
|
| 224 |
+
unseen = [
|
| 225 |
+
b
|
| 226 |
+
for tag, b in self._rl_player.battles.items()
|
| 227 |
+
if tag not in self._known_battle_tags
|
| 228 |
+
]
|
| 229 |
+
if not unseen:
|
| 230 |
+
return None
|
| 231 |
+
active_unseen = [b for b in unseen if not b.finished]
|
| 232 |
+
b = active_unseen[-1] if active_unseen else unseen[-1]
|
| 233 |
+
self._awaiting_new_battle = False
|
| 234 |
+
self._current_battle = b
|
| 235 |
+
return b
|
| 236 |
+
|
| 237 |
+
battles = list(self._rl_player.battles.values())
|
| 238 |
+
active = [b for b in battles if not b.finished]
|
| 239 |
+
if active:
|
| 240 |
+
b = active[-1]
|
| 241 |
+
self._current_battle = b
|
| 242 |
+
return b
|
| 243 |
+
# All finished — return the latest one (covers the case where the battle
|
| 244 |
+
# ended before we got a chance to poll it).
|
| 245 |
+
b = battles[-1]
|
| 246 |
+
self._current_battle = b
|
| 247 |
+
return b
|
| 248 |
+
|
| 249 |
+
def send_action(self, order: BattleOrder) -> None:
|
| 250 |
+
"""Submit an action for the RL player to execute."""
|
| 251 |
+
if self._loop is None or self._action_queue is None:
|
| 252 |
+
raise RuntimeError("PokeEnvClient has not been started.")
|
| 253 |
+
|
| 254 |
+
async def _enqueue() -> None:
|
| 255 |
+
assert self._action_queue is not None
|
| 256 |
+
await self._action_queue.put(order)
|
| 257 |
+
|
| 258 |
+
asyncio.run_coroutine_threadsafe(_enqueue(), self._loop)
|
| 259 |
+
self._log("Submitted action to RLPlayer queue.")
|
| 260 |
+
|
| 261 |
+
def wait_for_battle_update(self, previous_turn: int) -> Optional[Battle]:
|
| 262 |
+
"""Block until the battle advances to a new turn or ends."""
|
| 263 |
+
start_time = time.time()
|
| 264 |
+
heartbeat_every = max(self.config.poll_heartbeat_seconds, self.config.poll_interval_seconds)
|
| 265 |
+
next_heartbeat_at = start_time + heartbeat_every
|
| 266 |
+
while True:
|
| 267 |
+
battle = self.battle
|
| 268 |
+
if battle is None:
|
| 269 |
+
now = time.time()
|
| 270 |
+
if now > next_heartbeat_at:
|
| 271 |
+
elapsed = now - start_time
|
| 272 |
+
self._log(
|
| 273 |
+
f"Still waiting for battle object "
|
| 274 |
+
f"({elapsed:.1f}s elapsed, previous_turn={previous_turn})."
|
| 275 |
+
)
|
| 276 |
+
next_heartbeat_at = now + heartbeat_every
|
| 277 |
+
if now - start_time > self.config.open_timeout:
|
| 278 |
+
self._log("Timed out waiting for initial battle object.")
|
| 279 |
+
return None
|
| 280 |
+
time.sleep(self.config.poll_interval_seconds)
|
| 281 |
+
continue
|
| 282 |
+
|
| 283 |
+
if battle.finished or battle.turn > previous_turn:
|
| 284 |
+
self._log(
|
| 285 |
+
f"Battle update received: turn={battle.turn}, finished={battle.finished}."
|
| 286 |
+
)
|
| 287 |
+
return battle
|
| 288 |
+
|
| 289 |
+
now = time.time()
|
| 290 |
+
if now > next_heartbeat_at:
|
| 291 |
+
elapsed = now - start_time
|
| 292 |
+
self._log(
|
| 293 |
+
f"Waiting for turn advance: current_turn={battle.turn}, "
|
| 294 |
+
f"previous_turn={previous_turn}, elapsed={elapsed:.1f}s."
|
| 295 |
+
)
|
| 296 |
+
next_heartbeat_at = now + heartbeat_every
|
| 297 |
+
|
| 298 |
+
if now - start_time > self.config.open_timeout:
|
| 299 |
+
self._log(
|
| 300 |
+
f"Turn-advance wait timed out at turn={battle.turn}; returning last state."
|
| 301 |
+
)
|
| 302 |
+
return battle
|
| 303 |
+
|
| 304 |
+
time.sleep(self.config.poll_interval_seconds)
|
src/smogon_rl/reward.py
ADDED
|
@@ -0,0 +1,320 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from poke_env.environment.battle import Battle
|
| 7 |
+
from poke_env.environment.pokemon import Pokemon
|
| 8 |
+
|
| 9 |
+
from .action_space import ActionJSON
|
| 10 |
+
from .state_formatter import hp_fraction_to_percent
|
| 11 |
+
|
| 12 |
+
# Hefty penalty when model outputs illegal action (e.g. hallucinated Pokemon).
|
| 13 |
+
# Used during rollout collection; recorded as collected_reward so GRPO learns to avoid illegal outputs.
|
| 14 |
+
ILLEGAL_ACTION_PENALTY = -10.0
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
@dataclass
|
| 18 |
+
class BattleStateSummary:
|
| 19 |
+
self_team_hp_percent: float
|
| 20 |
+
opp_team_hp_percent: float
|
| 21 |
+
self_fainted: int
|
| 22 |
+
opp_fainted: int
|
| 23 |
+
self_statuses: Dict[str, Optional[str]]
|
| 24 |
+
opp_statuses: Dict[str, Optional[str]]
|
| 25 |
+
self_stat_stages: Dict[str, Dict[str, int]]
|
| 26 |
+
opp_stat_stages: Dict[str, Dict[str, int]]
|
| 27 |
+
opponent_passive_hits: int
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class RewardTrackingState:
|
| 32 |
+
healing_reward_used: float = 0.0
|
| 33 |
+
per_pokemon_setup_reward_used: Dict[str, float] = field(default_factory=dict)
|
| 34 |
+
passive_hits_total: int = 0
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def _team_hp_and_faints(team: Dict[str, Pokemon]) -> tuple[float, int]:
|
| 38 |
+
total_hp = 0.0
|
| 39 |
+
total_max_hp = 0.0
|
| 40 |
+
fainted = 0
|
| 41 |
+
for mon in team.values():
|
| 42 |
+
if mon.max_hp is None or mon.max_hp <= 0:
|
| 43 |
+
continue
|
| 44 |
+
total_hp += max(0, mon.current_hp)
|
| 45 |
+
total_max_hp += mon.max_hp
|
| 46 |
+
if mon.fainted:
|
| 47 |
+
fainted += 1
|
| 48 |
+
if total_max_hp <= 0:
|
| 49 |
+
return 0.0, fainted
|
| 50 |
+
return (total_hp / total_max_hp) * 100.0, fainted
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def _collect_statuses(team: Dict[str, Pokemon]) -> Dict[str, Optional[str]]:
|
| 54 |
+
return {
|
| 55 |
+
mon.species or key: (str(mon.status) if mon.status is not None else None)
|
| 56 |
+
for key, mon in team.items()
|
| 57 |
+
}
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def _collect_stat_stages(team: Dict[str, Pokemon]) -> Dict[str, Dict[str, int]]:
|
| 61 |
+
return {mon.species or key: dict(mon.boosts) for key, mon in team.items()}
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def _passive_events_in_turn(events: list, opponent_role: str) -> int:
|
| 65 |
+
"""Count passive-damage hits for the opponent in one turn's raw event list."""
|
| 66 |
+
count = 0
|
| 67 |
+
for event in events:
|
| 68 |
+
if not event or event[0] != "-damage":
|
| 69 |
+
continue
|
| 70 |
+
if len(event) < 2:
|
| 71 |
+
continue
|
| 72 |
+
if not event[1].startswith(opponent_role):
|
| 73 |
+
continue
|
| 74 |
+
# "[from]" in any trailing field marks an external/passive damage source:
|
| 75 |
+
# e.g. "[from] brn", "[from] Stealth Rock", "[from] Leech Seed", etc.
|
| 76 |
+
if any("[from]" in part for part in event[2:]):
|
| 77 |
+
count += 1
|
| 78 |
+
return count
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def count_new_passive_hits_for_turn(battle: Battle, turn_number: int) -> int:
|
| 82 |
+
"""Count passive damage hits the opponent took on a single, specific turn.
|
| 83 |
+
|
| 84 |
+
Designed for O(k) per step use: only the events from `turn_number` are
|
| 85 |
+
scanned. The caller accumulates the running total across turns.
|
| 86 |
+
|
| 87 |
+
Parameters
|
| 88 |
+
----------
|
| 89 |
+
battle:
|
| 90 |
+
The current poke-env Battle object.
|
| 91 |
+
turn_number:
|
| 92 |
+
The turn whose Observation.events should be inspected (usually the
|
| 93 |
+
turn that just resolved, i.e., the value of `battle.turn` before
|
| 94 |
+
the action was submitted).
|
| 95 |
+
"""
|
| 96 |
+
obs = battle.observations.get(turn_number)
|
| 97 |
+
if obs is None:
|
| 98 |
+
return 0
|
| 99 |
+
opponent_role = "p2" if battle.player_role == "p1" else "p1"
|
| 100 |
+
return _passive_events_in_turn(obs.events, opponent_role)
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
def _count_passive_hits_on_opponent(battle: Battle) -> int:
|
| 104 |
+
"""Full-scan fallback: count cumulative passive hits across all observed turns.
|
| 105 |
+
|
| 106 |
+
This is O(total events) and should only be called once on reset() to
|
| 107 |
+
establish a baseline. Per-step increments should use
|
| 108 |
+
`count_new_passive_hits_for_turn` instead.
|
| 109 |
+
"""
|
| 110 |
+
opponent_role = "p2" if battle.player_role == "p1" else "p1"
|
| 111 |
+
count = 0
|
| 112 |
+
for obs in battle.observations.values():
|
| 113 |
+
count += _passive_events_in_turn(obs.events, opponent_role)
|
| 114 |
+
return count
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def summarize_battle_state(battle: Battle, cumulative_passive_hits: int = 0) -> BattleStateSummary:
|
| 118 |
+
"""Snapshot the current battle state into a plain dataclass.
|
| 119 |
+
|
| 120 |
+
Parameters
|
| 121 |
+
----------
|
| 122 |
+
battle:
|
| 123 |
+
The live poke-env Battle object.
|
| 124 |
+
cumulative_passive_hits:
|
| 125 |
+
Running total of passive damage hits the opponent has taken this
|
| 126 |
+
battle, maintained by the caller (e.g. PokemonShowdownEnv) using
|
| 127 |
+
`count_new_passive_hits_for_turn` to keep each step O(k).
|
| 128 |
+
Defaults to 0 for the initial state on reset().
|
| 129 |
+
"""
|
| 130 |
+
self_hp, self_fainted = _team_hp_and_faints(battle.team)
|
| 131 |
+
opp_hp, opp_fainted = _team_hp_and_faints(battle.opponent_team)
|
| 132 |
+
self_statuses = _collect_statuses(battle.team)
|
| 133 |
+
opp_statuses = _collect_statuses(battle.opponent_team)
|
| 134 |
+
self_stats = _collect_stat_stages(battle.team)
|
| 135 |
+
opp_stats = _collect_stat_stages(battle.opponent_team)
|
| 136 |
+
return BattleStateSummary(
|
| 137 |
+
self_team_hp_percent=self_hp,
|
| 138 |
+
opp_team_hp_percent=opp_hp,
|
| 139 |
+
self_fainted=self_fainted,
|
| 140 |
+
opp_fainted=opp_fainted,
|
| 141 |
+
self_statuses=self_statuses,
|
| 142 |
+
opp_statuses=opp_statuses,
|
| 143 |
+
self_stat_stages=self_stats,
|
| 144 |
+
opp_stat_stages=opp_stats,
|
| 145 |
+
opponent_passive_hits=cumulative_passive_hits,
|
| 146 |
+
)
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
def _status_penalty(prev_statuses: Dict[str, Optional[str]], curr_statuses: Dict[str, Optional[str]]) -> float:
|
| 150 |
+
penalty = 0.0
|
| 151 |
+
for key, curr in curr_statuses.items():
|
| 152 |
+
prev = prev_statuses.get(key)
|
| 153 |
+
if prev == curr:
|
| 154 |
+
continue
|
| 155 |
+
if curr is None:
|
| 156 |
+
# Could be a status cure handled elsewhere.
|
| 157 |
+
continue
|
| 158 |
+
code = curr.lower()
|
| 159 |
+
if code in {"brn", "psn", "tox"}:
|
| 160 |
+
penalty -= 0.5
|
| 161 |
+
elif code in {"par", "frz", "slp", "conf"}:
|
| 162 |
+
penalty -= 1.0
|
| 163 |
+
return penalty
|
| 164 |
+
|
| 165 |
+
|
| 166 |
+
def _healing_reward(prev_hp: float, curr_hp: float, trackers: RewardTrackingState) -> float:
|
| 167 |
+
if curr_hp <= prev_hp:
|
| 168 |
+
return 0.0
|
| 169 |
+
healed = curr_hp - prev_hp
|
| 170 |
+
raw = (healed / 10.0) # +1.0 per 10% healed
|
| 171 |
+
remaining_cap = max(0.0, 3.0 - trackers.healing_reward_used)
|
| 172 |
+
reward = min(raw, remaining_cap)
|
| 173 |
+
trackers.healing_reward_used += reward
|
| 174 |
+
return reward
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
def _setup_reward(
|
| 178 |
+
prev_stats: Dict[str, Dict[str, int]],
|
| 179 |
+
curr_stats: Dict[str, Dict[str, int]],
|
| 180 |
+
active: Pokemon,
|
| 181 |
+
trackers: RewardTrackingState,
|
| 182 |
+
) -> float:
|
| 183 |
+
active_key = active.species or "active"
|
| 184 |
+
prev = prev_stats.get(active_key, {})
|
| 185 |
+
curr = curr_stats.get(active_key, {})
|
| 186 |
+
delta_stages = 0
|
| 187 |
+
for stat, curr_stage in curr.items():
|
| 188 |
+
prev_stage = prev.get(stat, 0)
|
| 189 |
+
if curr_stage > prev_stage:
|
| 190 |
+
delta_stages += curr_stage - prev_stage
|
| 191 |
+
if delta_stages <= 0:
|
| 192 |
+
return 0.0
|
| 193 |
+
if hp_fraction_to_percent(active.current_hp_fraction) <= 50.0:
|
| 194 |
+
return 0.0
|
| 195 |
+
|
| 196 |
+
raw = 0.5 * delta_stages
|
| 197 |
+
used = trackers.per_pokemon_setup_reward_used.get(active_key, 0.0)
|
| 198 |
+
remaining_cap = max(0.0, 2.0 - used)
|
| 199 |
+
reward = min(raw, remaining_cap)
|
| 200 |
+
trackers.per_pokemon_setup_reward_used[active_key] = used + reward
|
| 201 |
+
return reward
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
def _opponent_setup_penalty(
|
| 205 |
+
prev_stats: Dict[str, Dict[str, int]],
|
| 206 |
+
curr_stats: Dict[str, Dict[str, int]],
|
| 207 |
+
) -> float:
|
| 208 |
+
penalty = 0.0
|
| 209 |
+
for key, curr in curr_stats.items():
|
| 210 |
+
prev = prev_stats.get(key, {})
|
| 211 |
+
for stat, curr_stage in curr.items():
|
| 212 |
+
prev_stage = prev.get(stat, 0)
|
| 213 |
+
if curr_stage > prev_stage:
|
| 214 |
+
penalty -= 0.5 * (curr_stage - prev_stage)
|
| 215 |
+
return penalty
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def _passive_damage_reward(
|
| 219 |
+
prev_hits: int,
|
| 220 |
+
curr_hits: int,
|
| 221 |
+
trackers: RewardTrackingState,
|
| 222 |
+
) -> float:
|
| 223 |
+
if curr_hits <= prev_hits:
|
| 224 |
+
return 0.0
|
| 225 |
+
delta = curr_hits - prev_hits
|
| 226 |
+
trackers.passive_hits_total += delta
|
| 227 |
+
return 0.01 * trackers.passive_hits_total
|
| 228 |
+
|
| 229 |
+
|
| 230 |
+
def _damage_rewards(prev: BattleStateSummary, curr: BattleStateSummary) -> float:
|
| 231 |
+
reward = 0.0
|
| 232 |
+
# Damage dealt: +1.0 per 10% opponent HP reduced
|
| 233 |
+
if curr.opp_team_hp_percent < prev.opp_team_hp_percent:
|
| 234 |
+
delta = prev.opp_team_hp_percent - curr.opp_team_hp_percent
|
| 235 |
+
reward += delta / 10.0
|
| 236 |
+
# Damage taken: -1.0 per 10% self HP lost
|
| 237 |
+
if curr.self_team_hp_percent < prev.self_team_hp_percent:
|
| 238 |
+
delta = prev.self_team_hp_percent - curr.self_team_hp_percent
|
| 239 |
+
reward -= delta / 10.0
|
| 240 |
+
return reward
|
| 241 |
+
|
| 242 |
+
|
| 243 |
+
def _knockout_rewards(prev: BattleStateSummary, curr: BattleStateSummary) -> float:
|
| 244 |
+
reward = 0.0
|
| 245 |
+
if curr.opp_fainted > prev.opp_fainted:
|
| 246 |
+
reward += 3.0 * (curr.opp_fainted - prev.opp_fainted)
|
| 247 |
+
if curr.self_fainted > prev.self_fainted:
|
| 248 |
+
reward -= 3.0 * (curr.self_fainted - prev.self_fainted)
|
| 249 |
+
return reward
|
| 250 |
+
|
| 251 |
+
|
| 252 |
+
def calculate_reward(
|
| 253 |
+
prev_state: BattleStateSummary,
|
| 254 |
+
curr_state: BattleStateSummary,
|
| 255 |
+
action: ActionJSON,
|
| 256 |
+
trackers: RewardTrackingState,
|
| 257 |
+
active: Optional[Pokemon] = None,
|
| 258 |
+
opponent_active: Optional[Pokemon] = None,
|
| 259 |
+
move_was_super_effective: bool = False,
|
| 260 |
+
move_hit: bool = True,
|
| 261 |
+
move_was_immune: bool = False,
|
| 262 |
+
team_status_cured: bool = False,
|
| 263 |
+
) -> float:
|
| 264 |
+
"""Compute shaped reward between two consecutive battle summaries.
|
| 265 |
+
|
| 266 |
+
The additional keyword arguments allow the caller to provide extra context from
|
| 267 |
+
the last action (type effectiveness, accuracy result, status cures) that are
|
| 268 |
+
not fully recoverable from the static battle snapshots alone.
|
| 269 |
+
"""
|
| 270 |
+
reward = 0.0
|
| 271 |
+
|
| 272 |
+
# Core mechanics
|
| 273 |
+
reward += _damage_rewards(prev_state, curr_state)
|
| 274 |
+
reward += _knockout_rewards(prev_state, curr_state)
|
| 275 |
+
|
| 276 |
+
# Strategic nudges: type effectiveness and accuracy
|
| 277 |
+
if action.action == "move":
|
| 278 |
+
if move_was_super_effective:
|
| 279 |
+
reward += 0.5
|
| 280 |
+
if move_was_immune:
|
| 281 |
+
reward -= 1.0
|
| 282 |
+
if not move_hit:
|
| 283 |
+
reward -= 0.25
|
| 284 |
+
|
| 285 |
+
# Healing
|
| 286 |
+
reward += _healing_reward(
|
| 287 |
+
prev_state.self_team_hp_percent,
|
| 288 |
+
curr_state.self_team_hp_percent,
|
| 289 |
+
trackers,
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
# Status cures (e.g., Aromatherapy)
|
| 293 |
+
if team_status_cured:
|
| 294 |
+
reward += 1.0
|
| 295 |
+
|
| 296 |
+
# Setup sweeping (self) and opponent setup
|
| 297 |
+
if active is not None:
|
| 298 |
+
reward += _setup_reward(
|
| 299 |
+
prev_state.self_stat_stages,
|
| 300 |
+
curr_state.self_stat_stages,
|
| 301 |
+
active,
|
| 302 |
+
trackers,
|
| 303 |
+
)
|
| 304 |
+
reward += _opponent_setup_penalty(
|
| 305 |
+
prev_state.opp_stat_stages,
|
| 306 |
+
curr_state.opp_stat_stages,
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
# Passive damage / hazards
|
| 310 |
+
reward += _passive_damage_reward(
|
| 311 |
+
prev_state.opponent_passive_hits,
|
| 312 |
+
curr_state.opponent_passive_hits,
|
| 313 |
+
trackers,
|
| 314 |
+
)
|
| 315 |
+
|
| 316 |
+
# Status afflictions
|
| 317 |
+
reward += _status_penalty(prev_state.self_statuses, curr_state.self_statuses)
|
| 318 |
+
|
| 319 |
+
return reward
|
| 320 |
+
|
src/smogon_rl/state_formatter.py
ADDED
|
@@ -0,0 +1,181 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
+
|
| 3 |
+
from dataclasses import dataclass, field
|
| 4 |
+
from typing import Dict, List, Optional
|
| 5 |
+
|
| 6 |
+
from poke_env.environment.battle import Battle
|
| 7 |
+
from poke_env.environment.pokemon import Pokemon
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
@dataclass
|
| 11 |
+
class OpponentMonHistory:
|
| 12 |
+
name: str
|
| 13 |
+
last_known_hp_percent: float
|
| 14 |
+
status: Optional[str]
|
| 15 |
+
revealed_moves: List[str] = field(default_factory=list)
|
| 16 |
+
revealed_item: Optional[str] = None
|
| 17 |
+
revealed_ability: Optional[str] = None
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
@dataclass
|
| 21 |
+
class OpponentHistoryTracker:
|
| 22 |
+
revealed: Dict[str, OpponentMonHistory] = field(default_factory=dict)
|
| 23 |
+
|
| 24 |
+
def update_from_battle(self, battle: Battle) -> None:
|
| 25 |
+
for mon in battle.opponent_team.values():
|
| 26 |
+
if not mon.species:
|
| 27 |
+
continue
|
| 28 |
+
key = mon.species
|
| 29 |
+
entry = self.revealed.get(
|
| 30 |
+
key,
|
| 31 |
+
OpponentMonHistory(
|
| 32 |
+
name=mon.species,
|
| 33 |
+
last_known_hp_percent=hp_fraction_to_percent(mon.current_hp_fraction),
|
| 34 |
+
status=str(mon.status) if mon.status is not None else None,
|
| 35 |
+
),
|
| 36 |
+
)
|
| 37 |
+
entry.last_known_hp_percent = hp_fraction_to_percent(mon.current_hp_fraction)
|
| 38 |
+
entry.status = str(mon.status) if mon.status is not None else None
|
| 39 |
+
|
| 40 |
+
for move in mon.moves.values():
|
| 41 |
+
move_name = move.id
|
| 42 |
+
if move_name not in entry.revealed_moves:
|
| 43 |
+
entry.revealed_moves.append(move_name)
|
| 44 |
+
|
| 45 |
+
if mon.item is not None:
|
| 46 |
+
entry.revealed_item = mon.item
|
| 47 |
+
if mon.ability is not None:
|
| 48 |
+
entry.revealed_ability = mon.ability
|
| 49 |
+
|
| 50 |
+
self.revealed[key] = entry
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
def hp_fraction_to_percent(fraction: float | None) -> float:
|
| 54 |
+
if fraction is None:
|
| 55 |
+
return 0.0
|
| 56 |
+
return max(0.0, min(1.0, float(fraction))) * 100.0
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
def _format_stat_modifiers(pokemon: Pokemon) -> str:
|
| 60 |
+
parts: List[str] = []
|
| 61 |
+
for stat, stage in pokemon.boosts.items():
|
| 62 |
+
if stage == 0:
|
| 63 |
+
continue
|
| 64 |
+
sign = "+" if stage > 0 else ""
|
| 65 |
+
parts.append(f"{stat.capitalize()} {sign}{stage}")
|
| 66 |
+
return ", ".join(parts) if parts else "None"
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
def _estimate_speed_range(pokemon: Pokemon) -> str:
|
| 70 |
+
base_speed = pokemon.base_stats.get("spe", 0)
|
| 71 |
+
if base_speed <= 0:
|
| 72 |
+
return "Unknown"
|
| 73 |
+
|
| 74 |
+
level = 100
|
| 75 |
+
min_speed = int((((2 * base_speed) * level) / 100 + 5) * 0.9)
|
| 76 |
+
max_speed = int((((2 * base_speed + 31 + (252 // 4)) * level) / 100 + 5) * 1.1)
|
| 77 |
+
return f"{min_speed}-{max_speed}"
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def _format_pokemon_line(pokemon: Pokemon) -> str:
|
| 81 |
+
hp = hp_fraction_to_percent(pokemon.current_hp_fraction)
|
| 82 |
+
status = str(pokemon.status) if pokemon.status is not None else "OK"
|
| 83 |
+
item = pokemon.item or "?"
|
| 84 |
+
return f"- {pokemon.species or '?'} HP:{hp:.0f}% {status} Item:{item}"
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def _format_moveset_section(pokemon: Pokemon) -> str:
|
| 88 |
+
if not pokemon.moves:
|
| 89 |
+
return " Moves: [unknown]"
|
| 90 |
+
parts = []
|
| 91 |
+
for move in pokemon.moves.values():
|
| 92 |
+
bp = move.base_power or 0
|
| 93 |
+
t = move.type.name[0] if move.type is not None else "?"
|
| 94 |
+
parts.append(f"{move.id}({t}{bp})")
|
| 95 |
+
return " Moves: " + " | ".join(parts)
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def format_battle_state(battle: Battle, opponent_history: OpponentHistoryTracker) -> str:
|
| 99 |
+
"""Format the full battle state into a markdown string for the LLM.
|
| 100 |
+
|
| 101 |
+
Structure:
|
| 102 |
+
- Part A: Active field (self and opponent).
|
| 103 |
+
- Part B: Full self roster and movesets.
|
| 104 |
+
- Part C: Opponent history (revealed bench, revealed info).
|
| 105 |
+
"""
|
| 106 |
+
opponent_history.update_from_battle(battle)
|
| 107 |
+
|
| 108 |
+
lines: List[str] = []
|
| 109 |
+
|
| 110 |
+
# ------------------------------------------------------------------ Part A
|
| 111 |
+
lines.append("## Part A: Active Field")
|
| 112 |
+
|
| 113 |
+
# Self active
|
| 114 |
+
self_active = battle.active_pokemon
|
| 115 |
+
if self_active is not None:
|
| 116 |
+
self_hp = hp_fraction_to_percent(self_active.current_hp_fraction)
|
| 117 |
+
self_status = (
|
| 118 |
+
str(self_active.status) if self_active.status is not None else "Healthy"
|
| 119 |
+
)
|
| 120 |
+
self_ability = self_active.ability or "Unknown"
|
| 121 |
+
self_item = self_active.item or "None"
|
| 122 |
+
self_mods = _format_stat_modifiers(self_active)
|
| 123 |
+
lines.append("### Active Self")
|
| 124 |
+
lines.append(
|
| 125 |
+
f"- Name: {self_active.species or 'Unknown'}\n"
|
| 126 |
+
f"- HP: {self_hp:.1f}%\n"
|
| 127 |
+
f"- Status: {self_status}\n"
|
| 128 |
+
f"- Ability: {self_ability}\n"
|
| 129 |
+
f"- Item: {self_item}\n"
|
| 130 |
+
f"- Stat Modifiers: {self_mods}"
|
| 131 |
+
)
|
| 132 |
+
else:
|
| 133 |
+
lines.append("### Active Self\n- None")
|
| 134 |
+
|
| 135 |
+
# Opponent active
|
| 136 |
+
opp_active = battle.opponent_active_pokemon
|
| 137 |
+
if opp_active is not None:
|
| 138 |
+
opp_hp = hp_fraction_to_percent(opp_active.current_hp_fraction)
|
| 139 |
+
opp_status = (
|
| 140 |
+
str(opp_active.status) if opp_active.status is not None else "Healthy"
|
| 141 |
+
)
|
| 142 |
+
opp_speed_range = _estimate_speed_range(opp_active)
|
| 143 |
+
lines.append("### Active Opponent")
|
| 144 |
+
lines.append(
|
| 145 |
+
f"- Name: {opp_active.species or 'Unknown'}\n"
|
| 146 |
+
f"- HP: {opp_hp:.1f}%\n"
|
| 147 |
+
f"- Status: {opp_status}\n"
|
| 148 |
+
f"- Speed Range: {opp_speed_range}"
|
| 149 |
+
)
|
| 150 |
+
else:
|
| 151 |
+
lines.append("### Active Opponent\n- None")
|
| 152 |
+
|
| 153 |
+
# ------------------------------------------------------------------ Part B
|
| 154 |
+
lines.append("\n## Part B: Full Self Roster")
|
| 155 |
+
if not battle.team:
|
| 156 |
+
lines.append("- [Unknown team]")
|
| 157 |
+
else:
|
| 158 |
+
for mon in battle.team.values():
|
| 159 |
+
lines.append(_format_pokemon_line(mon))
|
| 160 |
+
lines.append(_format_moveset_section(mon))
|
| 161 |
+
|
| 162 |
+
# ------------------------------------------------------------------ Part C
|
| 163 |
+
lines.append("\n## Part C: Opponent History")
|
| 164 |
+
if not opponent_history.revealed:
|
| 165 |
+
lines.append("- No opponent Pokémon revealed yet.")
|
| 166 |
+
else:
|
| 167 |
+
for entry in opponent_history.revealed.values():
|
| 168 |
+
lines.append(
|
| 169 |
+
f"- {entry.name} | Last HP: {entry.last_known_hp_percent:.1f}% | "
|
| 170 |
+
f"Status: {entry.status or 'Healthy'}"
|
| 171 |
+
)
|
| 172 |
+
if entry.revealed_moves:
|
| 173 |
+
moves = ", ".join(entry.revealed_moves)
|
| 174 |
+
lines.append(f" - Revealed moves: {moves}")
|
| 175 |
+
if entry.revealed_item:
|
| 176 |
+
lines.append(f" - Revealed item: {entry.revealed_item}")
|
| 177 |
+
if entry.revealed_ability:
|
| 178 |
+
lines.append(f" - Revealed ability: {entry.revealed_ability}")
|
| 179 |
+
|
| 180 |
+
return "\n".join(lines)
|
| 181 |
+
|
trainer.ipynb
ADDED
|
The diff for this file is too large to render.
See raw diff
|
|
|