from plot_functions import * class WarehouseEnv(gym.Env): """ WarehouseEnv Environment that follows gym interface. No inertia. State: x_a, y_a - current position [0, 100], [0, 100] x_rel, y_rel - relative to target position [0, 100], [0, 100] Action: alpha - an angle (direction) [0, 1] v - velocity [0, 1] Reward: -1 -> not in target radius 10 -> in target radius """ metadata = {"render_modes": ["human"], "render_fps": 30} def __init__(self, render_mode): super().__init__() self.render_mode = render_mode self.to_render = self.render_mode == 'human' self.ACTIONS: int = 2 self.N_CHANNELS: int = 4 self.SIDE: int = 100 # self.RADIUS_COVERAGE: int = 20 # working v1 self.RADIUS_COVERAGE: int = 10 self.MAX_STEPS: int = 500 self.DIAG: float = math.sqrt(self.SIDE ** 2 + self.SIDE ** 2) self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.ACTIONS,), dtype=np.float32) self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(self.N_CHANNELS,), dtype=np.float64) self.field = np.zeros((self.SIDE, self.SIDE)) # Agent self.agent_x = None self.agent_y = None self.goal_x = None self.goal_y = None self.step_counter = None self.terminated = True self.truncated = True # to render if self.to_render: self.fig, self.ax = plt.subplots(1, 1, figsize=(5, 5)) @property def rel_x(self) -> int: return self.agent_x - self.goal_x @property def rel_y(self) -> int: return self.agent_y - self.goal_y def reset(self, seed=None, options=None, agent_x=None, agent_y=None, goal_x=None, goal_y=None): if agent_x is None: self.agent_x = np.random.uniform(0, self.SIDE) self.agent_y = np.random.uniform(0, self.SIDE) # self.agent_x = 50.0 # self.agent_y = 50.0 else: self.agent_x = agent_x self.agent_y = agent_y if goal_x is None: self.goal_x = np.random.uniform(0, self.SIDE) self.goal_y = np.random.uniform(0, self.SIDE) else: self.goal_x = goal_x self.goal_y = goal_y self.step_counter = 0 self.terminated = False self.truncated = False info = {} return self.build_obs(), info def build_obs(self): observation = np.array([ self.agent_x / self.SIDE * 2 - 1, self.agent_y / self.SIDE * 2 - 1, self.rel_x / self.SIDE * 2 - 1, self.rel_y / self.SIDE * 2 - 1 # self.goal_x / self.SIDE * 2 - 1, # self.goal_y / self.SIDE * 2 - 1 ]) return observation def build_reward(self): rel_x, rel_y = self.rel_x, self.rel_y # rel_x = self.agent_x - self.goal_x # rel_y = self.agent_y - self.goal_y distance = math.sqrt(rel_x ** 2 + rel_y ** 2) # terminated + reward # if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE): if distance < self.RADIUS_COVERAGE: self.terminated = True self.truncated = True return 20 elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE: self.terminated = True self.truncated = True return -10 # return -1 * (distance / self.DIAG) return -0.001 def step(self, action): if self.terminated: raise RuntimeError('reset the env') # --- execute action --- input_angle, input_vel = action # reshape between 0 and 1 input_angle = (input_angle + 1) / 2 input_vel = (input_vel + 1) / 2 # execute angle_rad = 2 * np.pi * input_angle mov_x, mov_y = np.array([np.cos(angle_rad), np.sin(angle_rad)]) self.agent_x += input_vel * mov_x self.agent_y += input_vel * mov_y # rel_x, rel_y = self.rel_x, self.rel_y # rel_x = self.agent_x - self.goal_x # rel_y = self.agent_y - self.goal_y # distance = math.sqrt(rel_x**2 + rel_y**2) # terminated + reward # if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE): # if distance < self.RADIUS_COVERAGE: # self.terminated = True # self.truncated = True # reward = 2 # print('Win') # elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE: # self.terminated = True # self.truncated = True # reward = -2 # else: # reward = -1 * (distance / self.DIAG) # truncated if self.step_counter > self.MAX_STEPS: # self.terminated = True self.truncated = True self.step_counter += 1 # info info = {} return self.build_obs(), self.build_reward(), self.terminated, self.truncated, info def render(self): plot_env(self.ax, info={'env': self}) plt.tight_layout() plt.pause(0.01) def close(self): pass def main(): env = WarehouseEnv(render_mode='') # It will check your custom environment and output additional warnings if needed check_env(env) if __name__ == '__main__': main()