from typing import Optional from core.env_server import Environment from core.env_server.interfaces import Transform from models import GridAction, GridObservation, GridState class GridWorldEnv(Environment[GridAction, GridObservation, GridState]): def __init__( self, transform: Optional[Transform[GridObservation]] = None, rubric=None, ): super().__init__(transform=transform, rubric=rubric) self.grid_size = 5 self.agent_pos = (0, 0) self.goal_pos = (4, 4) self._steps = 0 def reset(self) -> GridObservation: self.agent_pos = (0, 0) self._steps = 0 return self._make_obs(reward=0.0, done=False) def step(self, action: GridAction) -> GridObservation: self._steps += 1 r, c = self.agent_pos move = action.direction.lower() if move == "right" and c < self.grid_size - 1: c += 1 elif move == "left" and c > 0: c -= 1 elif move == "down" and r < self.grid_size - 1: r += 1 elif move == "up" and r > 0: r -= 1 self.agent_pos = (r, c) done = self.agent_pos == self.goal_pos reward = 10.0 if done else -1.0 return self._make_obs(reward=reward, done=done) def _make_obs(self, reward: float, done: bool) -> GridObservation: return GridObservation( agent_pos=self.agent_pos, goal_pos=self.goal_pos, reward=reward, done=done, ) def get_state(self) -> GridState: return GridState(episode_id="env_01", steps_taken=self._steps) @property def state(self) -> GridState: return self.get_state()