Atharva commited on
Commit
a8df3de
·
0 Parent(s):

Initial hackathon submission export

Browse files
.gitignore ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Environment and secrets — never commit
2
+ .env
3
+ .env.local
4
+ *.env
5
+
6
+ # Python
7
+ __pycache__/
8
+ *.py[cod]
9
+ *.egg-info/
10
+ .eggs/
11
+ dist/
12
+ build/
13
+
14
+ # Virtual environments
15
+ .venv/
16
+ venv/
17
+ env/
18
+
19
+ # IDE / OS
20
+ .idea/
21
+ .DS_Store
README.md ADDED
@@ -0,0 +1,195 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # OpenEnv-WolfeClick
2
+
3
+ OpenEnv-WolfeClick is a reinforcement learning environment and training workflow for competitive Pokemon battles with large language models.
4
+
5
+ The project was built for the OpenEnv hackathon to answer a specific question: can an LLM learn to act in a partially observable, adversarial, long-horizon environment where legal actions are constrained, rewards are delayed, and the opponent is another agent?
6
+
7
+ This repo focuses on that environment and a minimal Colab training path.
8
+
9
+ ## Why I Built This
10
+
11
+ Pokemon battles are a strong multi-agent training environment for LLMs because they require:
12
+
13
+ - hidden information and opponent modeling
14
+ - long-horizon planning over many turns
15
+ - legal action grounding under a constrained action space
16
+ - adapting to a changing world state after every action
17
+ - balancing local rewards against later consequences
18
+
19
+ I built this environment to make those properties trainable with a simple `reset()` / `step()` loop and a small JSON action interface.
20
+
21
+ ## What is in this repo
22
+
23
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl): environment, state formatting, action space, reward shaping, and client code
24
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb): main Colab notebook for warm-up SFT, rollout collection, and GRPO training
25
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/examples`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/examples): small local examples
26
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/pyproject.toml`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/pyproject.toml): package metadata
27
+
28
+ ## Environment design
29
+
30
+ ### State design
31
+
32
+ The state is not a raw simulator dump. It is a structured markdown representation designed to preserve strategic information while remaining readable to an LLM.
33
+
34
+ Each prompt includes:
35
+
36
+ - active self Pokemon
37
+ - active opponent Pokemon
38
+ - HP, status, ability, item, and current stat modifiers
39
+ - full self team roster with currently known moves
40
+ - opponent history and revealed information
41
+ - exact legal actions available this turn
42
+
43
+ This is implemented through the environment wrapper and state formatter:
44
+
45
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py)
46
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/state_formatter.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/state_formatter.py)
47
+
48
+ My design goal was to expose enough information for strategic decisions without giving the model shortcuts that bypass the game structure.
49
+
50
+ ### Action design
51
+
52
+ The action space is deliberately constrained.
53
+
54
+ The model must emit exactly one JSON object:
55
+
56
+ ```json
57
+ {"action": "move" | "switch", "choice": "Exact Name of Move or Pokemon"}
58
+ ```
59
+
60
+ At every step, legal actions are enumerated from the current battle state using:
61
+
62
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/action_space.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/action_space.py)
63
+
64
+ This module does three important things:
65
+
66
+ - enumerates legal moves and switches for the turn
67
+ - builds the action instruction block shown to the model
68
+ - validates model outputs against the legal action set
69
+
70
+ This matters because I do not want the model to “sort of” describe an action. I want the environment to enforce a concrete legal interface.
71
+
72
+ ### Reward design
73
+
74
+ The environment reward is shaped but still tied to battle outcomes.
75
+
76
+ Reward computation lives in:
77
+
78
+ - [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/reward.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/reward.py)
79
+
80
+ The reward includes:
81
+
82
+ - damage dealt to the opponent
83
+ - damage taken by the agent
84
+ - knockouts and faint penalties
85
+ - healing value
86
+ - setup value and opponent setup penalties
87
+ - passive damage value
88
+ - status penalties
89
+
90
+ The environment wrapper in [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/src/smogon_rl/openenv_sync_env.py) adds practical rollout constraints:
91
+
92
+ - illegal action fallback handling
93
+ - illegal action penalties
94
+ - anti-stall living penalty
95
+ - battle length caps
96
+ - no-progress termination penalties
97
+
98
+ This separation is intentional:
99
+
100
+ - `reward.py` captures battle-quality shaping
101
+ - the env wrapper handles rollout hygiene and training throughput
102
+
103
+ ## Training design
104
+
105
+ ### 1. Warm-up SFT
106
+
107
+ The notebook begins with a supervised warm-up stage so the model learns to emit valid action JSON for the battle-state prompt format.
108
+
109
+ This does not claim strategic mastery. It only ensures the model is good enough to participate in the environment without collapsing into malformed outputs.
110
+
111
+ ### 2. Real rollout collection
112
+
113
+ The policy is then run in real Pokemon Showdown battles. For each turn, the notebook stores:
114
+
115
+ - `prompt`
116
+ - `collected_action`
117
+ - `collected_reward`
118
+
119
+ This makes the rollout data usable for GRPO training while preserving the exact environment reward signal.
120
+
121
+ ### 3. GRPO training
122
+
123
+ The GRPO reward used in the notebook is a wrapper around the stored rollout reward.
124
+
125
+ It is designed to preserve ranking pressure inside a completion group:
126
+
127
+ - malformed output is penalized strongly
128
+ - valid but different actions are penalized lightly
129
+ - the action matching the executed rollout action receives the collected environment reward plus a positive margin
130
+
131
+ That matters because raw rollout rewards alone do not always create a clean learning signal for group-relative optimization.
132
+
133
+ ## How it works end to end
134
+
135
+ 1. Start Pokemon Showdown locally in Colab.
136
+ 2. Create the OpenEnv-style synchronous environment.
137
+ 3. Format battle state into markdown.
138
+ 4. Enumerate legal actions.
139
+ 5. Generate one JSON action from the model.
140
+ 6. Execute the action in the environment.
141
+ 7. Receive next state, reward, done flag, and info.
142
+ 8. Store rollout rows.
143
+ 9. Train with GRPO on the collected rows.
144
+
145
+ ## How to use
146
+
147
+ ### Local package install
148
+
149
+ From the repo root:
150
+
151
+ ```bash
152
+ python3 -m pip install -e .
153
+ ```
154
+
155
+ ### Colab training
156
+
157
+ Open [`/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb`](/Users/atharva/Desktop/Projects/OpenEnv-WolfeClick/trainer.ipynb) in Colab and run it top to bottom.
158
+
159
+ The notebook does the following:
160
+
161
+ 1. clones or uses the repo
162
+ 2. installs the training stack
163
+ 3. loads the model and LoRA adapter
164
+ 4. starts a local Pokemon Showdown server
165
+ 5. runs JSON warm-up SFT
166
+ 6. collects rollout data from real battles
167
+ 7. trains with GRPO
168
+ 8. optionally saves the adapter to Hugging Face Hub
169
+
170
+ ### Requirements
171
+
172
+ - GPU runtime in Colab
173
+ - local Pokemon Showdown server started from the notebook
174
+ - Hugging Face token only if you want to push adapters
175
+
176
+ ## Current status
177
+
178
+ This repo now has a working end-to-end path where:
179
+
180
+ - real battle rollouts are collected from the environment
181
+ - valid action JSON is produced reliably after warm-up
182
+ - GRPO can train on real rollout data in the non-quantized plain TRL path
183
+
184
+ This is the basis for my hackathon demo and benchmark runs.
185
+
186
+ ## Submission notes
187
+
188
+ This repo is intended to be my clean hackathon submission repo.
189
+
190
+ Linked artifacts to add before submission:
191
+
192
+ - Hugging Face model repo
193
+ - Hugging Face Space using OpenEnv stable release `0.2.1`
194
+ - benchmark/results file
195
+ - 1-minute demo video
examples/run_single_episode.py ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from pathlib import Path
5
+
6
+ from smogon_rl.action_space import ActionOption, enumerate_actions
7
+ from smogon_rl.config import EnvConfig
8
+ from smogon_rl.openenv_sync_env import PokemonShowdownEnv
9
+
10
+
11
+ def main() -> None:
12
+ config = EnvConfig()
13
+ env = PokemonShowdownEnv(config=config)
14
+
15
+ print("Starting a single gen4randombattle episode.")
16
+ obs = env.reset()
17
+ print("Initial state (truncated):")
18
+ print("\n".join(obs.splitlines()[:40]))
19
+
20
+ done = False
21
+ total_reward = 0.0
22
+ step_idx = 0
23
+
24
+ while not done and step_idx < config.max_steps_per_battle:
25
+ step_idx += 1
26
+ print(f"\n=== Step {step_idx} ===")
27
+
28
+ # Naive policy: query valid actions from the environment and always pick
29
+ # the first one. A real agent would send `obs` and `info["instructions"]`
30
+ # to an LLM and use its JSON response here.
31
+ battle = env._ensure_battle() # type: ignore[attr-defined]
32
+ valid_actions = enumerate_actions(battle)
33
+ if not valid_actions:
34
+ print("No valid actions available; terminating.")
35
+ break
36
+
37
+ chosen: ActionOption = valid_actions[0]
38
+ action_json = {"action": chosen.action_type, "choice": chosen.choice}
39
+ obs, reward, done, info = env.step(json.dumps(action_json))
40
+
41
+ total_reward += reward
42
+ print(f"Chosen action: {action_json}")
43
+ print(f"Reward: {reward:.3f}, Done: {done}")
44
+ print("State (truncated):")
45
+ print("\n".join(obs.splitlines()[:20]))
46
+
47
+ print(f"\nTotal reward: {total_reward}")
48
+
49
+
50
+ if __name__ == "__main__":
51
+ main()
52
+
pyproject.toml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [project]
2
+ name = "smogon-rl"
3
+ version = "0.1.0"
4
+ description = "Theory-of-Mind Pokémon RL environment using poke-env and OpenEnv."
5
+ readme = "README.md"
6
+ requires-python = ">=3.10"
7
+ authors = [
8
+ { name = "Atharva" }
9
+ ]
10
+ dependencies = [
11
+ "poke-env>=0.8.0,<0.9.0",
12
+ "numpy>=1.24.0",
13
+ "pydantic>=2.0.0",
14
+ ]
15
+
16
+ [project.optional-dependencies]
17
+ dev = [
18
+ "pytest>=7.0.0",
19
+ "ruff>=0.5.0",
20
+ ]
21
+
22
+ [build-system]
23
+ requires = ["hatchling"]
24
+ build-backend = "hatchling.build"
25
+
26
+ [tool.uv]
27
+ package = "smogon-rl"
28
+
29
+ [tool.uv.sources]
30
+
src/smogon_rl/__init__.py ADDED
@@ -0,0 +1,13 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Smogon-RL core package.
3
+
4
+ This package provides:
5
+ - An async poke-env client for Pokémon Showdown battles.
6
+ - A synchronous, OpenEnv-style wrapper exposing reset/step.
7
+ - State formatting, action space handling, and reward shaping utilities.
8
+ """
9
+
10
+ from .config import DEFAULT_BATTLE_FORMAT
11
+
12
+ __all__ = ["DEFAULT_BATTLE_FORMAT"]
13
+
src/smogon_rl/action_space.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ import re
5
+ from dataclasses import dataclass
6
+ from typing import List, Literal, Optional
7
+
8
+ from pydantic import BaseModel, ValidationError
9
+ from poke_env.environment.battle import Battle
10
+ from poke_env.environment.move import Move
11
+ from poke_env.environment.pokemon import Pokemon
12
+
13
+ # Match a single JSON object with "action" and "choice" (handles <think>...</think> + JSON).
14
+ _ACTION_JSON_RE = re.compile(
15
+ r'\{\s*"action"\s*:\s*"(?:move|switch)"\s*,\s*"choice"\s*:\s*"[^"]*"\s*\}',
16
+ re.IGNORECASE,
17
+ )
18
+
19
+
20
+ ActionType = Literal["move", "switch"]
21
+
22
+
23
+ @dataclass
24
+ class ActionOption:
25
+ """Concrete action option available in the current state."""
26
+
27
+ action_type: ActionType
28
+ choice: str
29
+ move: Optional[Move] = None
30
+ pokemon: Optional[Pokemon] = None
31
+
32
+
33
+ class ActionJSON(BaseModel):
34
+ """Strict JSON schema the LLM must output."""
35
+
36
+ action: ActionType
37
+ choice: str
38
+
39
+
40
+ def enumerate_actions(battle: Battle) -> List[ActionOption]:
41
+ """Enumerate up to 4 moves and up to 5 switches for the current state."""
42
+ options: List[ActionOption] = []
43
+
44
+ # Moves
45
+ for move in battle.available_moves[:4]:
46
+ if getattr(move, "current_pp", 1) <= 0:
47
+ continue
48
+ choice = move.id
49
+ options.append(ActionOption(action_type="move", choice=choice, move=move))
50
+
51
+ # Switches
52
+ for pokemon in battle.available_switches[:5]:
53
+ if pokemon.fainted:
54
+ continue
55
+ choice = pokemon.species or pokemon.nickname or "Unknown"
56
+ options.append(
57
+ ActionOption(action_type="switch", choice=choice, pokemon=pokemon)
58
+ )
59
+
60
+ return options
61
+
62
+
63
+ def _normalize_choice(s: str) -> str:
64
+ """Normalize choice for comparison: lowercase, spaces to hyphens (matches poke-env move ids)."""
65
+ return s.strip().lower().replace(" ", "-")
66
+
67
+
68
+ def extract_action_json_from_text(text: str) -> Optional[str]:
69
+ """Extract a single action JSON object from model output that may contain thinking or prose.
70
+
71
+ Strips think tags first, then looks for our schema in the remainder (or in the full string).
72
+ Returns the first matching JSON substring, or None if none found.
73
+ """
74
+ if not text or not text.strip():
75
+ return None
76
+ # Strip think blocks first so we prefer content after thinking.
77
+ stripped = re.sub(r"<think>.*?</think>", "", text, flags=re.DOTALL).strip()
78
+ for candidate in (stripped, text):
79
+ match = _ACTION_JSON_RE.search(candidate)
80
+ if match:
81
+ return match.group(0)
82
+ return None
83
+
84
+
85
+ def parse_llm_action(raw_output: str, valid_actions: List[ActionOption]) -> ActionJSON:
86
+ """Parse and validate the LLM JSON output against the current action set.
87
+
88
+ The model must output:
89
+ {
90
+ "action": "move" | "switch",
91
+ "choice": "Exact Name of Move or Pokemon"
92
+ }
93
+ Choice matching is case-insensitive and normalizes spaces to hyphens so
94
+ "Flamethrower" and "Thunder Wave" match env ids "flamethrower" and "thunder-wave".
95
+ """
96
+ try:
97
+ payload = json.loads(raw_output)
98
+ except json.JSONDecodeError as exc:
99
+ raise ValueError(f"Model output is not valid JSON: {exc}") from exc
100
+
101
+ try:
102
+ action = ActionJSON.model_validate(payload)
103
+ except ValidationError as exc:
104
+ raise ValueError(f"Model JSON does not match schema: {exc}") from exc
105
+
106
+ want_norm = _normalize_choice(action.choice)
107
+ matched = None
108
+ for a in valid_actions:
109
+ if a.action_type != action.action:
110
+ continue
111
+ if _normalize_choice(a.choice) == want_norm:
112
+ matched = a
113
+ break
114
+ if matched is None:
115
+ valid_desc = [
116
+ {"action": a.action_type, "choice": a.choice} for a in valid_actions
117
+ ]
118
+ raise ValueError(
119
+ f"Invalid action selection {action.model_dump()}. "
120
+ f"Valid options are: {valid_desc}"
121
+ )
122
+ # Return with the env's exact choice string so downstream uses the right id.
123
+ return ActionJSON(action=action.action, choice=matched.choice)
124
+
125
+
126
+ def build_action_instructions(valid_actions: List[ActionOption]) -> str:
127
+ """Build a short instruction string describing the JSON schema and options."""
128
+ lines = [
129
+ "You must choose exactly one action and output pure JSON with this schema:",
130
+ "",
131
+ '{"action": "move" | "switch", "choice": "Exact Name of Move or Pokemon"}',
132
+ "",
133
+ "Valid options for this state:",
134
+ ]
135
+ for opt in valid_actions:
136
+ lines.append(f"- action: {opt.action_type!r}, choice: {opt.choice!r}")
137
+ return "\n".join(lines)
138
+
src/smogon_rl/config.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass
4
+
5
+
6
+ DEFAULT_BATTLE_FORMAT = "gen4randombattle"
7
+
8
+
9
+ @dataclass
10
+ class EnvConfig:
11
+ """Configuration for the Pokémon RL environment."""
12
+
13
+ battle_format: str = DEFAULT_BATTLE_FORMAT
14
+ # Hard cap to prevent very long battles from dominating rollout wall-time.
15
+ max_steps_per_battle: int = 30
16
+ poll_interval_seconds: float = 0.2
17
+ open_timeout: float = 25.0
18
+ show_replays: bool = False
19
+ verbose_logging: bool = False
20
+ log_every_n_steps: int = 25
21
+ poll_heartbeat_seconds: float = 5.0
22
+ min_battle_reward: float = -100.0
23
+ max_no_progress_steps: int = 2
24
+ # Small per-step time penalty to bias toward faster, decisive games.
25
+ step_living_penalty: float = -0.05
26
+ # Additional truncation/timeout penalties.
27
+ no_progress_termination_penalty: float = -1.0
28
+ max_steps_termination_penalty: float = -2.0
src/smogon_rl/openenv_sync_env.py ADDED
@@ -0,0 +1,260 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import json
4
+ from dataclasses import dataclass, field
5
+ from typing import Any, Dict, Optional, Tuple
6
+
7
+ from poke_env.environment.battle import Battle
8
+ from poke_env.player.player import Player
9
+
10
+ from .action_space import (
11
+ ActionJSON,
12
+ ActionOption,
13
+ build_action_instructions,
14
+ enumerate_actions,
15
+ extract_action_json_from_text,
16
+ parse_llm_action,
17
+ )
18
+ from .config import EnvConfig
19
+ from .pokeenv_client import PokeEnvClient
20
+ from .reward import (
21
+ BattleStateSummary,
22
+ ILLEGAL_ACTION_PENALTY,
23
+ RewardTrackingState,
24
+ calculate_reward,
25
+ count_new_passive_hits_for_turn,
26
+ summarize_battle_state,
27
+ )
28
+ from .state_formatter import OpponentHistoryTracker, format_battle_state
29
+
30
+
31
+ @dataclass
32
+ class PokemonShowdownEnv:
33
+ """Synchronous, OpenEnv-style wrapper around a poke-env battle.
34
+
35
+ The environment exposes a simple Gymnasium-like / OpenEnv-like API:
36
+
37
+ obs = env.reset()
38
+ obs, reward, done, info = env.step(action_json_str)
39
+
40
+ where `action_json_str` is a JSON string describing a move or switch using
41
+ the constrained 9-action space.
42
+ """
43
+
44
+ config: EnvConfig = field(default_factory=EnvConfig)
45
+ _client: PokeEnvClient = field(init=False)
46
+ _opponent_history: OpponentHistoryTracker = field(init=False)
47
+ _reward_trackers: RewardTrackingState = field(init=False)
48
+ _prev_state: Optional[BattleStateSummary] = field(init=False, default=None)
49
+ _steps_this_battle: int = field(init=False, default=0)
50
+ # Running total of passive hits — updated O(k) per step via the single-turn
51
+ # scanner, never by re-scanning the full observation history.
52
+ _cumulative_passive_hits: int = field(init=False, default=0)
53
+ _battle_index: int = field(init=False, default=0)
54
+ _battle_reward_total: float = field(init=False, default=0.0)
55
+ _no_progress_steps: int = field(init=False, default=0)
56
+
57
+ def __post_init__(self) -> None:
58
+ self._client = PokeEnvClient(config=self.config)
59
+ self._opponent_history = OpponentHistoryTracker()
60
+ self._reward_trackers = RewardTrackingState()
61
+
62
+ def _log(self, message: str) -> None:
63
+ if self.config.verbose_logging:
64
+ print(f"[PokemonShowdownEnv] {message}", flush=True)
65
+
66
+ # ------------------------------------------------------------------ API
67
+
68
+ def reset(self) -> str:
69
+ """Start a new battle and return the initial markdown state."""
70
+ self._battle_index += 1
71
+ self._client.start_new_battle()
72
+ self._opponent_history = OpponentHistoryTracker()
73
+ self._reward_trackers = RewardTrackingState()
74
+ self._steps_this_battle = 0
75
+ self._cumulative_passive_hits = 0
76
+ self._battle_reward_total = 0.0
77
+ self._no_progress_steps = 0
78
+
79
+ battle = self._wait_for_battle_or_raise()
80
+ self._log(
81
+ f"Battle {self._battle_index} started at turn={battle.turn} "
82
+ f"(format={self.config.battle_format})."
83
+ )
84
+ self._prev_state = summarize_battle_state(battle, self._cumulative_passive_hits)
85
+ return format_battle_state(battle, self._opponent_history)
86
+
87
+ def step(self, action_json: str | Dict[str, Any]) -> Tuple[str, float, bool, Dict[str, Any]]:
88
+ """Apply one action and return (state_str, reward, done, info)."""
89
+ battle = self._ensure_battle()
90
+ if battle.finished:
91
+ raise RuntimeError("Cannot call step() on a finished battle. Call reset().")
92
+
93
+ self._steps_this_battle += 1
94
+ if self._steps_this_battle > self.config.max_steps_per_battle:
95
+ return self._terminal_from_truncation(battle)
96
+
97
+ valid_actions = enumerate_actions(battle)
98
+ if isinstance(action_json, dict):
99
+ raw = json.dumps(action_json)
100
+ else:
101
+ raw = action_json
102
+
103
+ used_fallback = False
104
+ try:
105
+ parsed = parse_llm_action(raw, valid_actions)
106
+ order = self._to_battle_order(parsed, valid_actions, battle)
107
+ except ValueError:
108
+ extracted = extract_action_json_from_text(raw)
109
+ if extracted is not None:
110
+ try:
111
+ parsed = parse_llm_action(extracted, valid_actions)
112
+ order = self._to_battle_order(parsed, valid_actions, battle)
113
+ except ValueError:
114
+ used_fallback = True
115
+ else:
116
+ used_fallback = True
117
+ if used_fallback:
118
+ opt = valid_actions[0]
119
+ from poke_env.player import Player as PlayerCls
120
+ if opt.action_type == "move" and opt.move is not None:
121
+ order = PlayerCls.create_order(opt.move)
122
+ else:
123
+ order = PlayerCls.create_order(opt.pokemon)
124
+
125
+ previous_turn = battle.turn
126
+ self._client.send_action(order)
127
+ new_battle = self._client.wait_for_battle_update(previous_turn) or battle
128
+
129
+ # Increment the passive-hit counter by scanning only the turn that just
130
+ # resolved — O(k) where k = events on that single turn, not O(total turns).
131
+ self._cumulative_passive_hits += count_new_passive_hits_for_turn(
132
+ new_battle, previous_turn
133
+ )
134
+
135
+ prev_state = self._prev_state or summarize_battle_state(battle, self._cumulative_passive_hits)
136
+ curr_state = summarize_battle_state(new_battle, self._cumulative_passive_hits)
137
+
138
+ active = new_battle.active_pokemon
139
+ opponent_active = new_battle.opponent_active_pokemon
140
+
141
+ if used_fallback:
142
+ reward = ILLEGAL_ACTION_PENALTY
143
+ else:
144
+ reward = calculate_reward(
145
+ prev_state=prev_state,
146
+ curr_state=curr_state,
147
+ action=ActionJSON(action=parsed.action, choice=parsed.choice),
148
+ trackers=self._reward_trackers,
149
+ active=active,
150
+ opponent_active=opponent_active,
151
+ )
152
+ # Small time cost per turn to discourage excessively long battles.
153
+ reward += self.config.step_living_penalty
154
+
155
+ self._prev_state = curr_state
156
+ if new_battle.turn == previous_turn and not new_battle.finished:
157
+ self._no_progress_steps += 1
158
+ else:
159
+ self._no_progress_steps = 0
160
+
161
+ done_reason: Optional[str] = None
162
+ done = False
163
+ if new_battle.finished:
164
+ done = True
165
+ done_reason = "battle_finished"
166
+ elif self._steps_this_battle >= self.config.max_steps_per_battle:
167
+ done = True
168
+ done_reason = "max_steps"
169
+ reward += self.config.max_steps_termination_penalty
170
+ elif (self._battle_reward_total + reward) <= self.config.min_battle_reward:
171
+ done = True
172
+ done_reason = "min_battle_reward"
173
+ elif self._no_progress_steps >= self.config.max_no_progress_steps:
174
+ done = True
175
+ done_reason = "no_progress_timeout"
176
+ reward += self.config.no_progress_termination_penalty
177
+
178
+ self._battle_reward_total += reward
179
+
180
+ # If we terminate early (not a natural finished battle), forfeit cleanly
181
+ # so the next reset starts from a free player/session state.
182
+ if done and not new_battle.finished and done_reason in {
183
+ "max_steps",
184
+ "min_battle_reward",
185
+ "no_progress_timeout",
186
+ }:
187
+ try:
188
+ self._client.forfeit_current_battle()
189
+ except Exception:
190
+ pass
191
+
192
+ obs = format_battle_state(new_battle, self._opponent_history)
193
+ info: Dict[str, Any] = {
194
+ "turn": new_battle.turn,
195
+ "valid_actions": [
196
+ {"action": a.action_type, "choice": a.choice} for a in valid_actions
197
+ ],
198
+ "instructions": build_action_instructions(valid_actions),
199
+ "battle_finished": new_battle.finished,
200
+ "reason": done_reason,
201
+ "action_illegal": used_fallback,
202
+ "battle_reward_total": self._battle_reward_total,
203
+ "no_progress_steps": self._no_progress_steps,
204
+ }
205
+ if self.config.verbose_logging:
206
+ should_log_step = (
207
+ used_fallback
208
+ or done
209
+ or self._steps_this_battle == 1
210
+ or self._steps_this_battle % max(1, self.config.log_every_n_steps) == 0
211
+ )
212
+ if should_log_step:
213
+ self._log(
214
+ f"battle={self._battle_index} step={self._steps_this_battle} "
215
+ f"turn={new_battle.turn} reward={reward:.3f} "
216
+ f"running_reward={self._battle_reward_total:.3f} "
217
+ f"illegal_action={used_fallback} done={done}"
218
+ )
219
+ return obs, reward, done, info
220
+
221
+ # ------------------------------------------------------------------ helpers
222
+
223
+ def _wait_for_battle_or_raise(self) -> Battle:
224
+ battle = self._client.battle
225
+ if battle is None:
226
+ battle = self._client.wait_for_battle_update(previous_turn=0)
227
+ if battle is None:
228
+ raise RuntimeError("Failed to obtain initial battle from poke-env.")
229
+ return battle
230
+
231
+ def _ensure_battle(self) -> Battle:
232
+ battle = self._client.battle
233
+ if battle is None:
234
+ raise RuntimeError("No active battle. Call reset() first.")
235
+ return battle
236
+
237
+ def _terminal_from_truncation(self, battle: Battle) -> Tuple[str, float, bool, Dict[str, Any]]:
238
+ obs = format_battle_state(battle, self._opponent_history)
239
+ info: Dict[str, Any] = {
240
+ "turn": battle.turn,
241
+ "battle_finished": battle.finished,
242
+ "reason": "max_steps",
243
+ }
244
+ return obs, self.config.max_steps_termination_penalty, True, info
245
+
246
+ @staticmethod
247
+ def _to_battle_order(
248
+ parsed: ActionJSON,
249
+ valid_actions: list[ActionOption],
250
+ battle: Battle,
251
+ ) -> "Player.create_order.__annotations__['return']":
252
+ from poke_env.player import Player as PlayerCls
253
+
254
+ for opt in valid_actions:
255
+ if opt.action_type == parsed.action and opt.choice == parsed.choice:
256
+ if opt.action_type == "move" and opt.move is not None:
257
+ return PlayerCls.create_order(opt.move)
258
+ if opt.action_type == "switch" and opt.pokemon is not None:
259
+ return PlayerCls.create_order(opt.pokemon)
260
+ raise ValueError(f"Could not map parsed action {parsed} to a BattleOrder")
src/smogon_rl/pokeenv_client.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import asyncio
4
+ import threading
5
+ import time
6
+ from dataclasses import dataclass, field
7
+ from typing import Optional
8
+
9
+ from poke_env.environment.battle import Battle
10
+ from poke_env.player import Player, RandomPlayer
11
+ from poke_env.player.battle_order import BattleOrder
12
+ from poke_env.ps_client.server_configuration import LocalhostServerConfiguration
13
+
14
+ from .config import DEFAULT_BATTLE_FORMAT, EnvConfig
15
+
16
+
17
+ class RLPlayer(Player):
18
+ """Player controlled externally via an asyncio queue of BattleOrders."""
19
+
20
+ def __init__(self, action_queue: "asyncio.Queue[BattleOrder]", **kwargs) -> None:
21
+ super().__init__(**kwargs)
22
+ self._action_queue: "asyncio.Queue[BattleOrder]" = action_queue
23
+
24
+ async def choose_move(self, battle: Battle) -> BattleOrder:
25
+ return await self._action_queue.get()
26
+
27
+
28
+ @dataclass
29
+ class PokeEnvClient:
30
+ """Asynchronous client that manages poke-env battles in a background loop.
31
+
32
+ Players are created ONCE when the loop starts and reused across battles to
33
+ avoid Showdown nametaken errors from zombie connections.
34
+ """
35
+
36
+ config: EnvConfig
37
+
38
+ def __post_init__(self) -> None:
39
+ self._loop: Optional[asyncio.AbstractEventLoop] = None
40
+ self._thread: Optional[threading.Thread] = None
41
+ self._action_queue: Optional["asyncio.Queue[BattleOrder]"] = None
42
+ self._rl_player: Optional[RLPlayer] = None
43
+ self._opponent: Optional[RandomPlayer] = None
44
+ self._battle_task: Optional[asyncio.Future] = None
45
+ # Snapshot of existing battle tags before we request a new battle.
46
+ self._known_battle_tags: set[str] = set()
47
+ self._awaiting_new_battle: bool = False
48
+ # Stored reference to the battle we are in (set when .battle is read).
49
+ # Used for forfeit so we always target the right battle.
50
+ self._current_battle: Optional[Battle] = None
51
+
52
+ def _log(self, message: str) -> None:
53
+ if self.config.verbose_logging:
54
+ print(f"[PokeEnvClient] {message}", flush=True)
55
+
56
+ # -------------------------------------------------------------------------
57
+ # Event loop management
58
+ # -------------------------------------------------------------------------
59
+
60
+ def start(self) -> None:
61
+ """Start the background asyncio loop and create players (once)."""
62
+ if self._loop is not None:
63
+ return
64
+
65
+ loop = asyncio.new_event_loop()
66
+
67
+ def _run_loop() -> None:
68
+ asyncio.set_event_loop(loop)
69
+ loop.run_forever()
70
+
71
+ thread = threading.Thread(target=_run_loop, daemon=True)
72
+ thread.start()
73
+
74
+ self._loop = loop
75
+ self._thread = thread
76
+ self._log("Background event loop started.")
77
+
78
+ # Create players once; they stay connected for the lifetime of this env.
79
+ self._action_queue = asyncio.Queue()
80
+ fmt = self.config.battle_format or DEFAULT_BATTLE_FORMAT
81
+
82
+ async def _create_players() -> None:
83
+ self._rl_player = RLPlayer(
84
+ action_queue=self._action_queue,
85
+ battle_format=fmt,
86
+ server_configuration=LocalhostServerConfiguration,
87
+ )
88
+ self._opponent = RandomPlayer(
89
+ battle_format=fmt,
90
+ server_configuration=LocalhostServerConfiguration,
91
+ )
92
+
93
+ future = asyncio.run_coroutine_threadsafe(_create_players(), loop)
94
+ future.result(timeout=15.0)
95
+ # Give the server a moment to register both connections.
96
+ time.sleep(1.0)
97
+ self._log("Players created and connected.")
98
+
99
+ def stop(self) -> None:
100
+ """Stop the background loop and clean up."""
101
+ if self._loop is None:
102
+ return
103
+ self._loop.call_soon_threadsafe(self._loop.stop)
104
+ if self._thread is not None:
105
+ self._thread.join(timeout=5.0)
106
+ self._loop = None
107
+ self._thread = None
108
+ self._battle_task = None
109
+ self._rl_player = None
110
+ self._opponent = None
111
+ self._action_queue = None
112
+ self._known_battle_tags = set()
113
+ self._awaiting_new_battle = False
114
+ self._current_battle = None
115
+ self._log("Background event loop stopped.")
116
+
117
+ def restart(self) -> None:
118
+ """Hard-restart loop + players to recover from stuck/cancelled battles."""
119
+ self._log("Restarting client event loop and players.")
120
+ self.stop()
121
+ self.start()
122
+
123
+ # -------------------------------------------------------------------------
124
+ # Battle lifecycle
125
+ # -------------------------------------------------------------------------
126
+
127
+ def forfeit_current_battle(self) -> None:
128
+ """Forfeit the current Showdown battle if it is still in progress.
129
+
130
+ Must be called before start_new_battle() when the env ends a battle early
131
+ (e.g. due to min_battle_reward) so the player is freed for the next battle.
132
+ """
133
+ if self._loop is None or self._rl_player is None:
134
+ return
135
+ # Use stored battle so we forfeit the one we were in, not whatever .battle returns now.
136
+ battle = self._current_battle if self._current_battle is not None else self.battle
137
+ if battle is None or battle.finished:
138
+ return
139
+
140
+ room = battle.battle_tag
141
+
142
+ async def _do_forfeit() -> None:
143
+ try:
144
+ await self._rl_player.send_message("/forfeit", room)
145
+ except Exception:
146
+ pass
147
+
148
+ try:
149
+ fut = asyncio.run_coroutine_threadsafe(_do_forfeit(), self._loop)
150
+ fut.result(timeout=5.0)
151
+ except Exception:
152
+ pass
153
+ # Give the server time to end the battle and free both players.
154
+ time.sleep(1.5)
155
+ self._current_battle = None
156
+ self._log("Forfeited current battle.")
157
+
158
+ def start_new_battle(self) -> None:
159
+ """Launch a new battle using the already-connected players."""
160
+ if self._loop is None:
161
+ self.start()
162
+ assert self._loop is not None
163
+ assert self._rl_player is not None
164
+ assert self._opponent is not None
165
+
166
+ # Forfeit any ongoing Showdown battle before starting a new one so the
167
+ # player is not stuck mid-battle when battle_against is called again.
168
+ self.forfeit_current_battle()
169
+
170
+ # Let the previous battle task finish cleanly (server will end battle
171
+ # after forfeit). If it does not settle, hard-restart the client.
172
+ restart_required = False
173
+ if self._battle_task is not None and not self._battle_task.done():
174
+ try:
175
+ self._battle_task.result(timeout=25.0)
176
+ except Exception:
177
+ self._battle_task.cancel()
178
+ self._log("Previous battle task timed out or failed; requesting client restart.")
179
+ restart_required = True
180
+ else:
181
+ self._log("Previous battle task finished.")
182
+
183
+ if restart_required:
184
+ # Hard recovery path: refresh websocket connections and players.
185
+ self.restart()
186
+ assert self._loop is not None
187
+ assert self._rl_player is not None
188
+ assert self._opponent is not None
189
+
190
+ self._current_battle = None # Will be set when the new battle appears.
191
+
192
+ # Let the server fully free both players before we start the next battle.
193
+ time.sleep(2.0)
194
+
195
+ # Fresh action queue for this battle.
196
+ self._action_queue = asyncio.Queue()
197
+ self._rl_player._action_queue = self._action_queue
198
+
199
+ # Record current battle tags so .battle can wait for a genuinely new one.
200
+ self._known_battle_tags = set(self._rl_player.battles.keys())
201
+ self._awaiting_new_battle = True
202
+
203
+ async def _run_battle() -> None:
204
+ await self._rl_player.battle_against(self._opponent, n_battles=1)
205
+
206
+ self._battle_task = asyncio.run_coroutine_threadsafe(
207
+ _run_battle(), self._loop
208
+ )
209
+ self._log(
210
+ f"Launching new battle in format "
211
+ f"{self.config.battle_format or DEFAULT_BATTLE_FORMAT}."
212
+ )
213
+ time.sleep(self.config.poll_interval_seconds)
214
+
215
+ @property
216
+ def battle(self) -> Optional[Battle]:
217
+ """Return the current Battle for this run, or None if not started yet."""
218
+ if self._rl_player is None or not self._rl_player.battles:
219
+ return None
220
+
221
+ # During reset(), wait for a battle tag that did not exist before
222
+ # start_new_battle() was called.
223
+ if self._awaiting_new_battle:
224
+ unseen = [
225
+ b
226
+ for tag, b in self._rl_player.battles.items()
227
+ if tag not in self._known_battle_tags
228
+ ]
229
+ if not unseen:
230
+ return None
231
+ active_unseen = [b for b in unseen if not b.finished]
232
+ b = active_unseen[-1] if active_unseen else unseen[-1]
233
+ self._awaiting_new_battle = False
234
+ self._current_battle = b
235
+ return b
236
+
237
+ battles = list(self._rl_player.battles.values())
238
+ active = [b for b in battles if not b.finished]
239
+ if active:
240
+ b = active[-1]
241
+ self._current_battle = b
242
+ return b
243
+ # All finished — return the latest one (covers the case where the battle
244
+ # ended before we got a chance to poll it).
245
+ b = battles[-1]
246
+ self._current_battle = b
247
+ return b
248
+
249
+ def send_action(self, order: BattleOrder) -> None:
250
+ """Submit an action for the RL player to execute."""
251
+ if self._loop is None or self._action_queue is None:
252
+ raise RuntimeError("PokeEnvClient has not been started.")
253
+
254
+ async def _enqueue() -> None:
255
+ assert self._action_queue is not None
256
+ await self._action_queue.put(order)
257
+
258
+ asyncio.run_coroutine_threadsafe(_enqueue(), self._loop)
259
+ self._log("Submitted action to RLPlayer queue.")
260
+
261
+ def wait_for_battle_update(self, previous_turn: int) -> Optional[Battle]:
262
+ """Block until the battle advances to a new turn or ends."""
263
+ start_time = time.time()
264
+ heartbeat_every = max(self.config.poll_heartbeat_seconds, self.config.poll_interval_seconds)
265
+ next_heartbeat_at = start_time + heartbeat_every
266
+ while True:
267
+ battle = self.battle
268
+ if battle is None:
269
+ now = time.time()
270
+ if now > next_heartbeat_at:
271
+ elapsed = now - start_time
272
+ self._log(
273
+ f"Still waiting for battle object "
274
+ f"({elapsed:.1f}s elapsed, previous_turn={previous_turn})."
275
+ )
276
+ next_heartbeat_at = now + heartbeat_every
277
+ if now - start_time > self.config.open_timeout:
278
+ self._log("Timed out waiting for initial battle object.")
279
+ return None
280
+ time.sleep(self.config.poll_interval_seconds)
281
+ continue
282
+
283
+ if battle.finished or battle.turn > previous_turn:
284
+ self._log(
285
+ f"Battle update received: turn={battle.turn}, finished={battle.finished}."
286
+ )
287
+ return battle
288
+
289
+ now = time.time()
290
+ if now > next_heartbeat_at:
291
+ elapsed = now - start_time
292
+ self._log(
293
+ f"Waiting for turn advance: current_turn={battle.turn}, "
294
+ f"previous_turn={previous_turn}, elapsed={elapsed:.1f}s."
295
+ )
296
+ next_heartbeat_at = now + heartbeat_every
297
+
298
+ if now - start_time > self.config.open_timeout:
299
+ self._log(
300
+ f"Turn-advance wait timed out at turn={battle.turn}; returning last state."
301
+ )
302
+ return battle
303
+
304
+ time.sleep(self.config.poll_interval_seconds)
src/smogon_rl/reward.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional
5
+
6
+ from poke_env.environment.battle import Battle
7
+ from poke_env.environment.pokemon import Pokemon
8
+
9
+ from .action_space import ActionJSON
10
+ from .state_formatter import hp_fraction_to_percent
11
+
12
+ # Hefty penalty when model outputs illegal action (e.g. hallucinated Pokemon).
13
+ # Used during rollout collection; recorded as collected_reward so GRPO learns to avoid illegal outputs.
14
+ ILLEGAL_ACTION_PENALTY = -10.0
15
+
16
+
17
+ @dataclass
18
+ class BattleStateSummary:
19
+ self_team_hp_percent: float
20
+ opp_team_hp_percent: float
21
+ self_fainted: int
22
+ opp_fainted: int
23
+ self_statuses: Dict[str, Optional[str]]
24
+ opp_statuses: Dict[str, Optional[str]]
25
+ self_stat_stages: Dict[str, Dict[str, int]]
26
+ opp_stat_stages: Dict[str, Dict[str, int]]
27
+ opponent_passive_hits: int
28
+
29
+
30
+ @dataclass
31
+ class RewardTrackingState:
32
+ healing_reward_used: float = 0.0
33
+ per_pokemon_setup_reward_used: Dict[str, float] = field(default_factory=dict)
34
+ passive_hits_total: int = 0
35
+
36
+
37
+ def _team_hp_and_faints(team: Dict[str, Pokemon]) -> tuple[float, int]:
38
+ total_hp = 0.0
39
+ total_max_hp = 0.0
40
+ fainted = 0
41
+ for mon in team.values():
42
+ if mon.max_hp is None or mon.max_hp <= 0:
43
+ continue
44
+ total_hp += max(0, mon.current_hp)
45
+ total_max_hp += mon.max_hp
46
+ if mon.fainted:
47
+ fainted += 1
48
+ if total_max_hp <= 0:
49
+ return 0.0, fainted
50
+ return (total_hp / total_max_hp) * 100.0, fainted
51
+
52
+
53
+ def _collect_statuses(team: Dict[str, Pokemon]) -> Dict[str, Optional[str]]:
54
+ return {
55
+ mon.species or key: (str(mon.status) if mon.status is not None else None)
56
+ for key, mon in team.items()
57
+ }
58
+
59
+
60
+ def _collect_stat_stages(team: Dict[str, Pokemon]) -> Dict[str, Dict[str, int]]:
61
+ return {mon.species or key: dict(mon.boosts) for key, mon in team.items()}
62
+
63
+
64
+ def _passive_events_in_turn(events: list, opponent_role: str) -> int:
65
+ """Count passive-damage hits for the opponent in one turn's raw event list."""
66
+ count = 0
67
+ for event in events:
68
+ if not event or event[0] != "-damage":
69
+ continue
70
+ if len(event) < 2:
71
+ continue
72
+ if not event[1].startswith(opponent_role):
73
+ continue
74
+ # "[from]" in any trailing field marks an external/passive damage source:
75
+ # e.g. "[from] brn", "[from] Stealth Rock", "[from] Leech Seed", etc.
76
+ if any("[from]" in part for part in event[2:]):
77
+ count += 1
78
+ return count
79
+
80
+
81
+ def count_new_passive_hits_for_turn(battle: Battle, turn_number: int) -> int:
82
+ """Count passive damage hits the opponent took on a single, specific turn.
83
+
84
+ Designed for O(k) per step use: only the events from `turn_number` are
85
+ scanned. The caller accumulates the running total across turns.
86
+
87
+ Parameters
88
+ ----------
89
+ battle:
90
+ The current poke-env Battle object.
91
+ turn_number:
92
+ The turn whose Observation.events should be inspected (usually the
93
+ turn that just resolved, i.e., the value of `battle.turn` before
94
+ the action was submitted).
95
+ """
96
+ obs = battle.observations.get(turn_number)
97
+ if obs is None:
98
+ return 0
99
+ opponent_role = "p2" if battle.player_role == "p1" else "p1"
100
+ return _passive_events_in_turn(obs.events, opponent_role)
101
+
102
+
103
+ def _count_passive_hits_on_opponent(battle: Battle) -> int:
104
+ """Full-scan fallback: count cumulative passive hits across all observed turns.
105
+
106
+ This is O(total events) and should only be called once on reset() to
107
+ establish a baseline. Per-step increments should use
108
+ `count_new_passive_hits_for_turn` instead.
109
+ """
110
+ opponent_role = "p2" if battle.player_role == "p1" else "p1"
111
+ count = 0
112
+ for obs in battle.observations.values():
113
+ count += _passive_events_in_turn(obs.events, opponent_role)
114
+ return count
115
+
116
+
117
+ def summarize_battle_state(battle: Battle, cumulative_passive_hits: int = 0) -> BattleStateSummary:
118
+ """Snapshot the current battle state into a plain dataclass.
119
+
120
+ Parameters
121
+ ----------
122
+ battle:
123
+ The live poke-env Battle object.
124
+ cumulative_passive_hits:
125
+ Running total of passive damage hits the opponent has taken this
126
+ battle, maintained by the caller (e.g. PokemonShowdownEnv) using
127
+ `count_new_passive_hits_for_turn` to keep each step O(k).
128
+ Defaults to 0 for the initial state on reset().
129
+ """
130
+ self_hp, self_fainted = _team_hp_and_faints(battle.team)
131
+ opp_hp, opp_fainted = _team_hp_and_faints(battle.opponent_team)
132
+ self_statuses = _collect_statuses(battle.team)
133
+ opp_statuses = _collect_statuses(battle.opponent_team)
134
+ self_stats = _collect_stat_stages(battle.team)
135
+ opp_stats = _collect_stat_stages(battle.opponent_team)
136
+ return BattleStateSummary(
137
+ self_team_hp_percent=self_hp,
138
+ opp_team_hp_percent=opp_hp,
139
+ self_fainted=self_fainted,
140
+ opp_fainted=opp_fainted,
141
+ self_statuses=self_statuses,
142
+ opp_statuses=opp_statuses,
143
+ self_stat_stages=self_stats,
144
+ opp_stat_stages=opp_stats,
145
+ opponent_passive_hits=cumulative_passive_hits,
146
+ )
147
+
148
+
149
+ def _status_penalty(prev_statuses: Dict[str, Optional[str]], curr_statuses: Dict[str, Optional[str]]) -> float:
150
+ penalty = 0.0
151
+ for key, curr in curr_statuses.items():
152
+ prev = prev_statuses.get(key)
153
+ if prev == curr:
154
+ continue
155
+ if curr is None:
156
+ # Could be a status cure handled elsewhere.
157
+ continue
158
+ code = curr.lower()
159
+ if code in {"brn", "psn", "tox"}:
160
+ penalty -= 0.5
161
+ elif code in {"par", "frz", "slp", "conf"}:
162
+ penalty -= 1.0
163
+ return penalty
164
+
165
+
166
+ def _healing_reward(prev_hp: float, curr_hp: float, trackers: RewardTrackingState) -> float:
167
+ if curr_hp <= prev_hp:
168
+ return 0.0
169
+ healed = curr_hp - prev_hp
170
+ raw = (healed / 10.0) # +1.0 per 10% healed
171
+ remaining_cap = max(0.0, 3.0 - trackers.healing_reward_used)
172
+ reward = min(raw, remaining_cap)
173
+ trackers.healing_reward_used += reward
174
+ return reward
175
+
176
+
177
+ def _setup_reward(
178
+ prev_stats: Dict[str, Dict[str, int]],
179
+ curr_stats: Dict[str, Dict[str, int]],
180
+ active: Pokemon,
181
+ trackers: RewardTrackingState,
182
+ ) -> float:
183
+ active_key = active.species or "active"
184
+ prev = prev_stats.get(active_key, {})
185
+ curr = curr_stats.get(active_key, {})
186
+ delta_stages = 0
187
+ for stat, curr_stage in curr.items():
188
+ prev_stage = prev.get(stat, 0)
189
+ if curr_stage > prev_stage:
190
+ delta_stages += curr_stage - prev_stage
191
+ if delta_stages <= 0:
192
+ return 0.0
193
+ if hp_fraction_to_percent(active.current_hp_fraction) <= 50.0:
194
+ return 0.0
195
+
196
+ raw = 0.5 * delta_stages
197
+ used = trackers.per_pokemon_setup_reward_used.get(active_key, 0.0)
198
+ remaining_cap = max(0.0, 2.0 - used)
199
+ reward = min(raw, remaining_cap)
200
+ trackers.per_pokemon_setup_reward_used[active_key] = used + reward
201
+ return reward
202
+
203
+
204
+ def _opponent_setup_penalty(
205
+ prev_stats: Dict[str, Dict[str, int]],
206
+ curr_stats: Dict[str, Dict[str, int]],
207
+ ) -> float:
208
+ penalty = 0.0
209
+ for key, curr in curr_stats.items():
210
+ prev = prev_stats.get(key, {})
211
+ for stat, curr_stage in curr.items():
212
+ prev_stage = prev.get(stat, 0)
213
+ if curr_stage > prev_stage:
214
+ penalty -= 0.5 * (curr_stage - prev_stage)
215
+ return penalty
216
+
217
+
218
+ def _passive_damage_reward(
219
+ prev_hits: int,
220
+ curr_hits: int,
221
+ trackers: RewardTrackingState,
222
+ ) -> float:
223
+ if curr_hits <= prev_hits:
224
+ return 0.0
225
+ delta = curr_hits - prev_hits
226
+ trackers.passive_hits_total += delta
227
+ return 0.01 * trackers.passive_hits_total
228
+
229
+
230
+ def _damage_rewards(prev: BattleStateSummary, curr: BattleStateSummary) -> float:
231
+ reward = 0.0
232
+ # Damage dealt: +1.0 per 10% opponent HP reduced
233
+ if curr.opp_team_hp_percent < prev.opp_team_hp_percent:
234
+ delta = prev.opp_team_hp_percent - curr.opp_team_hp_percent
235
+ reward += delta / 10.0
236
+ # Damage taken: -1.0 per 10% self HP lost
237
+ if curr.self_team_hp_percent < prev.self_team_hp_percent:
238
+ delta = prev.self_team_hp_percent - curr.self_team_hp_percent
239
+ reward -= delta / 10.0
240
+ return reward
241
+
242
+
243
+ def _knockout_rewards(prev: BattleStateSummary, curr: BattleStateSummary) -> float:
244
+ reward = 0.0
245
+ if curr.opp_fainted > prev.opp_fainted:
246
+ reward += 3.0 * (curr.opp_fainted - prev.opp_fainted)
247
+ if curr.self_fainted > prev.self_fainted:
248
+ reward -= 3.0 * (curr.self_fainted - prev.self_fainted)
249
+ return reward
250
+
251
+
252
+ def calculate_reward(
253
+ prev_state: BattleStateSummary,
254
+ curr_state: BattleStateSummary,
255
+ action: ActionJSON,
256
+ trackers: RewardTrackingState,
257
+ active: Optional[Pokemon] = None,
258
+ opponent_active: Optional[Pokemon] = None,
259
+ move_was_super_effective: bool = False,
260
+ move_hit: bool = True,
261
+ move_was_immune: bool = False,
262
+ team_status_cured: bool = False,
263
+ ) -> float:
264
+ """Compute shaped reward between two consecutive battle summaries.
265
+
266
+ The additional keyword arguments allow the caller to provide extra context from
267
+ the last action (type effectiveness, accuracy result, status cures) that are
268
+ not fully recoverable from the static battle snapshots alone.
269
+ """
270
+ reward = 0.0
271
+
272
+ # Core mechanics
273
+ reward += _damage_rewards(prev_state, curr_state)
274
+ reward += _knockout_rewards(prev_state, curr_state)
275
+
276
+ # Strategic nudges: type effectiveness and accuracy
277
+ if action.action == "move":
278
+ if move_was_super_effective:
279
+ reward += 0.5
280
+ if move_was_immune:
281
+ reward -= 1.0
282
+ if not move_hit:
283
+ reward -= 0.25
284
+
285
+ # Healing
286
+ reward += _healing_reward(
287
+ prev_state.self_team_hp_percent,
288
+ curr_state.self_team_hp_percent,
289
+ trackers,
290
+ )
291
+
292
+ # Status cures (e.g., Aromatherapy)
293
+ if team_status_cured:
294
+ reward += 1.0
295
+
296
+ # Setup sweeping (self) and opponent setup
297
+ if active is not None:
298
+ reward += _setup_reward(
299
+ prev_state.self_stat_stages,
300
+ curr_state.self_stat_stages,
301
+ active,
302
+ trackers,
303
+ )
304
+ reward += _opponent_setup_penalty(
305
+ prev_state.opp_stat_stages,
306
+ curr_state.opp_stat_stages,
307
+ )
308
+
309
+ # Passive damage / hazards
310
+ reward += _passive_damage_reward(
311
+ prev_state.opponent_passive_hits,
312
+ curr_state.opponent_passive_hits,
313
+ trackers,
314
+ )
315
+
316
+ # Status afflictions
317
+ reward += _status_penalty(prev_state.self_statuses, curr_state.self_statuses)
318
+
319
+ return reward
320
+
src/smogon_rl/state_formatter.py ADDED
@@ -0,0 +1,181 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ from dataclasses import dataclass, field
4
+ from typing import Dict, List, Optional
5
+
6
+ from poke_env.environment.battle import Battle
7
+ from poke_env.environment.pokemon import Pokemon
8
+
9
+
10
+ @dataclass
11
+ class OpponentMonHistory:
12
+ name: str
13
+ last_known_hp_percent: float
14
+ status: Optional[str]
15
+ revealed_moves: List[str] = field(default_factory=list)
16
+ revealed_item: Optional[str] = None
17
+ revealed_ability: Optional[str] = None
18
+
19
+
20
+ @dataclass
21
+ class OpponentHistoryTracker:
22
+ revealed: Dict[str, OpponentMonHistory] = field(default_factory=dict)
23
+
24
+ def update_from_battle(self, battle: Battle) -> None:
25
+ for mon in battle.opponent_team.values():
26
+ if not mon.species:
27
+ continue
28
+ key = mon.species
29
+ entry = self.revealed.get(
30
+ key,
31
+ OpponentMonHistory(
32
+ name=mon.species,
33
+ last_known_hp_percent=hp_fraction_to_percent(mon.current_hp_fraction),
34
+ status=str(mon.status) if mon.status is not None else None,
35
+ ),
36
+ )
37
+ entry.last_known_hp_percent = hp_fraction_to_percent(mon.current_hp_fraction)
38
+ entry.status = str(mon.status) if mon.status is not None else None
39
+
40
+ for move in mon.moves.values():
41
+ move_name = move.id
42
+ if move_name not in entry.revealed_moves:
43
+ entry.revealed_moves.append(move_name)
44
+
45
+ if mon.item is not None:
46
+ entry.revealed_item = mon.item
47
+ if mon.ability is not None:
48
+ entry.revealed_ability = mon.ability
49
+
50
+ self.revealed[key] = entry
51
+
52
+
53
+ def hp_fraction_to_percent(fraction: float | None) -> float:
54
+ if fraction is None:
55
+ return 0.0
56
+ return max(0.0, min(1.0, float(fraction))) * 100.0
57
+
58
+
59
+ def _format_stat_modifiers(pokemon: Pokemon) -> str:
60
+ parts: List[str] = []
61
+ for stat, stage in pokemon.boosts.items():
62
+ if stage == 0:
63
+ continue
64
+ sign = "+" if stage > 0 else ""
65
+ parts.append(f"{stat.capitalize()} {sign}{stage}")
66
+ return ", ".join(parts) if parts else "None"
67
+
68
+
69
+ def _estimate_speed_range(pokemon: Pokemon) -> str:
70
+ base_speed = pokemon.base_stats.get("spe", 0)
71
+ if base_speed <= 0:
72
+ return "Unknown"
73
+
74
+ level = 100
75
+ min_speed = int((((2 * base_speed) * level) / 100 + 5) * 0.9)
76
+ max_speed = int((((2 * base_speed + 31 + (252 // 4)) * level) / 100 + 5) * 1.1)
77
+ return f"{min_speed}-{max_speed}"
78
+
79
+
80
+ def _format_pokemon_line(pokemon: Pokemon) -> str:
81
+ hp = hp_fraction_to_percent(pokemon.current_hp_fraction)
82
+ status = str(pokemon.status) if pokemon.status is not None else "OK"
83
+ item = pokemon.item or "?"
84
+ return f"- {pokemon.species or '?'} HP:{hp:.0f}% {status} Item:{item}"
85
+
86
+
87
+ def _format_moveset_section(pokemon: Pokemon) -> str:
88
+ if not pokemon.moves:
89
+ return " Moves: [unknown]"
90
+ parts = []
91
+ for move in pokemon.moves.values():
92
+ bp = move.base_power or 0
93
+ t = move.type.name[0] if move.type is not None else "?"
94
+ parts.append(f"{move.id}({t}{bp})")
95
+ return " Moves: " + " | ".join(parts)
96
+
97
+
98
+ def format_battle_state(battle: Battle, opponent_history: OpponentHistoryTracker) -> str:
99
+ """Format the full battle state into a markdown string for the LLM.
100
+
101
+ Structure:
102
+ - Part A: Active field (self and opponent).
103
+ - Part B: Full self roster and movesets.
104
+ - Part C: Opponent history (revealed bench, revealed info).
105
+ """
106
+ opponent_history.update_from_battle(battle)
107
+
108
+ lines: List[str] = []
109
+
110
+ # ------------------------------------------------------------------ Part A
111
+ lines.append("## Part A: Active Field")
112
+
113
+ # Self active
114
+ self_active = battle.active_pokemon
115
+ if self_active is not None:
116
+ self_hp = hp_fraction_to_percent(self_active.current_hp_fraction)
117
+ self_status = (
118
+ str(self_active.status) if self_active.status is not None else "Healthy"
119
+ )
120
+ self_ability = self_active.ability or "Unknown"
121
+ self_item = self_active.item or "None"
122
+ self_mods = _format_stat_modifiers(self_active)
123
+ lines.append("### Active Self")
124
+ lines.append(
125
+ f"- Name: {self_active.species or 'Unknown'}\n"
126
+ f"- HP: {self_hp:.1f}%\n"
127
+ f"- Status: {self_status}\n"
128
+ f"- Ability: {self_ability}\n"
129
+ f"- Item: {self_item}\n"
130
+ f"- Stat Modifiers: {self_mods}"
131
+ )
132
+ else:
133
+ lines.append("### Active Self\n- None")
134
+
135
+ # Opponent active
136
+ opp_active = battle.opponent_active_pokemon
137
+ if opp_active is not None:
138
+ opp_hp = hp_fraction_to_percent(opp_active.current_hp_fraction)
139
+ opp_status = (
140
+ str(opp_active.status) if opp_active.status is not None else "Healthy"
141
+ )
142
+ opp_speed_range = _estimate_speed_range(opp_active)
143
+ lines.append("### Active Opponent")
144
+ lines.append(
145
+ f"- Name: {opp_active.species or 'Unknown'}\n"
146
+ f"- HP: {opp_hp:.1f}%\n"
147
+ f"- Status: {opp_status}\n"
148
+ f"- Speed Range: {opp_speed_range}"
149
+ )
150
+ else:
151
+ lines.append("### Active Opponent\n- None")
152
+
153
+ # ------------------------------------------------------------------ Part B
154
+ lines.append("\n## Part B: Full Self Roster")
155
+ if not battle.team:
156
+ lines.append("- [Unknown team]")
157
+ else:
158
+ for mon in battle.team.values():
159
+ lines.append(_format_pokemon_line(mon))
160
+ lines.append(_format_moveset_section(mon))
161
+
162
+ # ------------------------------------------------------------------ Part C
163
+ lines.append("\n## Part C: Opponent History")
164
+ if not opponent_history.revealed:
165
+ lines.append("- No opponent Pokémon revealed yet.")
166
+ else:
167
+ for entry in opponent_history.revealed.values():
168
+ lines.append(
169
+ f"- {entry.name} | Last HP: {entry.last_known_hp_percent:.1f}% | "
170
+ f"Status: {entry.status or 'Healthy'}"
171
+ )
172
+ if entry.revealed_moves:
173
+ moves = ", ".join(entry.revealed_moves)
174
+ lines.append(f" - Revealed moves: {moves}")
175
+ if entry.revealed_item:
176
+ lines.append(f" - Revealed item: {entry.revealed_item}")
177
+ if entry.revealed_ability:
178
+ lines.append(f" - Revealed ability: {entry.revealed_ability}")
179
+
180
+ return "\n".join(lines)
181
+
trainer.ipynb ADDED
The diff for this file is too large to render. See raw diff