|
|
import numpy as np |
|
|
import torch |
|
|
import math |
|
|
import torch.nn.functional as F |
|
|
import utils |
|
|
|
|
|
from torch import nn |
|
|
from torch import distributions as pyd |
|
|
|
|
|
class TanhTransform(pyd.transforms.Transform): |
|
|
domain = pyd.constraints.real |
|
|
codomain = pyd.constraints.interval(-1.0, 1.0) |
|
|
bijective = True |
|
|
sign = +1 |
|
|
|
|
|
def __init__(self, cache_size=1): |
|
|
super().__init__(cache_size=cache_size) |
|
|
|
|
|
@staticmethod |
|
|
def atanh(x): |
|
|
return 0.5 * (x.log1p() - (-x).log1p()) |
|
|
|
|
|
def __eq__(self, other): |
|
|
return isinstance(other, TanhTransform) |
|
|
|
|
|
def _call(self, x): |
|
|
return x.tanh() |
|
|
|
|
|
def _inverse(self, y): |
|
|
|
|
|
return self.atanh(y) |
|
|
|
|
|
def log_abs_det_jacobian(self, x, y): |
|
|
|
|
|
return 2. * (math.log(2.) - x - F.softplus(-2. * x)) |
|
|
|
|
|
|
|
|
class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): |
|
|
def __init__(self, loc, scale): |
|
|
self.loc = loc |
|
|
self.scale = scale |
|
|
|
|
|
self.base_dist = pyd.Normal(loc, scale) |
|
|
transforms = [TanhTransform()] |
|
|
super().__init__(self.base_dist, transforms) |
|
|
|
|
|
@property |
|
|
def mean(self): |
|
|
mu = self.loc |
|
|
for tr in self.transforms: |
|
|
mu = tr(mu) |
|
|
return mu |
|
|
|
|
|
|
|
|
class DiagGaussianActor(nn.Module): |
|
|
"""torch.distributions implementation of an diagonal Gaussian policy.""" |
|
|
def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, |
|
|
log_std_bounds): |
|
|
super().__init__() |
|
|
|
|
|
self.log_std_bounds = log_std_bounds |
|
|
self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, hidden_depth) |
|
|
|
|
|
self.outputs = dict() |
|
|
self.apply(utils.weight_init) |
|
|
|
|
|
def forward(self, obs): |
|
|
mu, log_std = self.trunk(obs).chunk(2, dim=-1) |
|
|
|
|
|
|
|
|
log_std = torch.tanh(log_std) |
|
|
log_std_min, log_std_max = self.log_std_bounds |
|
|
log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) |
|
|
|
|
|
std = log_std.exp() |
|
|
|
|
|
self.outputs['mu'] = mu |
|
|
self.outputs['std'] = std |
|
|
|
|
|
dist = SquashedNormal(mu, std) |
|
|
return dist |
|
|
|
|
|
def log(self, logger, step): |
|
|
for k, v in self.outputs.items(): |
|
|
logger.log_histogram(f'train_actor/{k}_hist', v, step) |
|
|
|
|
|
for i, m in enumerate(self.trunk): |
|
|
if type(m) == nn.Linear: |
|
|
logger.log_param(f'train_actor/fc{i}', m, step) |