Spaces:
Runtime error
Runtime error
| # Copyright (c) Meta Platforms, Inc. and affiliates. | |
| # All rights reserved. | |
| # | |
| # This source code is licensed under the BSD-style license found in the | |
| # LICENSE file in the root directory of this source tree. | |
| """ | |
| Snake Environment Implementation. | |
| A multi-agent snake game environment that wraps marlenv's Snake-v1. | |
| This implementation provides a single-agent interface by wrapping the | |
| multi-agent marlenv environment. | |
| """ | |
| from uuid import uuid4 | |
| import gym | |
| import marlenv.envs # Register marlenv environments with gym | |
| import numpy as np | |
| # Support both in-repo and standalone imports | |
| # In-repo imports (when running from OpenEnv repository) | |
| from core.env_server.interfaces import Environment | |
| from core.env_server.types import State | |
| from envs.snake_env import SnakeAction, SnakeObservation | |
| # from openenv_core.env_server.interfaces import Environment | |
| # from openenv_core.env_server.types import State | |
| class SingleAgentWrapper(gym.Wrapper): | |
| """ | |
| Custom wrapper to convert multi-agent marlenv to single-agent. | |
| This wrapper properly handles the conversion without triggering | |
| gym 0.24.1's strict type checking on done flags. | |
| """ | |
| def __init__(self, env): | |
| super().__init__(env) | |
| # Unwrap observation and action spaces for single agent | |
| if hasattr(env.observation_space, "__getitem__"): | |
| self.observation_space = env.observation_space[0] | |
| if hasattr(env.action_space, "__getitem__"): | |
| self.action_space = env.action_space[0] | |
| def reset(self, **kwargs): | |
| obs = self.env.reset(**kwargs) | |
| # Remove first dimension if it's a multi-agent array (num_agents, H, W, C) | |
| if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1: | |
| return obs[0] # Return (H, W, C) | |
| # Return first agent's observation if it's a list | |
| if isinstance(obs, list): | |
| return obs[0] | |
| return obs | |
| def step(self, action): | |
| # Wrap action in list for multi-agent env | |
| obs, rewards, dones, info = self.env.step([action]) | |
| # Unwrap returns for single agent | |
| # Handle observation: remove first dimension if shape is (1, H, W, C) | |
| if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1: | |
| obs = obs[0] # Convert (1, H, W, C) -> (H, W, C) | |
| elif isinstance(obs, list): | |
| obs = obs[0] | |
| reward = rewards[0] if isinstance(rewards, list) else rewards | |
| done = dones[0] if isinstance(dones, list) else dones | |
| # Ensure done is a boolean (not numpy bool) | |
| done = bool(done) | |
| return obs, reward, done, info | |
| class SnakeEnvironment(Environment): | |
| """ | |
| A snake game environment that wraps marlenv's Snake-v1. | |
| This environment provides a single-agent interface to the multi-agent | |
| snake game. The snake must navigate a grid, eat fruits, and avoid walls | |
| and its own body. | |
| Args: | |
| height: Height of the grid map (default: 20) | |
| width: Width of the grid map (default: 20) | |
| snake_length: Initial length of the snake (default: 3) | |
| vision_range: Vision range for partial observability (default: None for full grid) | |
| observer: 'snake' for relative actions or 'human' for global directions (default: 'snake') | |
| max_episode_steps: Maximum steps per episode (default: 1000) | |
| reward_dict: Custom reward function (default: fruit=1.0, others=0.0) | |
| Example: | |
| >>> env = SnakeEnvironment() | |
| >>> obs = env.reset() | |
| >>> print(obs.alive) # True | |
| >>> | |
| >>> obs = env.step(SnakeAction(action=1)) # Turn left | |
| >>> print(obs.episode_score) | |
| >>> print(obs.reward) | |
| """ | |
| def __init__( | |
| self, | |
| height: int = 20, | |
| width: int = 20, | |
| snake_length: int = 3, | |
| vision_range: int = None, | |
| observer: str = "snake", | |
| max_episode_steps: int = 1000, | |
| reward_dict: dict = None, | |
| ): | |
| """Initialize the snake environment.""" | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| # Default reward function | |
| if reward_dict is None: | |
| reward_dict = { | |
| "fruit": 1.0, | |
| "kill": 0.0, | |
| "lose": -1.0, | |
| "win": 100.0, | |
| "time": 0.001, | |
| } | |
| # Create the marlenv snake environment for single agent | |
| # Note: We don't use gym.make directly to avoid gym 0.24.1 wrappers | |
| from marlenv.envs.snake_env import SnakeEnv as MarlenvSnake | |
| self.base_env = MarlenvSnake( | |
| height=height, | |
| width=width, | |
| num_snakes=1, # Single agent | |
| snake_length=snake_length, | |
| vision_range=vision_range, | |
| frame_stack=1, | |
| observer=observer, | |
| reward_dict=reward_dict, | |
| max_episode_steps=max_episode_steps, | |
| ) | |
| # Wrap with our custom SingleAgent wrapper | |
| self.env = SingleAgentWrapper(self.base_env) | |
| # Track episode statistics | |
| self._episode_score = 0.0 | |
| self._episode_fruits = 0 | |
| self._episode_kills = 0 | |
| def reset(self) -> SnakeObservation: | |
| """ | |
| Reset the environment. | |
| Returns: | |
| SnakeObservation with initial game state | |
| """ | |
| self._state = State(episode_id=str(uuid4()), step_count=0) | |
| self._episode_score = 0.0 | |
| self._episode_fruits = 0 | |
| self._episode_kills = 0 | |
| # Reset the marlenv environment | |
| obs = self.env.reset() | |
| # Convert observation to list format | |
| obs_list = obs.tolist() if isinstance(obs, np.ndarray) else obs | |
| # Get the grid from the environment (access base env directly) | |
| grid = self.base_env.grid.tolist() if hasattr(self.base_env, "grid") else [] | |
| return SnakeObservation( | |
| grid=grid, | |
| observation=obs_list, | |
| episode_score=self._episode_score, | |
| episode_steps=self._state.step_count, | |
| episode_fruits=self._episode_fruits, | |
| episode_kills=self._episode_kills, | |
| alive=True, | |
| done=False, | |
| reward=0.0, | |
| ) | |
| def step(self, action: SnakeAction) -> SnakeObservation: # type: ignore[override] | |
| """ | |
| Execute a step in the environment. | |
| Args: | |
| action: SnakeAction containing the action to take | |
| Returns: | |
| SnakeObservation with the result of the action | |
| """ | |
| self._state.step_count += 1 | |
| # Execute action in marlenv | |
| obs, reward, done, info = self.env.step(action.action) | |
| # Update episode statistics | |
| self._episode_score += reward | |
| # Convert observation to list format | |
| obs_list = obs.tolist() if isinstance(obs, np.ndarray) else obs | |
| # Get the grid from the environment (access base env directly) | |
| grid = self.base_env.grid.tolist() if hasattr(self.base_env, "grid") else [] | |
| # Extract episode statistics from info if available | |
| episode_fruits = ( | |
| info.get("episode_fruits", [self._episode_fruits])[0] | |
| if "episode_fruits" in info | |
| else self._episode_fruits | |
| ) | |
| episode_kills = ( | |
| info.get("episode_kills", [self._episode_kills])[0] | |
| if "episode_kills" in info | |
| else self._episode_kills | |
| ) | |
| return SnakeObservation( | |
| grid=grid, | |
| observation=obs_list, | |
| episode_score=self._episode_score, | |
| episode_steps=self._state.step_count, | |
| episode_fruits=int(episode_fruits), | |
| episode_kills=int(episode_kills), | |
| alive=not done, | |
| done=done, | |
| reward=float(reward), | |
| metadata={"info": info}, | |
| ) | |
| def state(self) -> State: | |
| """ | |
| Get the current environment state. | |
| Returns: | |
| Current State with episode_id and step_count | |
| """ | |
| return self._state | |