|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| """
|
| Multi-environment GRPO training with OpenEnv: Wordle + Catch in the same training run.
|
|
|
| Demonstrates how to wrap multiple environments in a single `environment_factory` class. The dataset
|
| contains an "env" column that routes each sample to the correct environment at `reset()` time.
|
|
|
| Usage:
|
| python examples/scripts/openenv/multi_env.py \\
|
| --wordle-url https://openenv-wordle.hf.space \\
|
| --catch-url https://openenv-openspiel-env.hf.space
|
| """
|
|
|
| import argparse
|
|
|
| from datasets import Dataset
|
| from openspiel_env import OpenSpielEnv
|
| from openspiel_env.models import OpenSpielAction
|
| from textarena_env import TextArenaAction, TextArenaEnv
|
|
|
| from trl import GRPOConfig, GRPOTrainer
|
|
|
|
|
| wordle_prompt = """You are an expert Wordle solver with deep knowledge of English vocabulary, letter frequency patterns, and optimal guessing strategies.
|
|
|
| Follow these rules to play Wordle:
|
|
|
| 1. The target is a 5-letter English word
|
| 2. You have 6 attempts to guess the correct word
|
| 3. After each guess, you receive color-coded feedback:
|
| - GREEN (G): Letter is correct and in the correct position
|
| - YELLOW (Y): Letter is in the word but in the wrong position
|
| - GRAY (X): Letter is not in the word at all
|
| 4. All guesses must be valid 5-letter English words
|
| 5. You cannot reuse a word you've already guessed
|
| 6. Use the tool `guess` to make a guess.
|
| """
|
|
|
| catch_prompt = """You are an AI agent playing the game **Catch**.
|
|
|
| ### Game Description
|
| - The game is played on a **10×5 grid**.
|
| - There is one **falling ball** and one **paddle** that you control at the bottom.
|
| - The objective is to **move the paddle left or right to catch the ball** as it falls.
|
| - The episode ends when the ball reaches the bottom row:
|
| - You get **+1 reward** if you catch it.
|
| - You get **–1 reward** if you miss it.
|
|
|
| ### Observation Format
|
| Each observation is a flattened 10x5 grid (list of 50 floats).
|
| - 1.0 → occupied (ball or paddle)
|
| - 0.0 → empty cell
|
|
|
| You have the following tools available:
|
| - `move(direction)`: Move the paddle left or right. Direction must be "left" or "right".
|
| - `stay`: Do nothing and let the ball fall one step.
|
|
|
| Observe the grid, determine where the ball is relative to the paddle, then move accordingly.
|
| """
|
|
|
| DEFAULT_WORDLE_URL = "https://openenv-wordle.hf.space"
|
| DEFAULT_CATCH_URL = "https://openenv-openspiel-env.hf.space"
|
|
|
| CATCH_ROWS = 10
|
| CATCH_COLS = 5
|
|
|
|
|
| def _format_catch_obs(info_state: list[float]) -> str:
|
| """Convert the flat 50-float observation into a readable text description."""
|
| ball_row = ball_col = paddle_col = None
|
| for idx, val in enumerate(info_state):
|
| if val == 1.0:
|
| r, c = divmod(idx, CATCH_COLS)
|
| if r < CATCH_ROWS - 1:
|
| ball_row, ball_col = r + 1, c + 1
|
| else:
|
| paddle_col = c + 1
|
| parts = []
|
| if ball_row is not None and ball_col is not None:
|
| parts.append(f"Ball: row {ball_row}/{CATCH_ROWS}, column {ball_col}/{CATCH_COLS}")
|
| if paddle_col is not None:
|
| parts.append(f"Paddle: column {paddle_col}/{CATCH_COLS}")
|
| if ball_col is not None and paddle_col is not None:
|
| diff = ball_col - paddle_col
|
| if diff < 0:
|
| parts.append(f"The ball is {abs(diff)} column(s) to the LEFT of the paddle.")
|
| elif diff > 0:
|
| parts.append(f"The ball is {diff} column(s) to the RIGHT of the paddle.")
|
| else:
|
| parts.append("The ball is directly above the paddle.")
|
| return "\n".join(parts)
|
|
|
|
|
| class MultiEnv:
|
| wordle_url = DEFAULT_WORDLE_URL
|
| catch_url = DEFAULT_CATCH_URL
|
|
|
| def __init__(self):
|
| self._wordle_client = None
|
| self._catch_client = None
|
| self.active = None
|
| self.reward = 0.0
|
| self.done = False
|
|
|
| def reset(self, **kwargs) -> str | None:
|
| self.active = kwargs.get("env", "wordle")
|
| self.reward = 0.0
|
| self.done = False
|
|
|
| if self.active == "wordle":
|
| if self._wordle_client is not None:
|
| try:
|
| self._wordle_client.close()
|
| except Exception:
|
| pass
|
| self._wordle_client = TextArenaEnv(base_url=MultiEnv.wordle_url)
|
| result = self._wordle_client.reset()
|
| self._last_full_feedback = result.observation.messages[0].content
|
| self.reward = 0.0
|
| return self._last_full_feedback
|
| elif self.active == "catch":
|
| if self._catch_client is not None:
|
| try:
|
| self._catch_client.close()
|
| except Exception:
|
| pass
|
| self._catch_client = OpenSpielEnv(base_url=MultiEnv.catch_url)
|
| result = self._catch_client.reset()
|
| self.done = result.observation.done
|
| return _format_catch_obs(result.observation.info_state)
|
| else:
|
| raise ValueError(f"Unknown environment: {self.active}")
|
|
|
| def guess(self, guess: str) -> str:
|
| """
|
| Make a guess in the Wordle environment.
|
|
|
| Args:
|
| guess: The guessed word, formatted as '[abcde]'
|
|
|
| Returns:
|
| The feedback message from the environment.
|
| """
|
| if self.active != "wordle":
|
| raise ValueError("guess is only available in Wordle")
|
| if self.done:
|
| raise ValueError("Game over.")
|
| result = self._wordle_client.step(TextArenaAction(message=guess))
|
| _full_feedback = result.observation.messages[0].content
|
| feedback = _full_feedback[len(self._last_full_feedback) :]
|
| self._last_full_feedback = _full_feedback
|
| if "You attempted an invalid move" in feedback:
|
| self.reward = 0.0
|
| else:
|
| self.reward = result.reward
|
| self.done = result.done
|
| return feedback
|
|
|
| def _catch_action(self, action_id: int) -> str:
|
| if self.done:
|
| raise ValueError("Episode is done.")
|
| result = self._catch_client.step(OpenSpielAction(action_id=action_id, game_name="catch"))
|
| self.reward = result.reward or 0.0
|
| self.done = result.observation.done
|
| return _format_catch_obs(result.observation.info_state)
|
|
|
| def move(self, direction: str) -> str:
|
| """Move the paddle left or right.
|
|
|
| Args:
|
| direction: Direction to move, either "left" or "right".
|
|
|
| Returns:
|
| The observation after moving.
|
| """
|
| if self.active != "catch":
|
| raise ValueError("move is only available in Catch")
|
| if direction == "left":
|
| action_id = 0
|
| elif direction == "right":
|
| action_id = 2
|
| else:
|
| raise ValueError(f"Invalid direction {direction!r}: must be 'left' or 'right'.")
|
| return self._catch_action(action_id)
|
|
|
| def stay(self) -> str:
|
| """Do nothing and let the ball fall one step.
|
|
|
| Returns:
|
| The observation after staying.
|
| """
|
| if self.active != "catch":
|
| raise ValueError("stay is only available in Catch")
|
| return self._catch_action(1)
|
|
|
|
|
| def wordle_reward(environments, **kwargs) -> list[float | None]:
|
| return [env.reward if env.active == "wordle" else None for env in environments]
|
|
|
|
|
| def catch_reward(environments, **kwargs) -> list[float | None]:
|
| rewards = []
|
| for env in environments:
|
| if env.active != "catch":
|
| rewards.append(None)
|
| elif env.done:
|
|
|
| rewards.append(max(env.reward, 0.0))
|
| else:
|
| rewards.append(0.0)
|
| return rewards
|
|
|
|
|
| def main() -> None:
|
| parser = argparse.ArgumentParser(description="Multi-environment GRPO training")
|
| parser.add_argument("--wordle-url", default=DEFAULT_WORDLE_URL, help="Wordle environment URL")
|
| parser.add_argument("--catch-url", default=DEFAULT_CATCH_URL, help="Catch environment URL")
|
| args, remaining = parser.parse_known_args()
|
|
|
| MultiEnv.wordle_url = args.wordle_url
|
| MultiEnv.catch_url = args.catch_url
|
|
|
| n = 500
|
| dataset = Dataset.from_dict(
|
| {
|
| "prompt": (
|
| [[{"role": "user", "content": wordle_prompt}]] * n + [[{"role": "user", "content": catch_prompt}]] * n
|
| ),
|
| "env": ["wordle"] * n + ["catch"] * n,
|
| }
|
| )
|
|
|
| trainer = GRPOTrainer(
|
| model="Qwen/Qwen3-1.7B",
|
| reward_funcs=[wordle_reward, catch_reward],
|
| train_dataset=dataset,
|
| args=GRPOConfig(
|
| report_to="wandb",
|
| log_completions=True,
|
| num_completions_to_print=2,
|
| logging_steps=1,
|
| chat_template_kwargs={"enable_thinking": False},
|
| max_completion_length=1024,
|
| ),
|
| environment_factory=MultiEnv,
|
| )
|
| trainer.train()
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|