|
|
import torch
|
|
|
import torch.nn as nn
|
|
|
import torch.optim as optim
|
|
|
import math
|
|
|
import torch.jit as jit
|
|
|
import random
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Adam(optim.Optimizer):
|
|
|
def __init__(self, params, lr=3e-4, weight_decay=0.01, betas=((math.sqrt(5)-1)/2, 0.995)):
|
|
|
defaults = dict(lr=lr, betas=betas)
|
|
|
super().__init__(params, defaults)
|
|
|
self.wd = weight_decay
|
|
|
self.lr = lr
|
|
|
self.beta1, self.beta2 = betas
|
|
|
self.beta1_, self.beta2_ = 1-self.beta1, 1-self.beta2
|
|
|
self.eps = 1e-8
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def step(self):
|
|
|
for group in self.param_groups:
|
|
|
for p in group['params']:
|
|
|
if p.grad is None:
|
|
|
continue
|
|
|
|
|
|
grad = p.grad
|
|
|
|
|
|
state = self.state[p]
|
|
|
if len(state) == 0:
|
|
|
state['m'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
state['v'] = torch.zeros_like(p, memory_format=torch.preserve_format)
|
|
|
|
|
|
|
|
|
m = state['m']
|
|
|
v = state['v']
|
|
|
|
|
|
|
|
|
m.mul_(self.beta1).add_(grad, alpha=self.beta1_)
|
|
|
|
|
|
v.mul_(self.beta2).addcmul_(grad, grad, value=self.beta2_)
|
|
|
|
|
|
|
|
|
p.add_(m/(v.sqrt() + self.eps) + self.wd*p, alpha=-self.lr)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReHSE(jit.ScriptModule):
|
|
|
def __init__(self):
|
|
|
super(ReHSE, self).__init__()
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, e):
|
|
|
return (e * torch.tanh(e/2)).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReHAE(jit.ScriptModule):
|
|
|
def __init__(self):
|
|
|
super(ReHAE, self).__init__()
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, e):
|
|
|
return (torch.abs(e) * torch.tanh(e/2)).mean()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReSine(jit.ScriptModule):
|
|
|
def __init__(self, hidden_dim=256):
|
|
|
super(ReSine, self).__init__()
|
|
|
k = 1/math.sqrt(hidden_dim)
|
|
|
self.s = nn.Parameter(data=2.0*k*torch.rand(hidden_dim)-k, requires_grad=True)
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, x):
|
|
|
s = torch.sigmoid(self.s)
|
|
|
x = s*torch.sin(x/s)
|
|
|
return x/(1+torch.exp(-1.5*x/s))
|
|
|
|
|
|
|
|
|
|
|
|
class SilentDropout(jit.ScriptModule):
|
|
|
def __init__(self, p=0.5):
|
|
|
super(SilentDropout, self).__init__()
|
|
|
self.p = p
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, x):
|
|
|
mask = (torch.rand_like(x) > self.p).float()
|
|
|
return mask * x + (1.0-mask) * x.detach()
|
|
|
|
|
|
|
|
|
class Swaddling(jit.ScriptModule):
|
|
|
def __init__(self):
|
|
|
super(Swaddling, self).__init__()
|
|
|
|
|
|
@jit.script_method
|
|
|
def Omega(self, x):
|
|
|
return torch.log((1+x)/(1-x))
|
|
|
|
|
|
@jit.script_method
|
|
|
def omega(self, x):
|
|
|
return x*torch.log(x)
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, x, k):
|
|
|
return (self.Omega(x**(1/k.detach())) + k * self.omega(x) + self.Omega(k**2)).mean()
|
|
|
|
|
|
|
|
|
|
|
|
class FeedForward(jit.ScriptModule):
|
|
|
def __init__(self, f_in, h_dim, f_out):
|
|
|
super(FeedForward, self).__init__()
|
|
|
|
|
|
|
|
|
self.ffw = nn.Sequential(
|
|
|
nn.Linear(f_in, h_dim),
|
|
|
nn.LayerNorm(h_dim),
|
|
|
nn.Linear(h_dim, h_dim),
|
|
|
ReSine(h_dim),
|
|
|
nn.Linear(h_dim, f_out)
|
|
|
)
|
|
|
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, x):
|
|
|
return self.ffw(x)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ActorCritic(jit.ScriptModule):
|
|
|
def __init__(self, state_dim, action_dim, h_dim, max_action=1.0):
|
|
|
super().__init__()
|
|
|
|
|
|
|
|
|
self.action_dim = action_dim
|
|
|
q_nodes = h_dim//4
|
|
|
|
|
|
self.a = FeedForward(state_dim, h_dim, 3*action_dim)
|
|
|
self.a_max = nn.Parameter(data= max_action, requires_grad=False)
|
|
|
self.std = 1/math.e
|
|
|
|
|
|
self.qA = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
|
|
self.qB = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
|
|
self.qC = FeedForward(state_dim+action_dim, h_dim, q_nodes)
|
|
|
self.qnets = nn.ModuleList([self.qA, self.qB, self.qC])
|
|
|
|
|
|
|
|
|
self.q_dist = q_nodes*len(self.qnets)
|
|
|
indexes = torch.arange(0, self.q_dist, 1)/self.q_dist
|
|
|
weights = torch.tanh((math.pi*(1-indexes))**math.e)
|
|
|
self.probs = nn.Parameter(data= weights/torch.sum(weights), requires_grad=False)
|
|
|
|
|
|
self.e = 1e-3
|
|
|
self.e_ = 1-self.e
|
|
|
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def actor(self, state, action:bool = True, noise:bool=True):
|
|
|
ASB = torch.tanh(self.a(state)/2).reshape(-1, 3, self.action_dim)
|
|
|
A, S, B = ASB [:, 0], ASB[:, 1].abs(), ASB[:, 2].abs()
|
|
|
N = self.std * torch.randn_like(A).clamp(-math.e, math.e)
|
|
|
return self.a_max * torch.tanh(float(action) * S * A + float(noise) * N), S.clamp(self.e, self.e_), B.clamp(self.e, self.e_)
|
|
|
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def critic(self, state, action):
|
|
|
x = torch.cat([state, action], -1)
|
|
|
return torch.cat([qnet(x) for qnet in self.qnets], dim=-1)
|
|
|
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def critic_soft(self, state, action):
|
|
|
q = self.probs * self.critic(state, action).sort(dim=-1)[0]
|
|
|
q = q.sum(dim=-1, keepdim=True)
|
|
|
return q, q.detach()
|
|
|
|
|
|
|
|
|
|
|
|
class Nets(jit.ScriptModule):
|
|
|
def __init__(self, state_dim, action_dim, h_dim, max_action, device):
|
|
|
super(Nets, self).__init__()
|
|
|
|
|
|
self.online = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
|
|
|
self.target = ActorCritic(state_dim, action_dim, h_dim, max_action=max_action).to(device)
|
|
|
self.target.load_state_dict(self.online.state_dict())
|
|
|
|
|
|
self.rehse = ReHSE()
|
|
|
self.rehae = ReHAE()
|
|
|
self.sw = Swaddling()
|
|
|
self.tau = 0.005
|
|
|
self.tau_ = 1.0 - self.tau
|
|
|
self.alpha = (math.sqrt(5)-1)/2
|
|
|
self.alpha_= 1.0 - self.alpha
|
|
|
self.q_next_ema = torch.zeros(1, device=device)
|
|
|
|
|
|
|
|
|
@torch.no_grad()
|
|
|
def tau_update(self):
|
|
|
for target_param, param in zip(self.target.qnets.parameters(), self.online.qnets.parameters()):
|
|
|
target_param.data.copy_(self.tau_*target_param.data + self.tau*param.data)
|
|
|
|
|
|
|
|
|
@jit.script_method
|
|
|
def forward(self, state, action, reward, next_state, not_done_gamma):
|
|
|
|
|
|
next_action, next_scale, next_beta = self.online.actor(next_state)
|
|
|
q_next_target, q_next_target_value = self.target.critic_soft(next_state, next_action)
|
|
|
q_target = reward + not_done_gamma * q_next_target_value
|
|
|
q_pred = self.online.critic(state, action)
|
|
|
|
|
|
q_next_ema = self.alpha * self.q_next_ema + self.alpha_ * q_next_target_value
|
|
|
nets_loss = -self.rehae((q_next_target - q_next_ema)/q_next_ema.abs()) + self.rehse(q_pred-q_target) + self.sw(next_scale, next_beta)
|
|
|
self.q_next_ema = q_next_ema.mean()
|
|
|
|
|
|
return nets_loss, next_scale.detach()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class Symphony(object):
|
|
|
def __init__(self, capacity, state_dim, action_dim, h_dim, device, max_action, learning_rate=3e-4):
|
|
|
|
|
|
self.state_dim = state_dim
|
|
|
self.action_dim = action_dim
|
|
|
self.device = device
|
|
|
|
|
|
|
|
|
self.replay_buffer = ReplayBuffer(capacity, state_dim, action_dim, device)
|
|
|
self.nets = Nets(state_dim, action_dim, h_dim, max_action, device)
|
|
|
self.nets_optimizer = Adam(self.nets.online.parameters(), lr=learning_rate)
|
|
|
self.batch_size = self.nets.online.q_dist
|
|
|
|
|
|
|
|
|
def select_action(self, state, action = True, noise=True):
|
|
|
state = torch.tensor(state, dtype=torch.float32, device=self.device).reshape(-1,self.state_dim)
|
|
|
with torch.no_grad(): action = self.nets.online.actor(state, action, noise)[0]
|
|
|
return action.cpu().data.numpy().flatten()
|
|
|
|
|
|
"""
|
|
|
def select_action(self, state, action = True, noise=True):
|
|
|
with torch.no_grad(): return self.nets.online.actor(state, action, noise)[0]
|
|
|
"""
|
|
|
|
|
|
|
|
|
|
|
|
def train(self):
|
|
|
|
|
|
torch.manual_seed(random.randint(0,2**32-1))
|
|
|
|
|
|
state, action, reward, next_state, not_done_gamma = self.replay_buffer.sample(self.batch_size)
|
|
|
self.nets_optimizer.zero_grad(set_to_none=True)
|
|
|
|
|
|
nets_loss, scale = self.nets(state, action, reward, next_state, not_done_gamma)
|
|
|
|
|
|
nets_loss.backward()
|
|
|
self.nets_optimizer.step()
|
|
|
self.nets.tau_update()
|
|
|
|
|
|
return scale
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class ReplayBuffer:
|
|
|
def __init__(self, capacity, state_dim, action_dim, device):
|
|
|
|
|
|
self.capacity, self.length, self.device = capacity, 0, device
|
|
|
|
|
|
self.states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
|
|
|
self.actions = torch.zeros((self.capacity, action_dim), dtype=torch.float32, device=device)
|
|
|
self.rewards = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
|
|
|
self.next_states = torch.zeros((self.capacity, state_dim), dtype=torch.float32, device=device)
|
|
|
self.not_dones_gamma = torch.zeros((self.capacity, 1), dtype=torch.float32, device=device)
|
|
|
|
|
|
self.norm = 1.0
|
|
|
|
|
|
|
|
|
def add(self, state, action, reward, next_state, done):
|
|
|
|
|
|
if self.length<self.capacity: self.length += 1
|
|
|
|
|
|
idx = self.length-1
|
|
|
|
|
|
self.states[idx,:] = torch.tensor(state, dtype=torch.float32, device=self.device)
|
|
|
self.actions[idx,:] = torch.tensor(action, dtype=torch.float32, device=self.device)
|
|
|
self.rewards[idx,:] = torch.tensor([reward/self.norm], dtype=torch.float32, device=self.device)
|
|
|
self.next_states[idx,:] = torch.tensor(next_state, dtype=torch.float32, device=self.device)
|
|
|
self.not_dones_gamma[idx,:] = torch.tensor([0.99 * (1.0 - float(done))], dtype=torch.float32, device=self.device)
|
|
|
|
|
|
if self.length>=self.capacity:
|
|
|
shift = 2 if self.not_dones_gamma[0,:].item() == 0.0 else 1
|
|
|
self.states = torch.roll(self.states, shifts=-shift, dims=0)
|
|
|
self.actions = torch.roll(self.actions, shifts=-shift, dims=0)
|
|
|
self.rewards = torch.roll(self.rewards, shifts=-shift, dims=0)
|
|
|
self.next_states = torch.roll(self.next_states, shifts=-shift, dims=0)
|
|
|
self.not_dones_gamma = torch.roll(self.not_dones_gamma, shifts=-shift, dims=0)
|
|
|
|
|
|
|
|
|
|
|
|
def sample(self, batch_size):
|
|
|
|
|
|
indices = torch.multinomial(self.probs, num_samples=batch_size, replacement=True)
|
|
|
|
|
|
return (
|
|
|
self.states[indices],
|
|
|
self.actions[indices],
|
|
|
self.rewards[indices],
|
|
|
self.next_states[indices],
|
|
|
self.not_dones_gamma[indices]
|
|
|
)
|
|
|
|
|
|
|
|
|
def __len__(self):
|
|
|
return self.length
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def norm_fill(self, times:int):
|
|
|
|
|
|
|
|
|
print("copying replay data, current length", self.length)
|
|
|
|
|
|
self.states = self.states[:self.length].repeat(times, 1)
|
|
|
self.actions = self.actions[:self.length].repeat(times, 1)
|
|
|
self.rewards = self.rewards[:self.length].repeat(times, 1)
|
|
|
self.next_states = self.next_states[:self.length].repeat(times, 1)
|
|
|
self.not_dones_gamma = self.not_dones_gamma[:self.length].repeat(times, 1)
|
|
|
|
|
|
self.norm = torch.mean(torch.abs(self.rewards)).item()
|
|
|
|
|
|
self.rewards /= self.norm
|
|
|
|
|
|
self.length = times*self.length
|
|
|
|
|
|
norm_index = torch.arange(0, self.length, 1)/self.length
|
|
|
weights = torch.tanh((math.pi*norm_index)**math.e)
|
|
|
self.probs = weights/torch.sum(weights)
|
|
|
|
|
|
print("new replay buffer length: ", self.length)
|
|
|
|