ArseniyPerchik commited on
Commit
a853d77
·
1 Parent(s): fbd53e3
.gitignore CHANGED
@@ -27,4 +27,5 @@ my_folder
27
  results
28
  test-trainer
29
  .gradio
30
- secrets.txt
 
 
27
  results
28
  test-trainer
29
  .gradio
30
+ secrets.txt
31
+ ppo_tensorboard
draft_2.py CHANGED
@@ -1,13 +1,27 @@
1
- import numpy as np
2
 
3
- # angle_deg = 350 # for example
4
- # angle_rad = np.deg2rad(angle_deg)
5
- #
6
- # vector = np.array([np.cos(angle_rad), np.sin(angle_rad)])
7
- # print(vector)
8
 
9
- input_angle = 0.5
10
- angle_rad = 2 * np.pi * input_angle
11
- vector_2 = np.array([np.cos(angle_rad), np.sin(angle_rad)])
12
- print(vector_2)
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
 
3
+ from stable_baselines3 import PPO
4
+ from stable_baselines3.common.env_util import make_vec_env
5
+ import torch
 
 
6
 
7
+ # Parallel environments
8
+ vec_env = make_vec_env("CartPole-v1", n_envs=4)
 
 
9
 
10
+ policy_kwargs = dict(activation_fn=torch.nn.ReLU,
11
+ net_arch=dict(pi=[32, 32], vf=[32, 32]))
12
+ model = PPO("MlpPolicy", vec_env,
13
+ verbose=1,
14
+ policy_kwargs=policy_kwargs,
15
+ tensorboard_log="./ppo_tensorboard/")
16
+ model.learn(total_timesteps=100000, tb_log_name="CartPole")
17
+ model.save("ppo_cartpole")
18
+
19
+ del model # remove to demonstrate saving and loading
20
+
21
+ model = PPO.load("ppo_cartpole")
22
+
23
+ obs = vec_env.reset()
24
+ while True:
25
+ action, _states = model.predict(obs)
26
+ obs, rewards, dones, info = vec_env.step(action)
27
+ vec_env.render("human")
globals.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import matplotlib.pyplot as plt
2
+ import matplotlib
3
+ from matplotlib.patches import Circle
4
+ import math
5
+ import gymnasium as gym
6
+ import numpy as np
7
+ from gymnasium import spaces
8
+ from stable_baselines3.common.env_checker import check_env
9
+ from stable_baselines3 import PPO
10
+ from stable_baselines3.common.env_util import make_vec_env
11
+ import torch
good_policies/sac_warehouse_r_10_working_v1.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bfd9ad5aea06290c82070061ad1c15f77369e2ce0ada6d2893af143301b38f19
3
+ size 105325
good_policies/sac_warehouse_r_20.zip ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:39aca94cf39aaf179a4cc9189882b18450c6eb90d61708899a87949e8bd78792
3
+ size 105325
plot_functions.py CHANGED
@@ -1,5 +1,4 @@
1
- import matplotlib.pyplot as plt
2
- import matplotlib
3
 
4
 
5
 
@@ -7,7 +6,17 @@ import matplotlib
7
  def plot_env(ax, info):
8
  ax.cla()
9
  env = info['env']
10
- ax.plot([1, 1], [1, 2], '.', color='b', alpha=0.5, linewidth=5, markersize=20)
 
 
 
 
 
 
 
 
 
 
11
  # ax.set_xlim([min(n_agents_list) - 20, max(n_agents_list) + 20])
12
  ax.set_xlim([0, 100])
13
  ax.set_ylim([0, 100])
@@ -16,7 +25,7 @@ def plot_env(ax, info):
16
  # ax.set_ylabel('Success Rate', fontsize=27)
17
  # ax.set_title(f'{img_dir[:-4]} Map | time limit: {time_to_think_limit} sec.')
18
  # set_plot_title(ax, f'{img_dir[:-4]} Map | time limit: {time_to_think_limit} sec.', size=11)
19
- ax.set_title(f'Warehouse', fontweight="bold", size=30)
20
  # set_legend(ax, size=18)
21
  # labelsize = 20
22
  # ax.xaxis.set_tick_params(labelsize=labelsize)
 
1
+ from globals import *
 
2
 
3
 
4
 
 
6
  def plot_env(ax, info):
7
  ax.cla()
8
  env = info['env']
9
+ agent_x, agent_y = env.agent_x, env.agent_y
10
+ goal_x, goal_y = env.goal_x, env.goal_y
11
+
12
+ # agent
13
+ ax.plot([agent_x], [agent_y], marker='o', color='b', alpha=0.5, linewidth=5, markersize=15)
14
+
15
+ # target
16
+ ax.plot([goal_x], [goal_y], marker='X', color='orange', alpha=0.5, linewidth=5, markersize=15)
17
+ circle = Circle((goal_x, goal_y), env.RADIUS_COVERAGE, color='orange', fill=True, alpha=0.3)
18
+ ax.add_patch(circle)
19
+
20
  # ax.set_xlim([min(n_agents_list) - 20, max(n_agents_list) + 20])
21
  ax.set_xlim([0, 100])
22
  ax.set_ylim([0, 100])
 
25
  # ax.set_ylabel('Success Rate', fontsize=27)
26
  # ax.set_title(f'{img_dir[:-4]} Map | time limit: {time_to_think_limit} sec.')
27
  # set_plot_title(ax, f'{img_dir[:-4]} Map | time limit: {time_to_think_limit} sec.', size=11)
28
+ ax.set_title(f'Warehouse Env | step {env.step_counter}', fontweight="bold", size=10)
29
  # set_legend(ax, size=18)
30
  # labelsize = 20
31
  # ax.xaxis.set_tick_params(labelsize=labelsize)
register_env.py ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from gymnasium.envs.registration import register
2
+ from warehouse_env import WarehouseEnv
3
+ # Example for the CartPole environment
4
+ register(
5
+ # unique identifier for the env `name-version`
6
+ id="WarehouseEnv",
7
+ # path to the class for creating the env
8
+ # Note: entry_point also accept a class as input (and not only a string)
9
+ entry_point=WarehouseEnv,
10
+ # Max number of steps per episode, using a `TimeLimitWrapper`
11
+ max_episode_steps=500,
12
+ )
train_agent.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from warehouse_env import *
2
+ from stable_baselines3 import SAC
3
+
4
+
5
+ def train_func(alg_name='PPO'):
6
+ env = WarehouseEnv(render_mode='')
7
+
8
+ if alg_name == 'PPO':
9
+ # PPO
10
+ policy_kwargs = dict(activation_fn=torch.nn.ReLU,
11
+ net_arch=dict(pi=[64, 64], vf=[64, 64]))
12
+ model = PPO("MlpPolicy", env,
13
+ verbose=1,
14
+ policy_kwargs=policy_kwargs,
15
+ tensorboard_log="./ppo_tensorboard/",
16
+ # learning_rate=0.0003,
17
+ # clip_range=0.1,
18
+ )
19
+ model.learn(total_timesteps=500000, tb_log_name="WarehouseEnv")
20
+ model.save("ppo_warehouse")
21
+
22
+ elif alg_name == 'SAC':
23
+ # policy_kwargs = dict(net_arch=dict(pi=[256, 256], qf=[400, 300]))
24
+ # policy_kwargs = dict(net_arch=[512, 512]) # Two shared hidden layers
25
+ policy_kwargs = dict(net_arch=[32, 32]) # Two shared hidden layers
26
+ model = SAC("MlpPolicy", env, verbose=1,
27
+ tensorboard_log="./ppo_tensorboard/",
28
+ # learning_rate=0.0003,
29
+ policy_kwargs = policy_kwargs,
30
+ )
31
+ model.learn(total_timesteps=700000, log_interval=4, tb_log_name="sac_WarehouseEnv")
32
+ model.save("sac_warehouse")
33
+
34
+ else:
35
+ raise RuntimeError('no model')
36
+
37
+
38
+
39
+
40
+ def exec_func(alg_name='PPO', model_name=None):
41
+ env = WarehouseEnv(render_mode='human')
42
+ if alg_name == 'PPO':
43
+ model_name = "ppo_warehouse" if model_name is None else model_name
44
+ model = PPO.load(model_name)
45
+ elif alg_name == 'SAC':
46
+ model_name = "sac_warehouse" if model_name is None else model_name
47
+ model = SAC.load(model_name)
48
+ else:
49
+ raise RuntimeError('no model')
50
+ # vec_env = model.get_env()
51
+ obs, info = env.reset()
52
+ while True:
53
+ action, _states = model.predict(obs)
54
+ obs, rewards, done, trunc, info = env.step(action)
55
+ env.render()
56
+ if done or trunc:
57
+ obs, info = env.reset()
58
+
59
+
60
+ def main():
61
+ # alg_name = 'PPO'
62
+ alg_name = 'SAC'
63
+ model_name = 'sac_warehouse_working_v1'
64
+ # train_func(alg_name)
65
+ exec_func(alg_name=alg_name, model_name=model_name)
66
+
67
+
68
+ if __name__ == '__main__':
69
+ main()
warehouse_env.py CHANGED
@@ -1,13 +1,7 @@
1
- import math
2
- import gymnasium as gym
3
- import numpy as np
4
- from gymnasium import spaces
5
- from stable_baselines3.common.env_checker import check_env
6
- from stable_baselines3 import PPO
7
- from stable_baselines3.common.env_util import make_vec_env
8
  from plot_functions import *
9
 
10
 
 
11
  class WarehouseEnv(gym.Env):
12
  """
13
  WarehouseEnv Environment that follows gym interface.
@@ -32,11 +26,12 @@ class WarehouseEnv(gym.Env):
32
  self.ACTIONS: int = 2
33
  self.N_CHANNELS: int = 4
34
  self.SIDE: int = 100
35
- self.RADIUS_COVERAGE: int = 5
36
- self.MAX_STEPS: int = 200
 
37
  self.DIAG: float = math.sqrt(self.SIDE ** 2 + self.SIDE ** 2)
38
  self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.ACTIONS,), dtype=np.float32)
39
- self.observation_space = spaces.Box(low=-1, high=1, shape=(self.N_CHANNELS,), dtype=np.float64)
40
  self.field = np.zeros((self.SIDE, self.SIDE))
41
 
42
  # Agent
@@ -50,7 +45,7 @@ class WarehouseEnv(gym.Env):
50
 
51
  # to render
52
  if self.to_render:
53
- self.fig, self.ax = plt.subplots(2, 2, figsize=(17, 10))
54
 
55
  @property
56
  def rel_x(self) -> int:
@@ -63,14 +58,45 @@ class WarehouseEnv(gym.Env):
63
  def reset(self, seed=None, options=None):
64
  self.agent_x = np.random.uniform(0, self.SIDE)
65
  self.agent_y = np.random.uniform(0, self.SIDE)
 
 
66
  self.goal_x = np.random.uniform(0, self.SIDE)
67
  self.goal_y = np.random.uniform(0, self.SIDE)
68
  self.step_counter = 0
69
  self.terminated = False
70
  self.truncated = False
71
- observation = np.array([self.agent_x / self.SIDE, self.agent_y / self.SIDE, self.rel_x / self.SIDE, self.rel_y / self.SIDE])
72
  info = {}
73
- return observation, info
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74
 
75
  def step(self, action):
76
  if self.terminated:
@@ -87,34 +113,37 @@ class WarehouseEnv(gym.Env):
87
  self.agent_x += input_vel * mov_x
88
  self.agent_y += input_vel * mov_y
89
 
90
- rel_x, rel_y = self.rel_x, self.rel_y
91
- distance = math.sqrt(rel_x**2 + rel_y**2)
92
-
93
- # obs
94
- observation = np.array([self.agent_x / self.SIDE, self.agent_y / self.SIDE, rel_x / self.SIDE, rel_y / self.SIDE])
95
-
96
 
97
  # terminated + reward
98
- if not (0 <= self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE):
99
- self.terminated = True
100
- reward = -10
101
- elif distance < self.RADIUS_COVERAGE:
102
- self.terminated = True
103
- reward = 10
104
- else:
105
- reward = - (distance / self.DIAG)
 
 
 
 
106
 
107
  # truncated
108
  if self.step_counter > self.MAX_STEPS:
 
109
  self.truncated = True
110
  self.step_counter += 1
111
 
112
  # info
113
  info = {}
114
- return observation, reward, self.terminated, self.truncated, info
115
 
116
  def render(self):
117
- plot_env(self.ax[0, 0], info={'env': self})
118
  plt.tight_layout()
119
  plt.pause(0.01)
120
 
@@ -123,24 +152,9 @@ class WarehouseEnv(gym.Env):
123
 
124
 
125
  def main():
126
- env = WarehouseEnv(render_mode='human')
127
  # It will check your custom environment and output additional warnings if needed
128
- # check_env(env)
129
-
130
- # vec_env = make_vec_env(env, n_envs=4)
131
- # model = PPO("MlpPolicy", env, verbose=1)
132
- # model.learn(total_timesteps=25000)
133
- # model.save("ppo_warehouse")
134
- #
135
- # del model # remove to demonstrate saving and loading
136
-
137
- model = PPO.load("ppo_warehouse")
138
- vec_env = model.get_env()
139
- obs, info = env.reset()
140
- while True:
141
- action, _states = model.predict(obs)
142
- obs, rewards, done, trunc, info = env.step(action)
143
- env.render()
144
 
145
 
146
  if __name__ == '__main__':
 
 
 
 
 
 
 
 
1
  from plot_functions import *
2
 
3
 
4
+
5
  class WarehouseEnv(gym.Env):
6
  """
7
  WarehouseEnv Environment that follows gym interface.
 
26
  self.ACTIONS: int = 2
27
  self.N_CHANNELS: int = 4
28
  self.SIDE: int = 100
29
+ # self.RADIUS_COVERAGE: int = 20 # working v1
30
+ self.RADIUS_COVERAGE: int = 10
31
+ self.MAX_STEPS: int = 500
32
  self.DIAG: float = math.sqrt(self.SIDE ** 2 + self.SIDE ** 2)
33
  self.action_space = spaces.Box(low=-1.0, high=1.0, shape=(self.ACTIONS,), dtype=np.float32)
34
+ self.observation_space = spaces.Box(low=-1.0, high=1.0, shape=(self.N_CHANNELS,), dtype=np.float64)
35
  self.field = np.zeros((self.SIDE, self.SIDE))
36
 
37
  # Agent
 
45
 
46
  # to render
47
  if self.to_render:
48
+ self.fig, self.ax = plt.subplots(1, 1, figsize=(5, 5))
49
 
50
  @property
51
  def rel_x(self) -> int:
 
58
  def reset(self, seed=None, options=None):
59
  self.agent_x = np.random.uniform(0, self.SIDE)
60
  self.agent_y = np.random.uniform(0, self.SIDE)
61
+ # self.agent_x = 50.0
62
+ # self.agent_y = 50.0
63
  self.goal_x = np.random.uniform(0, self.SIDE)
64
  self.goal_y = np.random.uniform(0, self.SIDE)
65
  self.step_counter = 0
66
  self.terminated = False
67
  self.truncated = False
 
68
  info = {}
69
+ return self.build_obs(), info
70
+
71
+ def build_obs(self):
72
+ observation = np.array([
73
+ self.agent_x / self.SIDE * 2 - 1,
74
+ self.agent_y / self.SIDE * 2 - 1,
75
+ self.rel_x / self.SIDE * 2 - 1,
76
+ self.rel_y / self.SIDE * 2 - 1
77
+ # self.goal_x / self.SIDE * 2 - 1,
78
+ # self.goal_y / self.SIDE * 2 - 1
79
+ ])
80
+ return observation
81
+
82
+ def build_reward(self):
83
+ rel_x, rel_y = self.rel_x, self.rel_y
84
+ # rel_x = self.agent_x - self.goal_x
85
+ # rel_y = self.agent_y - self.goal_y
86
+ distance = math.sqrt(rel_x ** 2 + rel_y ** 2)
87
+
88
+ # terminated + reward
89
+ # if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE):
90
+ if distance < self.RADIUS_COVERAGE:
91
+ self.terminated = True
92
+ self.truncated = True
93
+ return 20
94
+ elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE:
95
+ self.terminated = True
96
+ self.truncated = True
97
+ return -10
98
+ # return -1 * (distance / self.DIAG)
99
+ return -0.001
100
 
101
  def step(self, action):
102
  if self.terminated:
 
113
  self.agent_x += input_vel * mov_x
114
  self.agent_y += input_vel * mov_y
115
 
116
+ # rel_x, rel_y = self.rel_x, self.rel_y
117
+ # rel_x = self.agent_x - self.goal_x
118
+ # rel_y = self.agent_y - self.goal_y
119
+ # distance = math.sqrt(rel_x**2 + rel_y**2)
 
 
120
 
121
  # terminated + reward
122
+ # if not (0 < self.agent_x < self.SIDE) or not (0 <= self.agent_y < self.SIDE):
123
+ # if distance < self.RADIUS_COVERAGE:
124
+ # self.terminated = True
125
+ # self.truncated = True
126
+ # reward = 2
127
+ # print('Win')
128
+ # elif self.agent_x < 0 or self.agent_x > self.SIDE or self.agent_y < 0 or self.agent_y > self.SIDE:
129
+ # self.terminated = True
130
+ # self.truncated = True
131
+ # reward = -2
132
+ # else:
133
+ # reward = -1 * (distance / self.DIAG)
134
 
135
  # truncated
136
  if self.step_counter > self.MAX_STEPS:
137
+ # self.terminated = True
138
  self.truncated = True
139
  self.step_counter += 1
140
 
141
  # info
142
  info = {}
143
+ return self.build_obs(), self.build_reward(), self.terminated, self.truncated, info
144
 
145
  def render(self):
146
+ plot_env(self.ax, info={'env': self})
147
  plt.tight_layout()
148
  plt.pause(0.01)
149
 
 
152
 
153
 
154
  def main():
155
+ env = WarehouseEnv(render_mode='')
156
  # It will check your custom environment and output additional warnings if needed
157
+ check_env(env)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
 
160
  if __name__ == '__main__':