Spaces:
Runtime error
Runtime error
File size: 8,031 Bytes
92c98ca f6d455c 92c98ca f6d455c 92c98ca f6d455c 92c98ca f6d455c 92c98ca f6d455c 92c98ca f6d455c 92c98ca |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Snake Environment Implementation.
A multi-agent snake game environment that wraps marlenv's Snake-v1.
This implementation provides a single-agent interface by wrapping the
multi-agent marlenv environment.
"""
from uuid import uuid4
import gym
import marlenv.envs # Register marlenv environments with gym
import numpy as np
# Support both in-repo and standalone imports
# In-repo imports (when running from OpenEnv repository)
from core.env_server.interfaces import Environment
from core.env_server.types import State
from envs.snake_env import SnakeAction, SnakeObservation
# from openenv_core.env_server.interfaces import Environment
# from openenv_core.env_server.types import State
class SingleAgentWrapper(gym.Wrapper):
"""
Custom wrapper to convert multi-agent marlenv to single-agent.
This wrapper properly handles the conversion without triggering
gym 0.24.1's strict type checking on done flags.
"""
def __init__(self, env):
super().__init__(env)
# Unwrap observation and action spaces for single agent
if hasattr(env.observation_space, "__getitem__"):
self.observation_space = env.observation_space[0]
if hasattr(env.action_space, "__getitem__"):
self.action_space = env.action_space[0]
def reset(self, **kwargs):
obs = self.env.reset(**kwargs)
# Remove first dimension if it's a multi-agent array (num_agents, H, W, C)
if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1:
return obs[0] # Return (H, W, C)
# Return first agent's observation if it's a list
if isinstance(obs, list):
return obs[0]
return obs
def step(self, action):
# Wrap action in list for multi-agent env
obs, rewards, dones, info = self.env.step([action])
# Unwrap returns for single agent
# Handle observation: remove first dimension if shape is (1, H, W, C)
if hasattr(obs, "shape") and len(obs.shape) == 4 and obs.shape[0] == 1:
obs = obs[0] # Convert (1, H, W, C) -> (H, W, C)
elif isinstance(obs, list):
obs = obs[0]
reward = rewards[0] if isinstance(rewards, list) else rewards
done = dones[0] if isinstance(dones, list) else dones
# Ensure done is a boolean (not numpy bool)
done = bool(done)
return obs, reward, done, info
class SnakeEnvironment(Environment):
"""
A snake game environment that wraps marlenv's Snake-v1.
This environment provides a single-agent interface to the multi-agent
snake game. The snake must navigate a grid, eat fruits, and avoid walls
and its own body.
Args:
height: Height of the grid map (default: 20)
width: Width of the grid map (default: 20)
snake_length: Initial length of the snake (default: 3)
vision_range: Vision range for partial observability (default: None for full grid)
observer: 'snake' for relative actions or 'human' for global directions (default: 'snake')
max_episode_steps: Maximum steps per episode (default: 1000)
reward_dict: Custom reward function (default: fruit=1.0, others=0.0)
Example:
>>> env = SnakeEnvironment()
>>> obs = env.reset()
>>> print(obs.alive) # True
>>>
>>> obs = env.step(SnakeAction(action=1)) # Turn left
>>> print(obs.episode_score)
>>> print(obs.reward)
"""
def __init__(
self,
height: int = 20,
width: int = 20,
snake_length: int = 3,
vision_range: int = None,
observer: str = "snake",
max_episode_steps: int = 1000,
reward_dict: dict = None,
):
"""Initialize the snake environment."""
self._state = State(episode_id=str(uuid4()), step_count=0)
# Default reward function
if reward_dict is None:
reward_dict = {
"fruit": 1.0,
"kill": 0.0,
"lose": -1.0,
"win": 100.0,
"time": 0.001,
}
# Create the marlenv snake environment for single agent
# Note: We don't use gym.make directly to avoid gym 0.24.1 wrappers
from marlenv.envs.snake_env import SnakeEnv as MarlenvSnake
self.base_env = MarlenvSnake(
height=height,
width=width,
num_snakes=1, # Single agent
snake_length=snake_length,
vision_range=vision_range,
frame_stack=1,
observer=observer,
reward_dict=reward_dict,
max_episode_steps=max_episode_steps,
)
# Wrap with our custom SingleAgent wrapper
self.env = SingleAgentWrapper(self.base_env)
# Track episode statistics
self._episode_score = 0.0
self._episode_fruits = 0
self._episode_kills = 0
def reset(self) -> SnakeObservation:
"""
Reset the environment.
Returns:
SnakeObservation with initial game state
"""
self._state = State(episode_id=str(uuid4()), step_count=0)
self._episode_score = 0.0
self._episode_fruits = 0
self._episode_kills = 0
# Reset the marlenv environment
obs = self.env.reset()
# Convert observation to list format
obs_list = obs.tolist() if isinstance(obs, np.ndarray) else obs
# Get the grid from the environment (access base env directly)
grid = self.base_env.grid.tolist() if hasattr(self.base_env, "grid") else []
return SnakeObservation(
grid=grid,
observation=obs_list,
episode_score=self._episode_score,
episode_steps=self._state.step_count,
episode_fruits=self._episode_fruits,
episode_kills=self._episode_kills,
alive=True,
done=False,
reward=0.0,
)
def step(self, action: SnakeAction) -> SnakeObservation: # type: ignore[override]
"""
Execute a step in the environment.
Args:
action: SnakeAction containing the action to take
Returns:
SnakeObservation with the result of the action
"""
self._state.step_count += 1
# Execute action in marlenv
obs, reward, done, info = self.env.step(action.action)
# Update episode statistics
self._episode_score += reward
# Convert observation to list format
obs_list = obs.tolist() if isinstance(obs, np.ndarray) else obs
# Get the grid from the environment (access base env directly)
grid = self.base_env.grid.tolist() if hasattr(self.base_env, "grid") else []
# Extract episode statistics from info if available
episode_fruits = (
info.get("episode_fruits", [self._episode_fruits])[0]
if "episode_fruits" in info
else self._episode_fruits
)
episode_kills = (
info.get("episode_kills", [self._episode_kills])[0]
if "episode_kills" in info
else self._episode_kills
)
return SnakeObservation(
grid=grid,
observation=obs_list,
episode_score=self._episode_score,
episode_steps=self._state.step_count,
episode_fruits=int(episode_fruits),
episode_kills=int(episode_kills),
alive=not done,
done=done,
reward=float(reward),
metadata={"info": info},
)
@property
def state(self) -> State:
"""
Get the current environment state.
Returns:
Current State with episode_id and step_count
"""
return self._state
|