maze_env-pr-106 / src /envs /maze_env /server /maze_environment.py
burtenshaw's picture
burtenshaw HF Staff
Upload folder using huggingface_hub
be46a16 verified
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
"""
Maze Environment Server Implementation.
This module wraps Maze's environment and exposes it
via the OpenEnv Environment interface.
"""
from typing import List, Tuple, Optional
from core.env_server import Environment
from .maze import Maze
from ..models import MazeAction, MazeObservation, MazeState
try:
import numpy as np
except ImportError as e:
raise ImportError(
"Numpy is not installed. "
"Please install it following instructions at: "
"pip install numpy"
) from e
class MazeEnvironment(Environment):
"""
Maze Environment wrapper for OpenEnv.
This environment wraps Maze game and provides a single-agent interface.
Args:
maze_array: Maze array as numpy array
start cell: Start of the maze
exit_cell: Exit for the maze
"""
def __init__(
self,
maze_array: np.ndarray,
start_cell: Tuple[int, int] = (0, 0),
exit_cell: Optional[Tuple[int, int]] = (7, 7),
):
# Create underlying Maze instance (matches your working code)
self.env = Maze(maze=maze_array, start_cell=start_cell, exit_cell=exit_cell)
self.total_reward = 0
self.start_cell = start_cell
self.exit_cell = exit_cell
# env.reset() will be called in reset(); state initialized to None until then
self.state: Optional[MazeState] = None
def reset(self) -> MazeObservation:
"""Reset environment and return initial observation (MazeObservation)."""
observation = (
self.env.reset()
) # typically returns np.array([row, col]) or similar
# initialize episode state
self.state = MazeState(episode_id="episode_1", step_count=0, done=False)
# build MazeObservation; convert numpy to list for JSON-serializable dataclass fields
pos_list = (
observation.tolist()
if hasattr(observation, "tolist")
else list(observation)
)
self.total_reward = 0
legal_actions = self._compute_legal_actions(pos_list[0])
return MazeObservation(
position=pos_list,
total_reward=self.total_reward,
legal_actions=legal_actions,
)
def step(self, action: MazeAction) -> MazeObservation:
"""
Step function that manipulates the maze position grid
and applies rewards/penalties for movement outcomes.
"""
# --- Get current position ---
if hasattr(self.env, "agent_position"):
row, col = self.env.agent_position
elif hasattr(self.env, "_Maze__current_cell"):
row, col = self.env._Maze__current_cell
else:
row, col = self.env._Maze__start_cell
maze = np.array(self.env.maze)
# --- Define movement directions ---
# 0 = UP, 1 = DOWN, 2 = LEFT, 3 = RIGHT
move_map = {
0: (-1, 0),
1: (1, 0),
2: (0, -1),
3: (0, 1),
}
# --- Reward settings ---
reward_exit = 10.0 # reward for reaching the exit cell
reward_move = 0.05 # reward for a move that didn't find the exit but is valid
penalty_visited = -0.25 # penalty for revisiting a cell
penalty_impossible = -0.75 # penalty for invalid move (wall/outside)
dr, dc = move_map.get(action.action, (0, 0))
new_r, new_c = row + dr, col + dc
# Keep track of visited cells
if not hasattr(self, "_visited"):
self._visited = set()
self._visited.add((row, col))
# --- Check if move is valid ---
valid_move = (
0 <= new_r < maze.shape[0]
and 0 <= new_c < maze.shape[1]
and maze[new_r, new_c] != 1
)
reward = 0.0
done = False
if valid_move:
# Update position
row, col = new_r, new_c
if self.exit_cell and (row, col) == self.exit_cell:
reward += reward_exit
done = True
self._visited = set()
elif (row, col) in self._visited:
reward += penalty_visited
else:
reward += reward_move
else:
# Invalid move
reward += penalty_impossible
# --- Update environment position ---
if hasattr(self.env, "agent_position"):
self.env.agent_position = (row, col)
elif hasattr(self.env, "_Maze__current_cell"):
self.env._Maze__current_cell = (row, col)
# --- Total reward update ---
self.total_reward += reward
# --- Update state ---
if self.state is None:
self.state = MazeState(episode_id="episode_1", step_count=0, done=done)
self.state.step_count += 1
self.state.done = done
# --- Observation ---
pos_list = [row, col]
legal_actions = self._compute_legal_actions(pos_list)
# --- Return observation ---
return MazeObservation(
position=pos_list,
total_reward=self.total_reward,
legal_actions=legal_actions,
done=done,
)
def state(self) -> Optional[MazeState]:
"""Return the current MazeState object."""
return self.state
def _compute_legal_actions(self, pos: List[int]) -> List[int]:
"""
Compute which actions are legal given the current normalized position [row, col].
(0=UP, 1=DOWN, 2=LEFT, 3=RIGHT)
"""
actions: List[int] = []
if not pos or len(pos) < 2:
return actions
row, col = int(pos[0]), int(pos[1])
nrows, ncols = self.env.maze.shape
# UP
if row > 0 and self.env.maze[row - 1, col] == 0:
actions.append(0)
# DOWN
if row < nrows - 1 and self.env.maze[row + 1, col] == 0:
actions.append(1)
# LEFT
if col > 0 and self.env.maze[row, col - 1] == 0:
actions.append(2)
# RIGHT
if col < ncols - 1 and self.env.maze[row, col + 1] == 0:
actions.append(3)
return actions