trl-mcsd / examples /scripts /openenv /multi_env.py
ihbkaiser's picture
Implement MCSD for experimental SDPO
1fa3c6c verified
# Copyright 2020-2026 The HuggingFace Team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# /// script
# dependencies = [
# "trl",
# "trackio",
# "openenv-textarena @ git+https://huggingface.co/spaces/openenv/wordle",
# "openenv-openspiel-env @ git+https://huggingface.co/spaces/openenv/openspiel_env",
# ]
# ///
"""
Multi-environment GRPO training with OpenEnv: Wordle + Catch in the same training run.
Demonstrates how to wrap multiple environments in a single `environment_factory` class. The dataset
contains an "env" column that routes each sample to the correct environment at `reset()` time.
Usage:
python examples/scripts/openenv/multi_env.py \\
--wordle-url https://openenv-wordle.hf.space \\
--catch-url https://openenv-openspiel-env.hf.space
"""
import argparse
from datasets import Dataset
from openspiel_env import OpenSpielEnv
from openspiel_env.models import OpenSpielAction
from textarena_env import TextArenaAction, TextArenaEnv
from trl import GRPOConfig, GRPOTrainer
wordle_prompt = """You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.
Follow these rules to play Wordle:
1. The target is a 5-letter English word
2. You have 6 attempts to guess the correct word
3. After each guess, you receive color-coded feedback:
- GREEN (G): Letter is correct and in the correct position
- YELLOW (Y): Letter is in the word but in the wrong position
- GRAY (X): Letter is not in the word at all
4. All guesses must be valid 5-letter English words
5. You cannot reuse a word you've already guessed
6. Use the tool `guess` to make a guess.
"""
catch_prompt = """You are an AI agent playing the game **Catch**.
### Game Description
- The game is played on a **10×5 grid**.
- There is one **falling ball** and one **paddle** that you control at the bottom.
- The objective is to **move the paddle left or right to catch the ball** as it falls.
- The episode ends when the ball reaches the bottom row:
- You get **+1 reward** if you catch it.
- You get **–1 reward** if you miss it.
### Observation Format
Each observation is a flattened 10x5 grid (list of 50 floats).
- 1.0 → occupied (ball or paddle)
- 0.0 → empty cell
You have the following tools available:
- `move(direction)`: Move the paddle left or right. Direction must be "left" or "right".
- `stay`: Do nothing and let the ball fall one step.
Observe the grid, determine where the ball is relative to the paddle, then move accordingly.
"""
DEFAULT_WORDLE_URL = "https://openenv-wordle.hf.space"
DEFAULT_CATCH_URL = "https://openenv-openspiel-env.hf.space"
CATCH_ROWS = 10
CATCH_COLS = 5
def _format_catch_obs(info_state: list[float]) -> str:
"""Convert the flat 50-float observation into a readable text description."""
ball_row = ball_col = paddle_col = None
for idx, val in enumerate(info_state):
if val == 1.0:
r, c = divmod(idx, CATCH_COLS)
if r < CATCH_ROWS - 1:
ball_row, ball_col = r + 1, c + 1
else:
paddle_col = c + 1
parts = []
if ball_row is not None and ball_col is not None:
parts.append(f"Ball: row {ball_row}/{CATCH_ROWS}, column {ball_col}/{CATCH_COLS}")
if paddle_col is not None:
parts.append(f"Paddle: column {paddle_col}/{CATCH_COLS}")
if ball_col is not None and paddle_col is not None:
diff = ball_col - paddle_col
if diff < 0:
parts.append(f"The ball is {abs(diff)} column(s) to the LEFT of the paddle.")
elif diff > 0:
parts.append(f"The ball is {diff} column(s) to the RIGHT of the paddle.")
else:
parts.append("The ball is directly above the paddle.")
return "\n".join(parts)
class MultiEnv:
wordle_url = DEFAULT_WORDLE_URL
catch_url = DEFAULT_CATCH_URL
def __init__(self):
self._wordle_client = None
self._catch_client = None
self.active = None
self.reward = 0.0
self.done = False
def reset(self, **kwargs) -> str | None:
self.active = kwargs.get("env", "wordle")
self.reward = 0.0
self.done = False
if self.active == "wordle":
if self._wordle_client is not None:
try:
self._wordle_client.close()
except Exception:
pass
self._wordle_client = TextArenaEnv(base_url=MultiEnv.wordle_url)
result = self._wordle_client.reset()
self._last_full_feedback = result.observation.messages[0].content
self.reward = 0.0
return self._last_full_feedback
elif self.active == "catch":
if self._catch_client is not None:
try:
self._catch_client.close()
except Exception:
pass
self._catch_client = OpenSpielEnv(base_url=MultiEnv.catch_url)
result = self._catch_client.reset()
self.done = result.observation.done
return _format_catch_obs(result.observation.info_state)
else:
raise ValueError(f"Unknown environment: {self.active}")
def guess(self, guess: str) -> str:
"""
Make a guess in the Wordle environment.
Args:
guess: The guessed word, formatted as '[abcde]'
Returns:
The feedback message from the environment.
"""
if self.active != "wordle":
raise ValueError("guess is only available in Wordle")
if self.done:
raise ValueError("Game over.")
result = self._wordle_client.step(TextArenaAction(message=guess))
_full_feedback = result.observation.messages[0].content
feedback = _full_feedback[len(self._last_full_feedback) :]
self._last_full_feedback = _full_feedback
if "You attempted an invalid move" in feedback:
self.reward = 0.0
else:
self.reward = result.reward
self.done = result.done
return feedback
def _catch_action(self, action_id: int) -> str:
if self.done:
raise ValueError("Episode is done.")
result = self._catch_client.step(OpenSpielAction(action_id=action_id, game_name="catch"))
self.reward = result.reward or 0.0
self.done = result.observation.done
return _format_catch_obs(result.observation.info_state)
def move(self, direction: str) -> str:
"""Move the paddle left or right.
Args:
direction: Direction to move, either "left" or "right".
Returns:
The observation after moving.
"""
if self.active != "catch":
raise ValueError("move is only available in Catch")
if direction == "left":
action_id = 0
elif direction == "right":
action_id = 2
else:
raise ValueError(f"Invalid direction {direction!r}: must be 'left' or 'right'.")
return self._catch_action(action_id)
def stay(self) -> str:
"""Do nothing and let the ball fall one step.
Returns:
The observation after staying.
"""
if self.active != "catch":
raise ValueError("stay is only available in Catch")
return self._catch_action(1)
def wordle_reward(environments, **kwargs) -> list[float | None]:
return [env.reward if env.active == "wordle" else None for env in environments]
def catch_reward(environments, **kwargs) -> list[float | None]:
rewards = []
for env in environments:
if env.active != "catch":
rewards.append(None)
elif env.done:
# Catch gives +1 for catching, -1 for missing. Clamp to [0, 1] for GRPO advantage estimation.
rewards.append(max(env.reward, 0.0))
else:
rewards.append(0.0) # Incomplete episode
return rewards
def main() -> None:
parser = argparse.ArgumentParser(description="Multi-environment GRPO training")
parser.add_argument("--wordle-url", default=DEFAULT_WORDLE_URL, help="Wordle environment URL")
parser.add_argument("--catch-url", default=DEFAULT_CATCH_URL, help="Catch environment URL")
args, remaining = parser.parse_known_args()
MultiEnv.wordle_url = args.wordle_url
MultiEnv.catch_url = args.catch_url
n = 500 # samples per environment
dataset = Dataset.from_dict(
{
"prompt": (
[[{"role": "user", "content": wordle_prompt}]] * n + [[{"role": "user", "content": catch_prompt}]] * n
),
"env": ["wordle"] * n + ["catch"] * n,
}
)
trainer = GRPOTrainer(
model="Qwen/Qwen3-1.7B",
reward_funcs=[wordle_reward, catch_reward],
train_dataset=dataset,
args=GRPOConfig(
report_to="wandb",
log_completions=True,
num_completions_to_print=2,
logging_steps=1,
chat_template_kwargs={"enable_thinking": False},
max_completion_length=1024,
),
environment_factory=MultiEnv,
)
trainer.train()
if __name__ == "__main__":
main()