timurgepard commited on
Commit
ceb68c0
·
verified ·
1 Parent(s): a93e7c9

Upload 2 files

Browse files
Files changed (2) hide show
  1. symphony_S2/symphony.py +372 -0
  2. symphony_S2/train.py +233 -0
symphony_S2/symphony.py ADDED
@@ -0,0 +1,372 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.optim as optim
4
+ import math
5
+ import torch.jit as jit
6
+ import random
7
+
8
+
9
+
10
+ #==============================================================================================
11
+ #==============================================================================================
12
+ #=========================================SYMPHONY=============================================
13
+ #==============================================================================================
14
+ #==============================================================================================
15
+
16
+
17
+ class Adam(optim.Optimizer):
18
+ def __init__(self, params, lr=3e-4, weight_decay=0.01, betas=((math.sqrt(5)-1)/2, 0.995)):
19
+ defaults = dict(lr=lr, betas=betas)
20
+ super().__init__(params, defaults)
21
+ self.wd = weight_decay
22
+ self.lr = lr
23
+ self.beta1, self.beta2 = betas
24
+ self.beta1_, self.beta2_ = 1-self.beta1, 1-self.beta2
25
+ self.eps = 1e-8 # You can make this configurable if needed
26
+
27
+
28
+ @torch.no_grad()
29
+ def step(self):
30
+ for group in self.param_groups:
31
+ for p in group['params']:
32
+ if p.grad is None:
33
+ continue
34
+
35
+ grad = p.grad
36
+
37
+ state = self.state[p]
38
+ if len(state) == 0:
39
+ state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
40
+ state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
41
+
42
+
43
+ m = state['m']
44
+ v = state['v']
45
+
46
+ # Update biased first moment estimate
47
+ m.mul_(self.beta1).add_(grad, alpha=self.beta1_)
48
+ # Update biased second raw moment estimate
49
+ v.mul_(self.beta2).addcmul_(grad, grad, value=self.beta2_)
50
+
51
+ # Update parameters
52
+ p.add_(m/(v.sqrt() + self.eps) + self.wd*p, alpha=-self.lr)
53
+
54
+
55
+
56
+
57
+
58
+
59
+ #Rectified Huber Symmetric Error Loss Function via JIT Module
60
+ # nn.Module -> JIT C++ graph
61
+ class ReHSE(jit.ScriptModule):
62
+ def __init__(self):
63
+ super(ReHSE, self).__init__()
64
+
65
+ @jit.script_method
66
+ def forward(self, e):
67
+ return (e * torch.tanh(e/2)).mean()
68
+
69
+
70
+ #Rectified Huber Asymmetric Error Loss Function via JIT Module
71
+ # nn.Module -> JIT C++ graph
72
+ class ReHAE(jit.ScriptModule):
73
+ def __init__(self):
74
+ super(ReHAE, self).__init__()
75
+
76
+ @jit.script_method
77
+ def forward(self, e):
78
+ return (torch.abs(e) * torch.tanh(e/2)).mean()
79
+
80
+
81
+
82
+ #ReSine Activation Function
83
+ # nn.Module -> JIT C++ graph
84
+ class ReSine(jit.ScriptModule):
85
+ def __init__(self, hidden_dim=256):
86
+ super(ReSine, self).__init__()
87
+ k = 1/math.sqrt(hidden_dim)
88
+ self.s = nn.Parameter(data=2.0*k*torch.rand(hidden_dim)-k, requires_grad=True)
89
+
90
+ @jit.script_method
91
+ def forward(self, x):
92
+ s = torch.sigmoid(self.s)
93
+ x = s*torch.sin(x/s)
94
+ return x/(1+torch.exp(-1.5*x/s))
95
+
96
+ #SilentDropout
97
+ # nn.Module -> JIT C++ graph
98
+ class GradientDropout(jit.ScriptModule):
99
+ def __init__(self, p=0.5):
100
+ super(GradientDropout, self).__init__()
101
+ self.p = p
102
+
103
+ @jit.script_method
104
+ def forward(self, x):
105
+ mask = (torch.rand_like(x) > self.p).float()
106
+ return mask * x + (1.0-mask) * x.detach()
107
+
108
+
109
+ class Swaddling(jit.ScriptModule):
110
+ def __init__(self):
111
+ super(Swaddling, self).__init__()
112
+
113
+ @jit.script_method
114
+ def Omega(self, x):
115
+ return torch.log((1+x)/(1-x))
116
+
117
+ @jit.script_method
118
+ def omega(self, x):
119
+ return x*torch.log(x)
120
+
121
+
122
+ @jit.script_method
123
+ def forward(self, x, k):
124
+ return (self.Omega(x**(1/k.detach())) + k * self.omega(x) + self.Omega(k**2)).mean()
125
+
126
+
127
+
128
+ class FeedForward(jit.ScriptModule):
129
+ def __init__(self, f_in, h_dim, f_out):
130
+ super(FeedForward, self).__init__()
131
+
132
+
133
+ self.ffw = nn.Sequential(
134
+ nn.Linear(f_in, h_dim),
135
+ nn.LayerNorm(h_dim),
136
+ nn.Linear(h_dim, h_dim),
137
+ ReSine(h_dim),
138
+ nn.Linear(h_dim, f_out)
139
+ )
140
+
141
+
142
+
143
+ @jit.script_method
144
+ def forward(self, x):
145
+ return self.ffw(x)
146
+
147
+
148
+
149
+
150
+
151
+
152
+ # nn.Module -> JIT C++ graph
153
+ class ActorCritic(jit.ScriptModule):
154
+ def __init__(self, state_dim, action_dim, h_dim, max_action=1.0):
155
+ super().__init__()
156
+
157
+
158
+ self.action_dim = action_dim
159
+ q_nodes = h_dim//4
160
+
161
+ self.a = FeedForward(state_dim, h_dim, 3*action_dim)
162
+ self.a_max = nn.Parameter(data= max_action, requires_grad=False)
163
+ self.std = 1/math.e
164
+
165
+ self.qA = FeedForward(state_dim+action_dim, h_dim, q_nodes)
166
+ self.qB = FeedForward(state_dim+action_dim, h_dim, q_nodes)
167
+ self.qC = FeedForward(state_dim+action_dim, h_dim, q_nodes)
168
+ self.qnets = nn.ModuleList([self.qA, self.qB, self.qC])
169
+
170
+
171
+ self.q_dist = q_nodes*len(self.qnets)
172
+ indexes = torch.arange(0, self.q_dist, 1)/self.q_dist
173
+ weights = torch.tanh((math.pi*(1-indexes))**math.pi) - 0.01*torch.exp(-(indexes/0.01)**2)
174
+ self.probs = nn.Parameter(data= weights/torch.sum(weights), requires_grad=False)
175
+
176
+ self.e = 1e-3
177
+ self.e_ = 1-self.e
178
+
179
+ #========= Actor Forward Pass =========
180
+
181
+ @jit.script_method
182
+ def actor(self, state, action:bool = True, noise:bool=True):
183
+ ASB = torch.tanh(self.a(state)/2).reshape(-1, 3, self.action_dim)
184
+ A, S, B = ASB [:, 0], ASB[:, 1].abs(), ASB[:, 2].abs()
185
+ N = self.std * torch.randn_like(A).clamp(-math.e, math.e)
186
+ return self.a_max * torch.tanh(float(action) * S * A + float(noise) * N), S.clamp(self.e, self.e_), B.clamp(self.e, self.e_)
187
+
188
+ #========= Critic Forward Pass =========
189
+ # take 3 distributions and concatenate them
190
+ @jit.script_method
191
+ def critic(self, state, action):
192
+ x = torch.cat([state, action], -1)
193
+ return torch.cat([qnet(x) for qnet in self.qnets], dim=-1)
194
+
195
+
196
+
197
+ @jit.script_method
198
+ def critic_soft(self, state, action):
199
+ q = self.probs * self.critic(state, action).sort(dim=-1)[0]
200
+ q = q.sum(dim=-1, keepdim=True)
201
+ return q, q.detach()
202
+
203
+
204
+
205
+ class Nets(jit.ScriptModule):
206
+ def __init__(self, state_dim, action_dim, h_dim, max_action, device):
207
+ super(Nets, self).__init__()
208
+
209
+ self.online = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
210
+ self.target = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
211
+ self.target.load_state_dict(self.online.state_dict())
212
+
213
+ self.rehse = ReHSE()
214
+ self.rehae = ReHAE()
215
+ self.sw = Swaddling()
216
+ self.tau = 0.005
217
+ self.tau_ = 1.0 - self.tau
218
+ self.alpha = (math.sqrt(5)-1)/2
219
+ self.alpha_= 1.0 - self.alpha
220
+ self.q_next_ema = torch.zeros(1, device=device)
221
+
222
+
223
+ @torch.no_grad()
224
+ def tau_update(self):
225
+ for target_param, param in zip(self.target.qnets.parameters(), self.online.qnets.parameters()):
226
+ target_param.data.copy_(self.tau_*target_param.data + self.tau*param.data)
227
+
228
+
229
+ @jit.script_method
230
+ def forward(self, state, action, reward, next_state, not_done_gamma):
231
+
232
+ next_action, next_scale, next_beta = self.online.actor(next_state)
233
+ q_next_target, q_next_target_value = self.target.critic_soft(next_state, next_action)
234
+ q_target = reward + not_done_gamma * q_next_target_value
235
+ q_pred = self.online.critic(state, action)
236
+
237
+ q_next_ema = self.alpha * self.q_next_ema + self.alpha_ * q_next_target_value
238
+ nets_loss = -self.rehae((q_next_target - q_next_ema)/q_next_ema.abs()) + self.rehse(q_pred-q_target) + self.sw(next_scale, next_beta)
239
+ self.q_next_ema = q_next_ema.mean()
240
+
241
+ return nets_loss, next_scale.detach()
242
+
243
+
244
+
245
+ # Define the algorithm
246
+ class Symphony(object):
247
+ def __init__(self, capacity, state_dim, action_dim, h_dim, device, max_action, learning_rate=3e-4):
248
+
249
+ self.state_dim = state_dim
250
+ self.action_dim = action_dim
251
+ self.device = device
252
+
253
+
254
+ self.replay_buffer = ReplayBuffer(capacity, state_dim, action_dim, device)
255
+ self.nets = Nets(state_dim, action_dim, h_dim, max_action, device)
256
+ self.nets_optimizer = Adam(self.nets.online.parameters(), lr=learning_rate)
257
+ self.batch_size = self.nets.online.q_dist
258
+
259
+
260
+ def select_action(self, state, action = True, noise=True):
261
+ state = torch.tensor(state, dtype=torch.float32, device=self.device).reshape(-1,self.state_dim)
262
+ with torch.no_grad(): action = self.nets.online.actor(state, action, noise)[0]
263
+ return action.cpu().data.numpy().flatten()
264
+
265
+ """
266
+ def select_action(self, state, action = True, noise=True):
267
+ with torch.no_grad(): return self.nets.online.actor(state, action, noise)[0]
268
+ """
269
+
270
+
271
+
272
+ def train(self):
273
+
274
+ torch.manual_seed(random.randint(0,2**32-1))
275
+
276
+ state, action, reward, next_state, not_done_gamma = self.replay_buffer.sample(self.batch_size)
277
+ self.nets_optimizer.zero_grad(set_to_none=True)
278
+
279
+ nets_loss, scale = self.nets(state, action, reward, next_state, not_done_gamma)
280
+
281
+ nets_loss.backward()
282
+ self.nets_optimizer.step()
283
+ self.nets.tau_update()
284
+
285
+ return scale
286
+
287
+
288
+
289
+
290
+
291
+
292
+ class ReplayBuffer:
293
+ def __init__(self, capacity, state_dim, action_dim, device):
294
+
295
+ self.capacity, self.length, self.device = capacity, 0, device
296
+
297
+ self.states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
298
+ self.actions = torch.zeros((self.capacity, action_dim), dtype=torch.float32, device=device)
299
+ self.rewards = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
300
+ self.next_states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
301
+ self.not_dones_gamma = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
302
+
303
+ self.norm = 1.0
304
+
305
+
306
+ def add(self, state, action, reward, next_state, done):
307
+
308
+ if self.length<self.capacity: self.length += 1
309
+
310
+ idx = self.length-1
311
+
312
+ self.states[idx,:] = torch.tensor(state, dtype=torch.float32, device=self.device)
313
+ self.actions[idx,:] = torch.tensor(action, dtype=torch.float32, device=self.device)
314
+ self.rewards[idx,:] = torch.tensor([reward/self.norm], dtype=torch.float32, device=self.device)
315
+ self.next_states[idx,:] = torch.tensor(next_state, dtype=torch.float32, device=self.device)
316
+ self.not_dones_gamma[idx,:] = torch.tensor([0.99 * (1.0 - float(done))], dtype=torch.float32, device=self.device)
317
+
318
+ if self.length>=self.capacity:
319
+ shift = 2 if self.not_dones_gamma[0,:].item() == 0.0 else 1
320
+ self.states = torch.roll(self.states, shifts=-shift, dims=0)
321
+ self.actions = torch.roll(self.actions, shifts=-shift, dims=0)
322
+ self.rewards = torch.roll(self.rewards, shifts=-shift, dims=0)
323
+ self.next_states = torch.roll(self.next_states, shifts=-shift, dims=0)
324
+ self.not_dones_gamma = torch.roll(self.not_dones_gamma, shifts=-shift, dims=0)
325
+
326
+
327
+
328
+ def sample(self, batch_size):
329
+
330
+ indices = torch.multinomial(self.probs, num_samples=batch_size, replacement=True)
331
+
332
+ return (
333
+ self.states[indices],
334
+ self.actions[indices],
335
+ self.rewards[indices],
336
+ self.next_states[indices],
337
+ self.not_dones_gamma[indices]
338
+ )
339
+
340
+
341
+ def __len__(self):
342
+ return self.length
343
+
344
+
345
+ #==============================================================
346
+ #==============================================================
347
+ #===========================HELPERS============================
348
+ #==============================================================
349
+ #==============================================================
350
+
351
+ def norm_fill(self, times:int):
352
+
353
+
354
+ print("copying replay data, current length", self.length)
355
+
356
+ self.states = self.states[:self.length].repeat(times, 1)
357
+ self.actions = self.actions[:self.length].repeat(times, 1)
358
+ self.rewards = self.rewards[:self.length].repeat(times, 1)
359
+ self.next_states = self.next_states[:self.length].repeat(times, 1)
360
+ self.not_dones_gamma = self.not_dones_gamma[:self.length].repeat(times, 1)
361
+
362
+ self.norm = torch.mean(torch.abs(self.rewards)).item()
363
+
364
+ self.rewards /= self.norm
365
+
366
+ self.length = times*self.length
367
+
368
+ indexes = torch.arange(0, self.length, 1)/self.length
369
+ weights = torch.tanh((math.pi*indexes)**math.pi) - 0.01*torch.exp(-((indexes-1)/0.01)**2)
370
+ self.probs = weights/torch.sum(weights)
371
+
372
+ print("new replay buffer length: ", self.length)
symphony_S2/train.py CHANGED
@@ -0,0 +1,233 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from symphony import Symphony
2
+ import gymnasium as gym
3
+
4
+ import logging
5
+ logging.getLogger().setLevel(logging.CRITICAL)
6
+ import torch
7
+ import numpy as np
8
+ import random
9
+ import pickle
10
+ import time
11
+ import os, re
12
+
13
+ #############################################
14
+ # -----------Helper Functions---------------#
15
+ #############################################
16
+
17
+
18
+
19
+ # random seeds for reproducing the experiment
20
+ def seed_reset():
21
+ r1, r2, r3 = random.randint(0,2**32-1), random.randint(0,2**32-1), random.randint(0,2**32-1)
22
+ torch.manual_seed(r1)
23
+ np.random.seed(r2)
24
+ random.seed(r3)
25
+ return r1, r2, r3
26
+
27
+
28
+ def extract_r1_r2_r3():
29
+ pattern = r'history_(\d+)_(\d+)_(\d+)\.csv'
30
+
31
+ # Iterate through the files in the given directory
32
+ for filename in os.listdir():
33
+ # Match the filename with the pattern
34
+ match = re.match(pattern, filename)
35
+ if match:
36
+ # Extract the numbers r1, r2, and r3 from the filename
37
+ return map(int, match.groups())
38
+ return None
39
+
40
+
41
+ #write or append to the history log file
42
+ class LogFile(object):
43
+ def __init__(self, log_name_main, log_name_opt):
44
+ self.log_name_main = log_name_main
45
+ self.log_name_opt = log_name_opt
46
+ def write(self, text):
47
+ with open(self.log_name_main, 'a+') as file:
48
+ file.write(text)
49
+ def write_opt(self, text):
50
+ with open(self.log_name_opt, 'a+') as file:
51
+ file.write(text)
52
+ def clean(self):
53
+ with open(self.log_name_main, 'w') as file:
54
+ file.write("step,return\n")
55
+ with open(self.log_name_opt, 'w') as file:
56
+ file.write("ep,return,steps,scale\n")
57
+
58
+
59
+ numbers = extract_r1_r2_r3()
60
+
61
+ if numbers != None:
62
+ # derive random numbers from history file
63
+ r1, r2, r3 = numbers
64
+ else:
65
+ # generate new random seeds
66
+ r1, r2, r3 = seed_reset()
67
+
68
+
69
+
70
+ print(r1, ", ", r2, ", ", r3)
71
+
72
+ log_name_main = "history_" + str(r1) + "_" + str(r2) + "_" + str(r3) + ".csv"
73
+ log_name_opt = "episodes_" + str(r1) + "_" + str(r2) + "_" + str(r3) + ".csv"
74
+ log_file = LogFile(log_name_main, log_name_opt)
75
+
76
+
77
+ def save(algo, total_rewards, total_steps):
78
+
79
+ torch.save(algo.nets.online.state_dict(), 'nets_online_model.pt')
80
+ torch.save(algo.nets.target.state_dict(), 'nets_target_model.pt')
81
+ torch.save(algo.nets_optimizer.state_dict(), 'nets_optimizer.pt')
82
+ print("saving... the buffer length = ", algo.replay_buffer.length, end="")
83
+ with open('data', 'wb') as file:
84
+ pickle.dump({'buffer': algo.replay_buffer, 'q_next_ema': algo.nets.q_next_ema, 'total_rewards': total_rewards, 'total_steps': total_steps}, file)
85
+ print(" > done")
86
+
87
+
88
+ def load(algo, Q_learning):
89
+
90
+ total_rewards, total_steps = [], 0
91
+
92
+ try:
93
+ print("loading models...")
94
+ algo.nets.online.load_state_dict(torch.load('nets_online_model.pt', weights_only=True))
95
+ algo.nets.target.load_state_dict(torch.load('nets_target_model.pt', weights_only=True))
96
+ algo.nets_optimizer.load_state_dict(torch.load('nets_optimizer.pt', weights_only=True))
97
+ print('models loaded')
98
+ #sim_loop(env_valid, 100, True, False, algo, [], total_steps=0)
99
+ except:
100
+ print("problem during loading models")
101
+
102
+
103
+ try:
104
+ print("loading buffer...")
105
+ with open('data', 'rb') as file:
106
+ dict = pickle.load(file)
107
+ algo.replay_buffer = dict['buffer']
108
+ algo.nets.q_next_ema = dict['q_next_ema']
109
+ total_rewards = dict['total_rewards']
110
+ total_steps = dict['total_steps']
111
+ if algo.replay_buffer.length>=explore_time and not Q_learning: Q_learning = True
112
+
113
+ print('buffer loaded, Q_ema', round(algo.nets.q_next_ema.item(), 2), ', average_reward = ', round(np.mean(total_rewards[-300:]), 2))
114
+
115
+ except:
116
+ print("problem during loading buffer")
117
+
118
+ return Q_learning, total_rewards, total_steps
119
+
120
+ #############################################
121
+ # ---------------Parametres-----------------#
122
+ #############################################
123
+
124
+ #global parameters
125
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
126
+
127
+ print(device)
128
+ G = 3
129
+ learning_rate = 5e-5
130
+ explore_time, times = 20480, 25
131
+ capacity = explore_time * times
132
+ h_dim = capacity//1000
133
+ limit_step = 1000 #max steps per episode
134
+ limit_eval = 1000 #max steps per evaluation
135
+ num_episodes = 1000000
136
+ start_episode = 1 #number for the identification of the current episode
137
+ episode_rewards_all, episode_steps_all, test_rewards, Q_learning, total_steps = [], [], [], False, 0
138
+
139
+ # environment type.
140
+ option = 3
141
+ pre_valid = True
142
+ if option == 0: env_name = '"BipedalWalker-v3'
143
+ elif option == 1: env_name = 'HalfCheetah-v4'
144
+ elif option == 2: env_name = 'Walker2d-v4'
145
+ elif option == 3: env_name = 'Humanoid-v4'
146
+ elif option == 4: env_name = 'Ant-v4'
147
+ elif option == 5: env_name = 'Swimmer-v4'
148
+ elif option == 6: env_name = 'Hopper-v4'
149
+ elif option == 7: env_name = 'Pusher-v4'
150
+
151
+ env = gym.make(env_name)
152
+ env_test = gym.make(env_name)
153
+ env_valid = gym.make(env_name, render_mode="human")
154
+
155
+ state_dim = env.observation_space.shape[0]
156
+ action_dim= env.action_space.shape[0]
157
+ #max_action = torch.FloatTensor(env.action_space.high) if env.action_space.is_bounded() else torch.ones(action_dim)
158
+ max_action = torch.ones(action_dim)
159
+
160
+ print("action_dim: ", action_dim, "state_dim: ", state_dim, "max_action:", max_action)
161
+
162
+ algo = Symphony(capacity, state_dim, action_dim, h_dim, device, max_action, learning_rate)
163
+
164
+
165
+ # Loop for episodes:[ State -> Loop for one episode: [ Action, Next State, Reward, Done, State = Next State ] ]
166
+ def sim_loop(env, episodes, testing, Q_learning, algo, total_rewards, total_steps):
167
+
168
+
169
+ start_episode = len(total_rewards) + 1
170
+
171
+
172
+ for episode in range(start_episode, episodes+1):
173
+
174
+ Return = 0.0
175
+ state = env.reset()[0]
176
+
177
+ for steps in range(1,limit_step+1):
178
+
179
+ seed_reset()
180
+ total_steps += 1
181
+
182
+ # Activate training if explore time is reached and if it is not testing mode:
183
+ if testing:
184
+ Q_learning = False
185
+ else:
186
+ if algo.replay_buffer.length>=explore_time and not Q_learning:
187
+ Q_learning = True
188
+ algo.replay_buffer.norm_fill(times)
189
+ print("started training")
190
+
191
+ # if total steps is divisible to 2500 save models, stop training and do testing, return to training:
192
+ if Q_learning and total_steps>=2500 and total_steps%2500==0:
193
+ save(algo, total_rewards, total_steps)
194
+ print("start testing")
195
+ log_file.write(str(total_steps) + ",")
196
+ Return = sim_loop(env_test, 25, True, Q_learning, algo, [], total_steps=0)
197
+ print("end of testing")
198
+ log_file.write(str(round(Return, 2)) + "\n")
199
+
200
+
201
+ # if steps is close to episode limit (e.g. 950) we shut down actions and leave noise to get Terminal Transition:
202
+ active = steps<(limit_step-50) if Q_learning else True
203
+ action = algo.select_action(state, action=active, noise=not testing)
204
+ next_state, reward, done, truncated, info = env.step(action)
205
+ if not testing: algo.replay_buffer.add(state, action, reward, next_state, done)
206
+ Return += reward
207
+
208
+ # actual training
209
+ if Q_learning: [scale := algo.train() for _ in range(G)]
210
+ if done or truncated: break
211
+ state = next_state
212
+
213
+
214
+
215
+ total_rewards.append(Return)
216
+ average_reward = np.mean(total_rewards[-300:])
217
+
218
+
219
+ print(f"Ep {episode}: Rtrn = {Return:.2f}, Avg = {average_reward:.2f}| ep steps = {steps} | total_steps = {total_steps}")
220
+ if not testing and Q_learning: log_file.write_opt(str(episode) + "," + str(round(Return, 2)) + "," + str(total_steps) + "," + str(round(scale.mean().item(), 4)) + "\n")
221
+
222
+
223
+ return np.mean(total_rewards).item()
224
+
225
+
226
+
227
+
228
+ # Loading existing models
229
+ Q_learning, total_rewards, total_steps = load(algo, Q_learning)
230
+ if not Q_learning: log_file.clean()
231
+
232
+ # Training
233
+ sim_loop(env, num_episodes, False, Q_learning, algo, total_rewards, total_steps)