|
|
import gym |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from abc import ABC, abstractmethod |
|
|
from gym.spaces import Box, Discrete |
|
|
from torch.distributions import Categorical, Distribution, Normal |
|
|
from typing import NamedTuple, Optional, Sequence, Type, TypeVar, Union |
|
|
|
|
|
from shared.module.feature_extractor import FeatureExtractor |
|
|
from shared.module.module import mlp |
|
|
|
|
|
|
|
|
class PiForward(NamedTuple): |
|
|
pi: Distribution |
|
|
logp_a: Optional[torch.Tensor] |
|
|
entropy: Optional[torch.Tensor] |
|
|
|
|
|
|
|
|
class Actor(nn.Module, ABC): |
|
|
@abstractmethod |
|
|
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: |
|
|
... |
|
|
|
|
|
|
|
|
class CategoricalActorHead(Actor): |
|
|
def __init__( |
|
|
self, |
|
|
act_dim: int, |
|
|
hidden_sizes: Sequence[int] = (32,), |
|
|
activation: Type[nn.Module] = nn.Tanh, |
|
|
init_layers_orthogonal: bool = True, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
layer_sizes = tuple(hidden_sizes) + (act_dim,) |
|
|
self._fc = mlp( |
|
|
layer_sizes, |
|
|
activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
final_layer_gain=0.01, |
|
|
) |
|
|
|
|
|
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: |
|
|
logits = self._fc(obs) |
|
|
pi = Categorical(logits=logits) |
|
|
logp_a = None |
|
|
entropy = None |
|
|
if a is not None: |
|
|
logp_a = pi.log_prob(a) |
|
|
entropy = pi.entropy() |
|
|
return PiForward(pi, logp_a, entropy) |
|
|
|
|
|
|
|
|
class GaussianDistribution(Normal): |
|
|
def log_prob(self, a: torch.Tensor) -> torch.Tensor: |
|
|
return super().log_prob(a).sum(axis=-1) |
|
|
|
|
|
def sample(self) -> torch.Tensor: |
|
|
return self.rsample() |
|
|
|
|
|
|
|
|
class GaussianActorHead(Actor): |
|
|
def __init__( |
|
|
self, |
|
|
act_dim: int, |
|
|
hidden_sizes: Sequence[int] = (32,), |
|
|
activation: Type[nn.Module] = nn.Tanh, |
|
|
init_layers_orthogonal: bool = True, |
|
|
log_std_init: float = -0.5, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
layer_sizes = tuple(hidden_sizes) + (act_dim,) |
|
|
self.mu_net = mlp( |
|
|
layer_sizes, |
|
|
activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
final_layer_gain=0.01, |
|
|
) |
|
|
self.log_std = nn.Parameter( |
|
|
torch.ones(act_dim, dtype=torch.float32) * log_std_init |
|
|
) |
|
|
|
|
|
def _distribution(self, obs: torch.Tensor) -> Distribution: |
|
|
mu = self.mu_net(obs) |
|
|
std = torch.exp(self.log_std) |
|
|
return GaussianDistribution(mu, std) |
|
|
|
|
|
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: |
|
|
pi = self._distribution(obs) |
|
|
logp_a = None |
|
|
entropy = None |
|
|
if a is not None: |
|
|
logp_a = pi.log_prob(a) |
|
|
entropy = pi.entropy() |
|
|
return PiForward(pi, logp_a, entropy) |
|
|
|
|
|
|
|
|
class TanhBijector: |
|
|
def __init__(self, epsilon: float = 1e-6) -> None: |
|
|
self.epsilon = epsilon |
|
|
|
|
|
@staticmethod |
|
|
def forward(x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.tanh(x) |
|
|
|
|
|
@staticmethod |
|
|
def inverse(y: torch.Tensor) -> torch.Tensor: |
|
|
eps = torch.finfo(y.dtype).eps |
|
|
clamped_y = y.clamp(min=-1.0 + eps, max=1.0 - eps) |
|
|
return torch.atanh(clamped_y) |
|
|
|
|
|
def log_prob_correction(self, x: torch.Tensor) -> torch.Tensor: |
|
|
return torch.log(1.0 - torch.tanh(x) ** 2 + self.epsilon) |
|
|
|
|
|
|
|
|
class StateDependentNoiseDistribution(Normal): |
|
|
def __init__( |
|
|
self, |
|
|
loc, |
|
|
scale, |
|
|
latent_sde: torch.Tensor, |
|
|
exploration_mat: torch.Tensor, |
|
|
exploration_matrices: torch.Tensor, |
|
|
bijector: Optional[TanhBijector] = None, |
|
|
validate_args=None, |
|
|
): |
|
|
super().__init__(loc, scale, validate_args) |
|
|
self.latent_sde = latent_sde |
|
|
self.exploration_mat = exploration_mat |
|
|
self.exploration_matrices = exploration_matrices |
|
|
self.bijector = bijector |
|
|
|
|
|
def log_prob(self, a: torch.Tensor) -> torch.Tensor: |
|
|
gaussian_a = self.bijector.inverse(a) if self.bijector else a |
|
|
log_prob = super().log_prob(gaussian_a).sum(axis=-1) |
|
|
if self.bijector: |
|
|
log_prob -= torch.sum(self.bijector.log_prob_correction(gaussian_a), dim=1) |
|
|
return log_prob |
|
|
|
|
|
def sample(self) -> torch.Tensor: |
|
|
noise = self._get_noise() |
|
|
actions = self.mean + noise |
|
|
return self.bijector.forward(actions) if self.bijector else actions |
|
|
|
|
|
def _get_noise(self) -> torch.Tensor: |
|
|
if len(self.latent_sde) == 1 or len(self.latent_sde) != len( |
|
|
self.exploration_matrices |
|
|
): |
|
|
return torch.mm(self.latent_sde, self.exploration_mat) |
|
|
|
|
|
latent_sde = self.latent_sde.unsqueeze(dim=1) |
|
|
|
|
|
noise = torch.bmm(latent_sde, self.exploration_matrices) |
|
|
return noise.squeeze(dim=1) |
|
|
|
|
|
@property |
|
|
def mode(self) -> torch.Tensor: |
|
|
mean = super().mode |
|
|
return self.bijector.forward(mean) if self.bijector else mean |
|
|
|
|
|
|
|
|
StateDependentNoiseActorHeadSelf = TypeVar( |
|
|
"StateDependentNoiseActorHeadSelf", bound="StateDependentNoiseActorHead" |
|
|
) |
|
|
|
|
|
|
|
|
class StateDependentNoiseActorHead(Actor): |
|
|
def __init__( |
|
|
self, |
|
|
act_dim: int, |
|
|
hidden_sizes: Sequence[int] = (32,), |
|
|
activation: Type[nn.Module] = nn.Tanh, |
|
|
init_layers_orthogonal: bool = True, |
|
|
log_std_init: float = -0.5, |
|
|
full_std: bool = True, |
|
|
squash_output: bool = False, |
|
|
learn_std: bool = False, |
|
|
) -> None: |
|
|
super().__init__() |
|
|
self.act_dim = act_dim |
|
|
layer_sizes = tuple(hidden_sizes) + (self.act_dim,) |
|
|
if len(layer_sizes) == 2: |
|
|
self.latent_net = nn.Identity() |
|
|
elif len(layer_sizes) > 2: |
|
|
self.latent_net = mlp( |
|
|
layer_sizes[:-1], |
|
|
activation, |
|
|
output_activation=activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
) |
|
|
else: |
|
|
raise ValueError("hidden_sizes must be of at least length 1") |
|
|
self.mu_net = mlp( |
|
|
layer_sizes[-2:], |
|
|
activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
final_layer_gain=0.01, |
|
|
) |
|
|
self.full_std = full_std |
|
|
std_dim = (hidden_sizes[-1], act_dim if self.full_std else 1) |
|
|
self.log_std = nn.Parameter( |
|
|
torch.ones(std_dim, dtype=torch.float32) * log_std_init |
|
|
) |
|
|
self.bijector = TanhBijector() if squash_output else None |
|
|
self.learn_std = learn_std |
|
|
self.device = None |
|
|
|
|
|
self.exploration_mat = None |
|
|
self.exploration_matrices = None |
|
|
self.sample_weights() |
|
|
|
|
|
def to( |
|
|
self: StateDependentNoiseActorHeadSelf, |
|
|
device: Optional[torch.device] = None, |
|
|
dtype: Optional[Union[torch.dtype, str]] = None, |
|
|
non_blocking: bool = False, |
|
|
) -> StateDependentNoiseActorHeadSelf: |
|
|
super().to(device, dtype, non_blocking) |
|
|
self.device = device |
|
|
return self |
|
|
|
|
|
def _distribution(self, obs: torch.Tensor) -> Distribution: |
|
|
latent = self.latent_net(obs) |
|
|
mu = self.mu_net(latent) |
|
|
latent_sde = latent if self.learn_std else latent.detach() |
|
|
variance = torch.mm(latent_sde**2, self._get_std() ** 2) |
|
|
assert self.exploration_mat is not None |
|
|
assert self.exploration_matrices is not None |
|
|
return StateDependentNoiseDistribution( |
|
|
mu, |
|
|
torch.sqrt(variance + 1e-6), |
|
|
latent_sde, |
|
|
self.exploration_mat, |
|
|
self.exploration_matrices, |
|
|
self.bijector, |
|
|
) |
|
|
|
|
|
def _get_std(self) -> torch.Tensor: |
|
|
std = torch.exp(self.log_std) |
|
|
if self.full_std: |
|
|
return std |
|
|
ones = torch.ones(self.log_std.shape[0], self.act_dim) |
|
|
if self.device: |
|
|
ones = ones.to(self.device) |
|
|
return ones * std |
|
|
|
|
|
def forward(self, obs: torch.Tensor, a: Optional[torch.Tensor] = None) -> PiForward: |
|
|
pi = self._distribution(obs) |
|
|
logp_a = None |
|
|
entropy = None |
|
|
if a is not None: |
|
|
logp_a = pi.log_prob(a) |
|
|
entropy = -logp_a |
|
|
return PiForward(pi, logp_a, entropy) |
|
|
|
|
|
def sample_weights(self, batch_size: int = 1) -> None: |
|
|
std = self._get_std() |
|
|
weights_dist = Normal(torch.zeros_like(std), std) |
|
|
|
|
|
self.exploration_mat = weights_dist.rsample() |
|
|
self.exploration_matrices = weights_dist.rsample(torch.Size((batch_size,))) |
|
|
|
|
|
|
|
|
def actor_head( |
|
|
action_space: gym.Space, |
|
|
hidden_sizes: Sequence[int], |
|
|
init_layers_orthogonal: bool, |
|
|
activation: Type[nn.Module], |
|
|
log_std_init: float = -0.5, |
|
|
use_sde: bool = False, |
|
|
full_std: bool = True, |
|
|
squash_output: bool = False, |
|
|
) -> Actor: |
|
|
assert not use_sde or isinstance( |
|
|
action_space, Box |
|
|
), "use_sde only valid if Box action_space" |
|
|
assert not squash_output or use_sde, "squash_output only valid if use_sde" |
|
|
if isinstance(action_space, Discrete): |
|
|
return CategoricalActorHead( |
|
|
action_space.n, |
|
|
hidden_sizes=hidden_sizes, |
|
|
activation=activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
) |
|
|
elif isinstance(action_space, Box): |
|
|
if use_sde: |
|
|
return StateDependentNoiseActorHead( |
|
|
action_space.shape[0], |
|
|
hidden_sizes=hidden_sizes, |
|
|
activation=activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
log_std_init=log_std_init, |
|
|
full_std=full_std, |
|
|
squash_output=squash_output, |
|
|
) |
|
|
else: |
|
|
return GaussianActorHead( |
|
|
action_space.shape[0], |
|
|
hidden_sizes=hidden_sizes, |
|
|
activation=activation, |
|
|
init_layers_orthogonal=init_layers_orthogonal, |
|
|
log_std_init=log_std_init, |
|
|
) |
|
|
else: |
|
|
raise ValueError(f"Unsupported action space: {action_space}") |
|
|
|