import numpy as np import torch import math import torch.nn.functional as F import utils from torch import nn from torch import distributions as pyd class TanhTransform(pyd.transforms.Transform): domain = pyd.constraints.real codomain = pyd.constraints.interval(-1.0, 1.0) bijective = True sign = +1 def __init__(self, cache_size=1): super().__init__(cache_size=cache_size) @staticmethod def atanh(x): return 0.5 * (x.log1p() - (-x).log1p()) def __eq__(self, other): return isinstance(other, TanhTransform) def _call(self, x): return x.tanh() def _inverse(self, y): return self.atanh(y) def log_abs_det_jacobian(self, x, y): return 2. * (math.log(2.) - x - F.softplus(-2. * x)) class SquashedNormal(pyd.transformed_distribution.TransformedDistribution): def __init__(self, loc, scale): self.loc = loc self.scale = scale self.base_dist = pyd.Normal(loc, scale) transforms = [TanhTransform()] super().__init__(self.base_dist, transforms) @property def mean(self): mu = self.loc for tr in self.transforms: mu = tr(mu) return mu class DiagGaussianActor(nn.Module): """torch.distributions implementation of an diagonal Gaussian policy.""" def __init__(self, obs_dim, action_dim, hidden_dim, hidden_depth, log_std_bounds): super().__init__() self.log_std_bounds = log_std_bounds self.trunk = utils.mlp(obs_dim, hidden_dim, 2 * action_dim, hidden_depth) self.outputs = dict() self.apply(utils.weight_init) def forward(self, obs): mu, log_std = self.trunk(obs).chunk(2, dim=-1) # constrain log_std inside [log_std_min, log_std_max] log_std = torch.tanh(log_std) log_std_min, log_std_max = self.log_std_bounds log_std = log_std_min + 0.5 * (log_std_max - log_std_min) * (log_std + 1) std = log_std.exp() self.outputs['mu'] = mu self.outputs['std'] = std dist = SquashedNormal(mu, std) return dist def log(self, logger, step): for k, v in self.outputs.items(): logger.log_histogram(f'train_actor/{k}_hist', v, step) for i, m in enumerate(self.trunk): if type(m) == nn.Linear: logger.log_param(f'train_actor/fc{i}', m, step)