gridworld-env / server /environment.py
Abhilasha Kakoty
Initial deploy
7078f4d
from typing import Optional
from core.env_server import Environment
from core.env_server.interfaces import Transform
from models import GridAction, GridObservation, GridState
class GridWorldEnv(Environment[GridAction, GridObservation, GridState]):
def __init__(
self,
transform: Optional[Transform[GridObservation]] = None,
rubric=None,
):
super().__init__(transform=transform, rubric=rubric)
self.grid_size = 5
self.agent_pos = (0, 0)
self.goal_pos = (4, 4)
self._steps = 0
def reset(self) -> GridObservation:
self.agent_pos = (0, 0)
self._steps = 0
return self._make_obs(reward=0.0, done=False)
def step(self, action: GridAction) -> GridObservation:
self._steps += 1
r, c = self.agent_pos
move = action.direction.lower()
if move == "right" and c < self.grid_size - 1: c += 1
elif move == "left" and c > 0: c -= 1
elif move == "down" and r < self.grid_size - 1: r += 1
elif move == "up" and r > 0: r -= 1
self.agent_pos = (r, c)
done = self.agent_pos == self.goal_pos
reward = 10.0 if done else -1.0
return self._make_obs(reward=reward, done=done)
def _make_obs(self, reward: float, done: bool) -> GridObservation:
return GridObservation(
agent_pos=self.agent_pos,
goal_pos=self.goal_pos,
reward=reward,
done=done,
)
def get_state(self) -> GridState:
return GridState(episode_id="env_01", steps_taken=self._steps)
@property
def state(self) -> GridState:
return self.get_state()