Agent_Control_with_Language / warehouse_env.py
ArseniyPerchik's picture
more
25a1345
from plot_functions import *
class WarehouseEnv(gym.Env):
"""
WarehouseEnv Environment that follows gym interface.
No inertia.
State:
x_a, y_a - current position [0, 100], [0, 100]
x_rel, y_rel - relative to target position [0, 100], [0, 100]
Action:
alpha - an angle (direction) [0, 1]
v - velocity [0, 1]
Reward:
-1 -> not in target radius
10 -> in target radius
"""
metadata = {"render_modes": ["human"], "render_fps": 30}
def __init__(self, render_mode):
super().__init__()
self.render_mode = render_mode
self.to_render = self.render_mode == 'human'
self.ACTIONS: int = 2
self.N_CHANNELS: int = 4
self.SIDE: int = 100
# self.RADIUS_COVERAGE: int = 20 # working v1
self.RADIUS_COVERAGE: int = 10
self.MAX_STEPS: int = 500
self.DIAG: float = math.sqrt(self.SIDE ** 2 + self.SIDE ** 2)
self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.ACTIONS,), dtype=np.float32)
self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(self.N_CHANNELS,), dtype=np.float64)
self.field = np.zeros((self.SIDE, self.SIDE))
# Agent
self.agent_x = None
self.agent_y = None
self.goal_x = None
self.goal_y = None
self.step_counter = None
self.terminated = True
self.truncated = True
# to render
if self.to_render:
self.fig, self.ax = plt.subplots(1, 1, figsize=(5, 5))
@property
def rel_x(self) -> int:
return self.agent_x - self.goal_x
@property
def rel_y(self) -> int:
return self.agent_y - self.goal_y
def reset(self, seed=None, options=None, agent_x=None, agent_y=None, goal_x=None, goal_y=None):
if agent_x is None:
self.agent_x = np.random.uniform(0, self.SIDE)
self.agent_y = np.random.uniform(0, self.SIDE)
# self.agent_x = 50.0
# self.agent_y = 50.0
else:
self.agent_x = agent_x
self.agent_y = agent_y
if goal_x is None:
self.goal_x = np.random.uniform(0, self.SIDE)
self.goal_y = np.random.uniform(0, self.SIDE)
else:
self.goal_x = goal_x
self.goal_y = goal_y
self.step_counter = 0
self.terminated = False
self.truncated = False
info = {}
return self.build_obs(), info
def build_obs(self):
observation = np.array([
self.agent_x / self.SIDE * 2 - 1,
self.agent_y / self.SIDE * 2 - 1,
self.rel_x / self.SIDE * 2 - 1,
self.rel_y / self.SIDE * 2 - 1
# self.goal_x / self.SIDE * 2 - 1,
# self.goal_y / self.SIDE * 2 - 1
])
return observation
def build_reward(self):
rel_x, rel_y = self.rel_x, self.rel_y
# rel_x = self.agent_x - self.goal_x
# rel_y = self.agent_y - self.goal_y
distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
# terminated + reward
# if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE):
if distance < self.RADIUS_COVERAGE:
self.terminated = True
self.truncated = True
return 20
elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE:
self.terminated = True
self.truncated = True
return -10
# return -1 * (distance / self.DIAG)
return -0.001
def step(self, action):
if self.terminated:
raise RuntimeError('reset the env')
# --- execute action ---
input_angle, input_vel = action
# reshape between 0 and 1
input_angle = (input_angle + 1) / 2
input_vel = (input_vel + 1) / 2
# execute
angle_rad = 2 * np.pi * input_angle
mov_x, mov_y = np.array([np.cos(angle_rad), np.sin(angle_rad)])
self.agent_x += input_vel * mov_x
self.agent_y += input_vel * mov_y
# rel_x, rel_y = self.rel_x, self.rel_y
# rel_x = self.agent_x - self.goal_x
# rel_y = self.agent_y - self.goal_y
# distance = math.sqrt(rel_x**2 + rel_y**2)
# terminated + reward
# if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE):
# if distance < self.RADIUS_COVERAGE:
# self.terminated = True
# self.truncated = True
# reward = 2
# print('Win')
# elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE:
# self.terminated = True
# self.truncated = True
# reward = -2
# else:
# reward = -1 * (distance / self.DIAG)
# truncated
if self.step_counter > self.MAX_STEPS:
# self.terminated = True
self.truncated = True
self.step_counter += 1
# info
info = {}
return self.build_obs(), self.build_reward(), self.terminated, self.truncated, info
def render(self):
plot_env(self.ax, info={'env': self})
plt.tight_layout()
plt.pause(0.01)
def close(self):
pass
def main():
env = WarehouseEnv(render_mode='')
# It will check your custom environment and output additional warnings if needed
check_env(env)
if __name__ == '__main__':
main()