Upload 2 files
Browse files- symphony_S2/symphony.py +372 -0
- symphony_S2/train.py +233 -0
symphony_S2/symphony.py
ADDED
|
@@ -0,0 +1,372 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.optim as optim
|
| 4 |
+
import math
|
| 5 |
+
import torch.jit as jit
|
| 6 |
+
import random
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
#==============================================================================================
|
| 11 |
+
#==============================================================================================
|
| 12 |
+
#=========================================SYMPHONY=============================================
|
| 13 |
+
#==============================================================================================
|
| 14 |
+
#==============================================================================================
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
class Adam(optim.Optimizer):
|
| 18 |
+
def __init__(self, params, lr=3e-4, weight_decay=0.01, betas=((math.sqrt(5)-1)/2, 0.995)):
|
| 19 |
+
defaults = dict(lr=lr, betas=betas)
|
| 20 |
+
super().__init__(params, defaults)
|
| 21 |
+
self.wd = weight_decay
|
| 22 |
+
self.lr = lr
|
| 23 |
+
self.beta1, self.beta2 = betas
|
| 24 |
+
self.beta1_, self.beta2_ = 1-self.beta1, 1-self.beta2
|
| 25 |
+
self.eps = 1e-8 # You can make this configurable if needed
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
@torch.no_grad()
|
| 29 |
+
def step(self):
|
| 30 |
+
for group in self.param_groups:
|
| 31 |
+
for p in group['params']:
|
| 32 |
+
if p.grad is None:
|
| 33 |
+
continue
|
| 34 |
+
|
| 35 |
+
grad = p.grad
|
| 36 |
+
|
| 37 |
+
state = self.state[p]
|
| 38 |
+
if len(state) == 0:
|
| 39 |
+
state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 40 |
+
state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
m = state['m']
|
| 44 |
+
v = state['v']
|
| 45 |
+
|
| 46 |
+
# Update biased first moment estimate
|
| 47 |
+
m.mul_(self.beta1).add_(grad, alpha=self.beta1_)
|
| 48 |
+
# Update biased second raw moment estimate
|
| 49 |
+
v.mul_(self.beta2).addcmul_(grad, grad, value=self.beta2_)
|
| 50 |
+
|
| 51 |
+
# Update parameters
|
| 52 |
+
p.add_(m/(v.sqrt() + self.eps) + self.wd*p, alpha=-self.lr)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
#Rectified Huber Symmetric Error Loss Function via JIT Module
|
| 60 |
+
# nn.Module -> JIT C++ graph
|
| 61 |
+
class ReHSE(jit.ScriptModule):
|
| 62 |
+
def __init__(self):
|
| 63 |
+
super(ReHSE, self).__init__()
|
| 64 |
+
|
| 65 |
+
@jit.script_method
|
| 66 |
+
def forward(self, e):
|
| 67 |
+
return (e * torch.tanh(e/2)).mean()
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
#Rectified Huber Asymmetric Error Loss Function via JIT Module
|
| 71 |
+
# nn.Module -> JIT C++ graph
|
| 72 |
+
class ReHAE(jit.ScriptModule):
|
| 73 |
+
def __init__(self):
|
| 74 |
+
super(ReHAE, self).__init__()
|
| 75 |
+
|
| 76 |
+
@jit.script_method
|
| 77 |
+
def forward(self, e):
|
| 78 |
+
return (torch.abs(e) * torch.tanh(e/2)).mean()
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
#ReSine Activation Function
|
| 83 |
+
# nn.Module -> JIT C++ graph
|
| 84 |
+
class ReSine(jit.ScriptModule):
|
| 85 |
+
def __init__(self, hidden_dim=256):
|
| 86 |
+
super(ReSine, self).__init__()
|
| 87 |
+
k = 1/math.sqrt(hidden_dim)
|
| 88 |
+
self.s = nn.Parameter(data=2.0*k*torch.rand(hidden_dim)-k, requires_grad=True)
|
| 89 |
+
|
| 90 |
+
@jit.script_method
|
| 91 |
+
def forward(self, x):
|
| 92 |
+
s = torch.sigmoid(self.s)
|
| 93 |
+
x = s*torch.sin(x/s)
|
| 94 |
+
return x/(1+torch.exp(-1.5*x/s))
|
| 95 |
+
|
| 96 |
+
#SilentDropout
|
| 97 |
+
# nn.Module -> JIT C++ graph
|
| 98 |
+
class GradientDropout(jit.ScriptModule):
|
| 99 |
+
def __init__(self, p=0.5):
|
| 100 |
+
super(GradientDropout, self).__init__()
|
| 101 |
+
self.p = p
|
| 102 |
+
|
| 103 |
+
@jit.script_method
|
| 104 |
+
def forward(self, x):
|
| 105 |
+
mask = (torch.rand_like(x) > self.p).float()
|
| 106 |
+
return mask * x + (1.0-mask) * x.detach()
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
class Swaddling(jit.ScriptModule):
|
| 110 |
+
def __init__(self):
|
| 111 |
+
super(Swaddling, self).__init__()
|
| 112 |
+
|
| 113 |
+
@jit.script_method
|
| 114 |
+
def Omega(self, x):
|
| 115 |
+
return torch.log((1+x)/(1-x))
|
| 116 |
+
|
| 117 |
+
@jit.script_method
|
| 118 |
+
def omega(self, x):
|
| 119 |
+
return x*torch.log(x)
|
| 120 |
+
|
| 121 |
+
|
| 122 |
+
@jit.script_method
|
| 123 |
+
def forward(self, x, k):
|
| 124 |
+
return (self.Omega(x**(1/k.detach())) + k * self.omega(x) + self.Omega(k**2)).mean()
|
| 125 |
+
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
class FeedForward(jit.ScriptModule):
|
| 129 |
+
def __init__(self, f_in, h_dim, f_out):
|
| 130 |
+
super(FeedForward, self).__init__()
|
| 131 |
+
|
| 132 |
+
|
| 133 |
+
self.ffw = nn.Sequential(
|
| 134 |
+
nn.Linear(f_in, h_dim),
|
| 135 |
+
nn.LayerNorm(h_dim),
|
| 136 |
+
nn.Linear(h_dim, h_dim),
|
| 137 |
+
ReSine(h_dim),
|
| 138 |
+
nn.Linear(h_dim, f_out)
|
| 139 |
+
)
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
|
| 143 |
+
@jit.script_method
|
| 144 |
+
def forward(self, x):
|
| 145 |
+
return self.ffw(x)
|
| 146 |
+
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
# nn.Module -> JIT C++ graph
|
| 153 |
+
class ActorCritic(jit.ScriptModule):
|
| 154 |
+
def __init__(self, state_dim, action_dim, h_dim, max_action=1.0):
|
| 155 |
+
super().__init__()
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
self.action_dim = action_dim
|
| 159 |
+
q_nodes = h_dim//4
|
| 160 |
+
|
| 161 |
+
self.a = FeedForward(state_dim, h_dim, 3*action_dim)
|
| 162 |
+
self.a_max = nn.Parameter(data= max_action, requires_grad=False)
|
| 163 |
+
self.std = 1/math.e
|
| 164 |
+
|
| 165 |
+
self.qA = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
| 166 |
+
self.qB = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
| 167 |
+
self.qC = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
| 168 |
+
self.qnets = nn.ModuleList([self.qA, self.qB, self.qC])
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
self.q_dist = q_nodes*len(self.qnets)
|
| 172 |
+
indexes = torch.arange(0, self.q_dist, 1)/self.q_dist
|
| 173 |
+
weights = torch.tanh((math.pi*(1-indexes))**math.pi) - 0.01*torch.exp(-(indexes/0.01)**2)
|
| 174 |
+
self.probs = nn.Parameter(data= weights/torch.sum(weights), requires_grad=False)
|
| 175 |
+
|
| 176 |
+
self.e = 1e-3
|
| 177 |
+
self.e_ = 1-self.e
|
| 178 |
+
|
| 179 |
+
#========= Actor Forward Pass =========
|
| 180 |
+
|
| 181 |
+
@jit.script_method
|
| 182 |
+
def actor(self, state, action:bool = True, noise:bool=True):
|
| 183 |
+
ASB = torch.tanh(self.a(state)/2).reshape(-1, 3, self.action_dim)
|
| 184 |
+
A, S, B = ASB [:, 0], ASB[:, 1].abs(), ASB[:, 2].abs()
|
| 185 |
+
N = self.std * torch.randn_like(A).clamp(-math.e, math.e)
|
| 186 |
+
return self.a_max * torch.tanh(float(action) * S * A + float(noise) * N), S.clamp(self.e, self.e_), B.clamp(self.e, self.e_)
|
| 187 |
+
|
| 188 |
+
#========= Critic Forward Pass =========
|
| 189 |
+
# take 3 distributions and concatenate them
|
| 190 |
+
@jit.script_method
|
| 191 |
+
def critic(self, state, action):
|
| 192 |
+
x = torch.cat([state, action], -1)
|
| 193 |
+
return torch.cat([qnet(x) for qnet in self.qnets], dim=-1)
|
| 194 |
+
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@jit.script_method
|
| 198 |
+
def critic_soft(self, state, action):
|
| 199 |
+
q = self.probs * self.critic(state, action).sort(dim=-1)[0]
|
| 200 |
+
q = q.sum(dim=-1, keepdim=True)
|
| 201 |
+
return q, q.detach()
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
|
| 205 |
+
class Nets(jit.ScriptModule):
|
| 206 |
+
def __init__(self, state_dim, action_dim, h_dim, max_action, device):
|
| 207 |
+
super(Nets, self).__init__()
|
| 208 |
+
|
| 209 |
+
self.online = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
|
| 210 |
+
self.target = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
|
| 211 |
+
self.target.load_state_dict(self.online.state_dict())
|
| 212 |
+
|
| 213 |
+
self.rehse = ReHSE()
|
| 214 |
+
self.rehae = ReHAE()
|
| 215 |
+
self.sw = Swaddling()
|
| 216 |
+
self.tau = 0.005
|
| 217 |
+
self.tau_ = 1.0 - self.tau
|
| 218 |
+
self.alpha = (math.sqrt(5)-1)/2
|
| 219 |
+
self.alpha_= 1.0 - self.alpha
|
| 220 |
+
self.q_next_ema = torch.zeros(1, device=device)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
@torch.no_grad()
|
| 224 |
+
def tau_update(self):
|
| 225 |
+
for target_param, param in zip(self.target.qnets.parameters(), self.online.qnets.parameters()):
|
| 226 |
+
target_param.data.copy_(self.tau_*target_param.data + self.tau*param.data)
|
| 227 |
+
|
| 228 |
+
|
| 229 |
+
@jit.script_method
|
| 230 |
+
def forward(self, state, action, reward, next_state, not_done_gamma):
|
| 231 |
+
|
| 232 |
+
next_action, next_scale, next_beta = self.online.actor(next_state)
|
| 233 |
+
q_next_target, q_next_target_value = self.target.critic_soft(next_state, next_action)
|
| 234 |
+
q_target = reward + not_done_gamma * q_next_target_value
|
| 235 |
+
q_pred = self.online.critic(state, action)
|
| 236 |
+
|
| 237 |
+
q_next_ema = self.alpha * self.q_next_ema + self.alpha_ * q_next_target_value
|
| 238 |
+
nets_loss = -self.rehae((q_next_target - q_next_ema)/q_next_ema.abs()) + self.rehse(q_pred-q_target) + self.sw(next_scale, next_beta)
|
| 239 |
+
self.q_next_ema = q_next_ema.mean()
|
| 240 |
+
|
| 241 |
+
return nets_loss, next_scale.detach()
|
| 242 |
+
|
| 243 |
+
|
| 244 |
+
|
| 245 |
+
# Define the algorithm
|
| 246 |
+
class Symphony(object):
|
| 247 |
+
def __init__(self, capacity, state_dim, action_dim, h_dim, device, max_action, learning_rate=3e-4):
|
| 248 |
+
|
| 249 |
+
self.state_dim = state_dim
|
| 250 |
+
self.action_dim = action_dim
|
| 251 |
+
self.device = device
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
self.replay_buffer = ReplayBuffer(capacity, state_dim, action_dim, device)
|
| 255 |
+
self.nets = Nets(state_dim, action_dim, h_dim, max_action, device)
|
| 256 |
+
self.nets_optimizer = Adam(self.nets.online.parameters(), lr=learning_rate)
|
| 257 |
+
self.batch_size = self.nets.online.q_dist
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def select_action(self, state, action = True, noise=True):
|
| 261 |
+
state = torch.tensor(state, dtype=torch.float32, device=self.device).reshape(-1,self.state_dim)
|
| 262 |
+
with torch.no_grad(): action = self.nets.online.actor(state, action, noise)[0]
|
| 263 |
+
return action.cpu().data.numpy().flatten()
|
| 264 |
+
|
| 265 |
+
"""
|
| 266 |
+
def select_action(self, state, action = True, noise=True):
|
| 267 |
+
with torch.no_grad(): return self.nets.online.actor(state, action, noise)[0]
|
| 268 |
+
"""
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
def train(self):
|
| 273 |
+
|
| 274 |
+
torch.manual_seed(random.randint(0,2**32-1))
|
| 275 |
+
|
| 276 |
+
state, action, reward, next_state, not_done_gamma = self.replay_buffer.sample(self.batch_size)
|
| 277 |
+
self.nets_optimizer.zero_grad(set_to_none=True)
|
| 278 |
+
|
| 279 |
+
nets_loss, scale = self.nets(state, action, reward, next_state, not_done_gamma)
|
| 280 |
+
|
| 281 |
+
nets_loss.backward()
|
| 282 |
+
self.nets_optimizer.step()
|
| 283 |
+
self.nets.tau_update()
|
| 284 |
+
|
| 285 |
+
return scale
|
| 286 |
+
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
class ReplayBuffer:
|
| 293 |
+
def __init__(self, capacity, state_dim, action_dim, device):
|
| 294 |
+
|
| 295 |
+
self.capacity, self.length, self.device = capacity, 0, device
|
| 296 |
+
|
| 297 |
+
self.states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
|
| 298 |
+
self.actions = torch.zeros((self.capacity, action_dim), dtype=torch.float32, device=device)
|
| 299 |
+
self.rewards = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
|
| 300 |
+
self.next_states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
|
| 301 |
+
self.not_dones_gamma = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
|
| 302 |
+
|
| 303 |
+
self.norm = 1.0
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def add(self, state, action, reward, next_state, done):
|
| 307 |
+
|
| 308 |
+
if self.length<self.capacity: self.length += 1
|
| 309 |
+
|
| 310 |
+
idx = self.length-1
|
| 311 |
+
|
| 312 |
+
self.states[idx,:] = torch.tensor(state, dtype=torch.float32, device=self.device)
|
| 313 |
+
self.actions[idx,:] = torch.tensor(action, dtype=torch.float32, device=self.device)
|
| 314 |
+
self.rewards[idx,:] = torch.tensor([reward/self.norm], dtype=torch.float32, device=self.device)
|
| 315 |
+
self.next_states[idx,:] = torch.tensor(next_state, dtype=torch.float32, device=self.device)
|
| 316 |
+
self.not_dones_gamma[idx,:] = torch.tensor([0.99 * (1.0 - float(done))], dtype=torch.float32, device=self.device)
|
| 317 |
+
|
| 318 |
+
if self.length>=self.capacity:
|
| 319 |
+
shift = 2 if self.not_dones_gamma[0,:].item() == 0.0 else 1
|
| 320 |
+
self.states = torch.roll(self.states, shifts=-shift, dims=0)
|
| 321 |
+
self.actions = torch.roll(self.actions, shifts=-shift, dims=0)
|
| 322 |
+
self.rewards = torch.roll(self.rewards, shifts=-shift, dims=0)
|
| 323 |
+
self.next_states = torch.roll(self.next_states, shifts=-shift, dims=0)
|
| 324 |
+
self.not_dones_gamma = torch.roll(self.not_dones_gamma, shifts=-shift, dims=0)
|
| 325 |
+
|
| 326 |
+
|
| 327 |
+
|
| 328 |
+
def sample(self, batch_size):
|
| 329 |
+
|
| 330 |
+
indices = torch.multinomial(self.probs, num_samples=batch_size, replacement=True)
|
| 331 |
+
|
| 332 |
+
return (
|
| 333 |
+
self.states[indices],
|
| 334 |
+
self.actions[indices],
|
| 335 |
+
self.rewards[indices],
|
| 336 |
+
self.next_states[indices],
|
| 337 |
+
self.not_dones_gamma[indices]
|
| 338 |
+
)
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
def __len__(self):
|
| 342 |
+
return self.length
|
| 343 |
+
|
| 344 |
+
|
| 345 |
+
#==============================================================
|
| 346 |
+
#==============================================================
|
| 347 |
+
#===========================HELPERS============================
|
| 348 |
+
#==============================================================
|
| 349 |
+
#==============================================================
|
| 350 |
+
|
| 351 |
+
def norm_fill(self, times:int):
|
| 352 |
+
|
| 353 |
+
|
| 354 |
+
print("copying replay data, current length", self.length)
|
| 355 |
+
|
| 356 |
+
self.states = self.states[:self.length].repeat(times, 1)
|
| 357 |
+
self.actions = self.actions[:self.length].repeat(times, 1)
|
| 358 |
+
self.rewards = self.rewards[:self.length].repeat(times, 1)
|
| 359 |
+
self.next_states = self.next_states[:self.length].repeat(times, 1)
|
| 360 |
+
self.not_dones_gamma = self.not_dones_gamma[:self.length].repeat(times, 1)
|
| 361 |
+
|
| 362 |
+
self.norm = torch.mean(torch.abs(self.rewards)).item()
|
| 363 |
+
|
| 364 |
+
self.rewards /= self.norm
|
| 365 |
+
|
| 366 |
+
self.length = times*self.length
|
| 367 |
+
|
| 368 |
+
indexes = torch.arange(0, self.length, 1)/self.length
|
| 369 |
+
weights = torch.tanh((math.pi*indexes)**math.pi) - 0.01*torch.exp(-((indexes-1)/0.01)**2)
|
| 370 |
+
self.probs = weights/torch.sum(weights)
|
| 371 |
+
|
| 372 |
+
print("new replay buffer length: ", self.length)
|
symphony_S2/train.py
CHANGED
|
@@ -0,0 +1,233 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from symphony import Symphony
|
| 2 |
+
import gymnasium as gym
|
| 3 |
+
|
| 4 |
+
import logging
|
| 5 |
+
logging.getLogger().setLevel(logging.CRITICAL)
|
| 6 |
+
import torch
|
| 7 |
+
import numpy as np
|
| 8 |
+
import random
|
| 9 |
+
import pickle
|
| 10 |
+
import time
|
| 11 |
+
import os, re
|
| 12 |
+
|
| 13 |
+
#############################################
|
| 14 |
+
# -----------Helper Functions---------------#
|
| 15 |
+
#############################################
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# random seeds for reproducing the experiment
|
| 20 |
+
def seed_reset():
|
| 21 |
+
r1, r2, r3 = random.randint(0,2**32-1), random.randint(0,2**32-1), random.randint(0,2**32-1)
|
| 22 |
+
torch.manual_seed(r1)
|
| 23 |
+
np.random.seed(r2)
|
| 24 |
+
random.seed(r3)
|
| 25 |
+
return r1, r2, r3
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def extract_r1_r2_r3():
|
| 29 |
+
pattern = r'history_(\d+)_(\d+)_(\d+)\.csv'
|
| 30 |
+
|
| 31 |
+
# Iterate through the files in the given directory
|
| 32 |
+
for filename in os.listdir():
|
| 33 |
+
# Match the filename with the pattern
|
| 34 |
+
match = re.match(pattern, filename)
|
| 35 |
+
if match:
|
| 36 |
+
# Extract the numbers r1, r2, and r3 from the filename
|
| 37 |
+
return map(int, match.groups())
|
| 38 |
+
return None
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
#write or append to the history log file
|
| 42 |
+
class LogFile(object):
|
| 43 |
+
def __init__(self, log_name_main, log_name_opt):
|
| 44 |
+
self.log_name_main = log_name_main
|
| 45 |
+
self.log_name_opt = log_name_opt
|
| 46 |
+
def write(self, text):
|
| 47 |
+
with open(self.log_name_main, 'a+') as file:
|
| 48 |
+
file.write(text)
|
| 49 |
+
def write_opt(self, text):
|
| 50 |
+
with open(self.log_name_opt, 'a+') as file:
|
| 51 |
+
file.write(text)
|
| 52 |
+
def clean(self):
|
| 53 |
+
with open(self.log_name_main, 'w') as file:
|
| 54 |
+
file.write("step,return\n")
|
| 55 |
+
with open(self.log_name_opt, 'w') as file:
|
| 56 |
+
file.write("ep,return,steps,scale\n")
|
| 57 |
+
|
| 58 |
+
|
| 59 |
+
numbers = extract_r1_r2_r3()
|
| 60 |
+
|
| 61 |
+
if numbers != None:
|
| 62 |
+
# derive random numbers from history file
|
| 63 |
+
r1, r2, r3 = numbers
|
| 64 |
+
else:
|
| 65 |
+
# generate new random seeds
|
| 66 |
+
r1, r2, r3 = seed_reset()
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
print(r1, ", ", r2, ", ", r3)
|
| 71 |
+
|
| 72 |
+
log_name_main = "history_" + str(r1) + "_" + str(r2) + "_" + str(r3) + ".csv"
|
| 73 |
+
log_name_opt = "episodes_" + str(r1) + "_" + str(r2) + "_" + str(r3) + ".csv"
|
| 74 |
+
log_file = LogFile(log_name_main, log_name_opt)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def save(algo, total_rewards, total_steps):
|
| 78 |
+
|
| 79 |
+
torch.save(algo.nets.online.state_dict(), 'nets_online_model.pt')
|
| 80 |
+
torch.save(algo.nets.target.state_dict(), 'nets_target_model.pt')
|
| 81 |
+
torch.save(algo.nets_optimizer.state_dict(), 'nets_optimizer.pt')
|
| 82 |
+
print("saving... the buffer length = ", algo.replay_buffer.length, end="")
|
| 83 |
+
with open('data', 'wb') as file:
|
| 84 |
+
pickle.dump({'buffer': algo.replay_buffer, 'q_next_ema': algo.nets.q_next_ema, 'total_rewards': total_rewards, 'total_steps': total_steps}, file)
|
| 85 |
+
print(" > done")
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def load(algo, Q_learning):
|
| 89 |
+
|
| 90 |
+
total_rewards, total_steps = [], 0
|
| 91 |
+
|
| 92 |
+
try:
|
| 93 |
+
print("loading models...")
|
| 94 |
+
algo.nets.online.load_state_dict(torch.load('nets_online_model.pt', weights_only=True))
|
| 95 |
+
algo.nets.target.load_state_dict(torch.load('nets_target_model.pt', weights_only=True))
|
| 96 |
+
algo.nets_optimizer.load_state_dict(torch.load('nets_optimizer.pt', weights_only=True))
|
| 97 |
+
print('models loaded')
|
| 98 |
+
#sim_loop(env_valid, 100, True, False, algo, [], total_steps=0)
|
| 99 |
+
except:
|
| 100 |
+
print("problem during loading models")
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
try:
|
| 104 |
+
print("loading buffer...")
|
| 105 |
+
with open('data', 'rb') as file:
|
| 106 |
+
dict = pickle.load(file)
|
| 107 |
+
algo.replay_buffer = dict['buffer']
|
| 108 |
+
algo.nets.q_next_ema = dict['q_next_ema']
|
| 109 |
+
total_rewards = dict['total_rewards']
|
| 110 |
+
total_steps = dict['total_steps']
|
| 111 |
+
if algo.replay_buffer.length>=explore_time and not Q_learning: Q_learning = True
|
| 112 |
+
|
| 113 |
+
print('buffer loaded, Q_ema', round(algo.nets.q_next_ema.item(), 2), ', average_reward = ', round(np.mean(total_rewards[-300:]), 2))
|
| 114 |
+
|
| 115 |
+
except:
|
| 116 |
+
print("problem during loading buffer")
|
| 117 |
+
|
| 118 |
+
return Q_learning, total_rewards, total_steps
|
| 119 |
+
|
| 120 |
+
#############################################
|
| 121 |
+
# ---------------Parametres-----------------#
|
| 122 |
+
#############################################
|
| 123 |
+
|
| 124 |
+
#global parameters
|
| 125 |
+
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| 126 |
+
|
| 127 |
+
print(device)
|
| 128 |
+
G = 3
|
| 129 |
+
learning_rate = 5e-5
|
| 130 |
+
explore_time, times = 20480, 25
|
| 131 |
+
capacity = explore_time * times
|
| 132 |
+
h_dim = capacity//1000
|
| 133 |
+
limit_step = 1000 #max steps per episode
|
| 134 |
+
limit_eval = 1000 #max steps per evaluation
|
| 135 |
+
num_episodes = 1000000
|
| 136 |
+
start_episode = 1 #number for the identification of the current episode
|
| 137 |
+
episode_rewards_all, episode_steps_all, test_rewards, Q_learning, total_steps = [], [], [], False, 0
|
| 138 |
+
|
| 139 |
+
# environment type.
|
| 140 |
+
option = 3
|
| 141 |
+
pre_valid = True
|
| 142 |
+
if option == 0: env_name = '"BipedalWalker-v3'
|
| 143 |
+
elif option == 1: env_name = 'HalfCheetah-v4'
|
| 144 |
+
elif option == 2: env_name = 'Walker2d-v4'
|
| 145 |
+
elif option == 3: env_name = 'Humanoid-v4'
|
| 146 |
+
elif option == 4: env_name = 'Ant-v4'
|
| 147 |
+
elif option == 5: env_name = 'Swimmer-v4'
|
| 148 |
+
elif option == 6: env_name = 'Hopper-v4'
|
| 149 |
+
elif option == 7: env_name = 'Pusher-v4'
|
| 150 |
+
|
| 151 |
+
env = gym.make(env_name)
|
| 152 |
+
env_test = gym.make(env_name)
|
| 153 |
+
env_valid = gym.make(env_name, render_mode="human")
|
| 154 |
+
|
| 155 |
+
state_dim = env.observation_space.shape[0]
|
| 156 |
+
action_dim= env.action_space.shape[0]
|
| 157 |
+
#max_action = torch.FloatTensor(env.action_space.high) if env.action_space.is_bounded() else torch.ones(action_dim)
|
| 158 |
+
max_action = torch.ones(action_dim)
|
| 159 |
+
|
| 160 |
+
print("action_dim: ", action_dim, "state_dim: ", state_dim, "max_action:", max_action)
|
| 161 |
+
|
| 162 |
+
algo = Symphony(capacity, state_dim, action_dim, h_dim, device, max_action, learning_rate)
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
# Loop for episodes:[ State -> Loop for one episode: [ Action, Next State, Reward, Done, State = Next State ] ]
|
| 166 |
+
def sim_loop(env, episodes, testing, Q_learning, algo, total_rewards, total_steps):
|
| 167 |
+
|
| 168 |
+
|
| 169 |
+
start_episode = len(total_rewards) + 1
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
for episode in range(start_episode, episodes+1):
|
| 173 |
+
|
| 174 |
+
Return = 0.0
|
| 175 |
+
state = env.reset()[0]
|
| 176 |
+
|
| 177 |
+
for steps in range(1,limit_step+1):
|
| 178 |
+
|
| 179 |
+
seed_reset()
|
| 180 |
+
total_steps += 1
|
| 181 |
+
|
| 182 |
+
# Activate training if explore time is reached and if it is not testing mode:
|
| 183 |
+
if testing:
|
| 184 |
+
Q_learning = False
|
| 185 |
+
else:
|
| 186 |
+
if algo.replay_buffer.length>=explore_time and not Q_learning:
|
| 187 |
+
Q_learning = True
|
| 188 |
+
algo.replay_buffer.norm_fill(times)
|
| 189 |
+
print("started training")
|
| 190 |
+
|
| 191 |
+
# if total steps is divisible to 2500 save models, stop training and do testing, return to training:
|
| 192 |
+
if Q_learning and total_steps>=2500 and total_steps%2500==0:
|
| 193 |
+
save(algo, total_rewards, total_steps)
|
| 194 |
+
print("start testing")
|
| 195 |
+
log_file.write(str(total_steps) + ",")
|
| 196 |
+
Return = sim_loop(env_test, 25, True, Q_learning, algo, [], total_steps=0)
|
| 197 |
+
print("end of testing")
|
| 198 |
+
log_file.write(str(round(Return, 2)) + "\n")
|
| 199 |
+
|
| 200 |
+
|
| 201 |
+
# if steps is close to episode limit (e.g. 950) we shut down actions and leave noise to get Terminal Transition:
|
| 202 |
+
active = steps<(limit_step-50) if Q_learning else True
|
| 203 |
+
action = algo.select_action(state, action=active, noise=not testing)
|
| 204 |
+
next_state, reward, done, truncated, info = env.step(action)
|
| 205 |
+
if not testing: algo.replay_buffer.add(state, action, reward, next_state, done)
|
| 206 |
+
Return += reward
|
| 207 |
+
|
| 208 |
+
# actual training
|
| 209 |
+
if Q_learning: [scale := algo.train() for _ in range(G)]
|
| 210 |
+
if done or truncated: break
|
| 211 |
+
state = next_state
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
total_rewards.append(Return)
|
| 216 |
+
average_reward = np.mean(total_rewards[-300:])
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
print(f"Ep {episode}: Rtrn = {Return:.2f}, Avg = {average_reward:.2f}| ep steps = {steps} | total_steps = {total_steps}")
|
| 220 |
+
if not testing and Q_learning: log_file.write_opt(str(episode) + "," + str(round(Return, 2)) + "," + str(total_steps) + "," + str(round(scale.mean().item(), 4)) + "\n")
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
return np.mean(total_rewards).item()
|
| 224 |
+
|
| 225 |
+
|
| 226 |
+
|
| 227 |
+
|
| 228 |
+
# Loading existing models
|
| 229 |
+
Q_learning, total_rewards, total_steps = load(algo, Q_learning)
|
| 230 |
+
if not Q_learning: log_file.clean()
|
| 231 |
+
|
| 232 |
+
# Training
|
| 233 |
+
sim_loop(env, num_episodes, False, Q_learning, algo, total_rewards, total_steps)
|