tripathysagar commited on
Commit
b1c091a
·
verified ·
1 Parent(s): 1eab21d

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. models.py +1 -0
  2. server/maze_environment.py +23 -3
models.py CHANGED
@@ -6,6 +6,7 @@ class MazeAction(Action):
6
 
7
  class MazeObservation(Observation):
8
  position: list = Field(default=[], description="Agent's [row, col]")
 
9
  grid_view: str = Field(default="", description="String view of the maze")
10
 
11
  class MazeState(State):
 
6
 
7
  class MazeObservation(Observation):
8
  position: list = Field(default=[], description="Agent's [row, col]")
9
+ valid_moves: list = Field(default=[], description="List of valid directions")
10
  grid_view: str = Field(default="", description="String view of the maze")
11
 
12
  class MazeState(State):
server/maze_environment.py CHANGED
@@ -53,15 +53,35 @@ class MazeEnvironment(Environment):
53
  if seed: random.seed(seed)
54
  self._generate_new_maze()
55
  self._episode_id = episode_id
56
- return MazeObservation(position=self._agent_pos, grid_view=self._render(), done=False, reward=0)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
57
 
58
  def step(self, action: MazeAction, timeout_s=None, **kwargs) -> MazeObservation:
59
  if action.direction in ["up", "down", "left", "right"]:
60
  self._move(action.direction)
61
  self._step_count += 1
62
  done = self._agent_pos == self._goal_pos
63
- return MazeObservation(position=self._agent_pos, grid_view=self._render(),
64
- done=done, reward=10 if done else -1)
 
 
 
 
 
65
 
66
  def _move(self, d):
67
  i, j = self._agent_pos
 
53
  if seed: random.seed(seed)
54
  self._generate_new_maze()
55
  self._episode_id = episode_id
56
+ return MazeObservation(
57
+ position=self._agent_pos,
58
+ grid_view=self._render(),
59
+ valid_moves=self._get_valid_moves(), # add this
60
+ done=False,
61
+ reward=0
62
+ )
63
+
64
+ def _get_valid_moves(self) -> list:
65
+ moves = []
66
+ i, j = self._agent_pos
67
+ if i > 0 and self._maze[i-1][j]: moves.append("up")
68
+ if i < self.row-1 and self._maze[i+1][j]: moves.append("down")
69
+ if j > 0 and self._maze[i][j-1]: moves.append("left")
70
+ if j < self.col-1 and self._maze[i][j+1]: moves.append("right")
71
+ return moves
72
 
73
  def step(self, action: MazeAction, timeout_s=None, **kwargs) -> MazeObservation:
74
  if action.direction in ["up", "down", "left", "right"]:
75
  self._move(action.direction)
76
  self._step_count += 1
77
  done = self._agent_pos == self._goal_pos
78
+ return MazeObservation(
79
+ position=self._agent_pos,
80
+ grid_view=self._render(),
81
+ valid_moves=self._get_valid_moves(), # add this
82
+ done=done,
83
+ reward=10 if done else -1
84
+ )
85
 
86
  def _move(self, d):
87
  i, j = self._agent_pos