| from abc import ABC |
|
|
| import gymnasium as gym |
| from gymnasium import spaces |
| import numpy as np |
|
|
|
|
| class NimGameEnv(gym.Env, ABC): |
| """Custom environment for a simple Nim game. |
| |
| In this game, there are two players and a number of piles of stones. |
| Each turn, a player can choose a pile and remove any number of stones from it. |
| The player who takes the last stone loses. |
| |
| The observation space is a tuple of integers representing the number of stones in each pile. |
| The action space is a tuple of two integers, representing the chosen pile and the number of stones to remove. |
| """ |
|
|
| def __init__(self, starting_stick_piles=[3, 5, 7]): |
| self.starting_stick_piles = starting_stick_piles |
| self.num_piles = len(starting_stick_piles) |
| self.max_stones = max(starting_stick_piles) |
| self.piles = self._init_piles() |
| self.current_player = 0 |
| self.action_space = spaces.MultiDiscrete([self.num_piles, self.max_stones + 1]) |
| self.observation_space = spaces.MultiDiscrete([self.max_stones + 1] * self.num_piles) |
|
|
| def step(self, action): |
| """Take a step in the environment. |
| |
| Parameters |
| ---------- |
| action: tuple |
| The action taken by the player, represented as a tuple of the chosen pile and the number of stones to remove. |
| |
| Returns |
| ------- |
| observation: tuple |
| The current number of stones in each pile. |
| reward: float |
| The reward for the current step. |
| done: bool |
| Whether the game has ended. |
| info: dict |
| Additional information about the step. |
| """ |
| |
| if not self._is_valid_action(action): |
| raise ValueError("Invalid action") |
|
|
| |
| pile, num_stones = action |
| self.piles[pile] -= num_stones |
|
|
| |
| done = self._is_game_over() |
|
|
| |
| reward = self._calculate_reward() |
|
|
| |
| self.current_player = (self.current_player + 1) % 2 |
| return self.piles, reward, done, {} |
|
|
| def reset(self): |
| """Reset the environment to the initial state.""" |
| self.piles = self._init_piles() |
| self.current_player = 0 |
| text_observation = "The piles contain " + ", ".join(str(x) for x in self.piles) + " sticks." |
| return text_observation, self.piles |
|
|
| def _init_piles(self): |
| """Initialize the stick piles.""" |
| return [3, 5, 7] |
|
|
| def _generate_random_stones(self): |
| """Generate a random number of stones (between 1 and max_stones inclusive).""" |
| return np.random.randint(1, self.max_stones + 1) |
|
|
| def _is_valid_action(self, action): |
| """Determine if an action is valid.""" |
| pile, num_stones = action |
| return 0 <= pile < self.num_piles and 0 < num_stones <= self.max_stones and num_stones <= self.piles[pile] |
|
|
| def _is_game_over(self): |
| """Determine if the game has ended.""" |
| return all(pile == 0 for pile in self.piles) |
|
|
| def _calculate_reward(self): |
| """Calculate the reward for the current step.""" |
| return 1 if self._is_game_over() else 0 |
|
|