Spaces:
Sleeping
Sleeping
| from typing import Union, Optional | |
| from easydict import EasyDict | |
| import torch | |
| import torch.nn as nn | |
| import treetensor.torch as ttorch | |
| from copy import deepcopy | |
| from ding.utils import SequenceType, squeeze | |
| from ding.model.common import ReparameterizationHead, RegressionHead, MultiHead, \ | |
| FCEncoder, ConvEncoder, IMPALAConvEncoder, PopArtVHead | |
| from ding.torch_utils import MLP, fc_block | |
| class DiscretePolicyHead(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_size: int, | |
| output_size: int, | |
| layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| ) -> None: | |
| super(DiscretePolicyHead, self).__init__() | |
| self.main = nn.Sequential( | |
| MLP( | |
| hidden_size, | |
| hidden_size, | |
| hidden_size, | |
| layer_num, | |
| layer_fn=nn.Linear, | |
| activation=activation, | |
| norm_type=norm_type | |
| ), fc_block(hidden_size, output_size) | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.main(x) | |
| class PPOFModel(nn.Module): | |
| mode = ['compute_actor', 'compute_critic', 'compute_actor_critic'] | |
| def __init__( | |
| self, | |
| obs_shape: Union[int, SequenceType], | |
| action_shape: Union[int, SequenceType, EasyDict], | |
| action_space: str = 'discrete', | |
| share_encoder: bool = True, | |
| encoder_hidden_size_list: SequenceType = [128, 128, 64], | |
| actor_head_hidden_size: int = 64, | |
| actor_head_layer_num: int = 1, | |
| critic_head_hidden_size: int = 64, | |
| critic_head_layer_num: int = 1, | |
| activation: Optional[nn.Module] = nn.ReLU(), | |
| norm_type: Optional[str] = None, | |
| sigma_type: Optional[str] = 'independent', | |
| fixed_sigma_value: Optional[int] = 0.3, | |
| bound_type: Optional[str] = None, | |
| encoder: Optional[torch.nn.Module] = None, | |
| popart_head=False, | |
| ) -> None: | |
| super(PPOFModel, self).__init__() | |
| obs_shape = squeeze(obs_shape) | |
| action_shape = squeeze(action_shape) | |
| self.obs_shape, self.action_shape = obs_shape, action_shape | |
| self.share_encoder = share_encoder | |
| # Encoder Type | |
| def new_encoder(outsize): | |
| if isinstance(obs_shape, int) or len(obs_shape) == 1: | |
| return FCEncoder( | |
| obs_shape=obs_shape, | |
| hidden_size_list=encoder_hidden_size_list, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| elif len(obs_shape) == 3: | |
| return ConvEncoder( | |
| obs_shape=obs_shape, | |
| hidden_size_list=encoder_hidden_size_list, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| else: | |
| raise RuntimeError( | |
| "not support obs_shape for pre-defined encoder: {}, please customize your own encoder". | |
| format(obs_shape) | |
| ) | |
| if self.share_encoder: | |
| assert actor_head_hidden_size == critic_head_hidden_size, \ | |
| "actor and critic network head should have same size." | |
| if encoder: | |
| if isinstance(encoder, torch.nn.Module): | |
| self.encoder = encoder | |
| else: | |
| raise ValueError("illegal encoder instance.") | |
| else: | |
| self.encoder = new_encoder(actor_head_hidden_size) | |
| else: | |
| if encoder: | |
| if isinstance(encoder, torch.nn.Module): | |
| self.actor_encoder = encoder | |
| self.critic_encoder = deepcopy(encoder) | |
| else: | |
| raise ValueError("illegal encoder instance.") | |
| else: | |
| self.actor_encoder = new_encoder(actor_head_hidden_size) | |
| self.critic_encoder = new_encoder(critic_head_hidden_size) | |
| # Head Type | |
| if not popart_head: | |
| self.critic_head = RegressionHead( | |
| critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type | |
| ) | |
| else: | |
| self.critic_head = PopArtVHead( | |
| critic_head_hidden_size, 1, critic_head_layer_num, activation=activation, norm_type=norm_type | |
| ) | |
| self.action_space = action_space | |
| assert self.action_space in ['discrete', 'continuous', 'hybrid'], self.action_space | |
| if self.action_space == 'continuous': | |
| self.multi_head = False | |
| self.actor_head = ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| sigma_type=sigma_type, | |
| activation=activation, | |
| norm_type=norm_type, | |
| bound_type=bound_type | |
| ) | |
| elif self.action_space == 'discrete': | |
| actor_head_cls = DiscretePolicyHead | |
| multi_head = not isinstance(action_shape, int) | |
| self.multi_head = multi_head | |
| if multi_head: | |
| self.actor_head = MultiHead( | |
| actor_head_cls, | |
| actor_head_hidden_size, | |
| action_shape, | |
| layer_num=actor_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| else: | |
| self.actor_head = actor_head_cls( | |
| actor_head_hidden_size, | |
| action_shape, | |
| actor_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type | |
| ) | |
| elif self.action_space == 'hybrid': # HPPO | |
| # hybrid action space: action_type(discrete) + action_args(continuous), | |
| # such as {'action_type_shape': torch.LongTensor([0]), 'action_args_shape': torch.FloatTensor([0.1, -0.27])} | |
| action_shape.action_args_shape = squeeze(action_shape.action_args_shape) | |
| action_shape.action_type_shape = squeeze(action_shape.action_type_shape) | |
| actor_action_args = ReparameterizationHead( | |
| actor_head_hidden_size, | |
| action_shape.action_args_shape, | |
| actor_head_layer_num, | |
| sigma_type=sigma_type, | |
| fixed_sigma_value=fixed_sigma_value, | |
| activation=activation, | |
| norm_type=norm_type, | |
| bound_type=bound_type, | |
| ) | |
| actor_action_type = DiscretePolicyHead( | |
| actor_head_hidden_size, | |
| action_shape.action_type_shape, | |
| actor_head_layer_num, | |
| activation=activation, | |
| norm_type=norm_type, | |
| ) | |
| self.actor_head = nn.ModuleList([actor_action_type, actor_action_args]) | |
| # must use list, not nn.ModuleList | |
| if self.share_encoder: | |
| self.actor = [self.encoder, self.actor_head] | |
| self.critic = [self.encoder, self.critic_head] | |
| else: | |
| self.actor = [self.actor_encoder, self.actor_head] | |
| self.critic = [self.critic_encoder, self.critic_head] | |
| # Convenient for calling some apis (e.g. self.critic.parameters()), | |
| # but may cause misunderstanding when `print(self)` | |
| self.actor = nn.ModuleList(self.actor) | |
| self.critic = nn.ModuleList(self.critic) | |
| def forward(self, inputs: ttorch.Tensor, mode: str) -> ttorch.Tensor: | |
| assert mode in self.mode, "not support forward mode: {}/{}".format(mode, self.mode) | |
| return getattr(self, mode)(inputs) | |
| def compute_actor(self, x: ttorch.Tensor) -> ttorch.Tensor: | |
| if self.share_encoder: | |
| x = self.encoder(x) | |
| else: | |
| x = self.actor_encoder(x) | |
| if self.action_space == 'discrete': | |
| return self.actor_head(x) | |
| elif self.action_space == 'continuous': | |
| x = self.actor_head(x) # mu, sigma | |
| return ttorch.as_tensor(x) | |
| elif self.action_space == 'hybrid': | |
| action_type = self.actor_head[0](x) | |
| action_args = self.actor_head[1](x) | |
| return ttorch.as_tensor({'action_type': action_type, 'action_args': action_args}) | |
| def compute_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: | |
| if self.share_encoder: | |
| x = self.encoder(x) | |
| else: | |
| x = self.critic_encoder(x) | |
| x = self.critic_head(x) | |
| return x | |
| def compute_actor_critic(self, x: ttorch.Tensor) -> ttorch.Tensor: | |
| if self.share_encoder: | |
| actor_embedding = critic_embedding = self.encoder(x) | |
| else: | |
| actor_embedding = self.actor_encoder(x) | |
| critic_embedding = self.critic_encoder(x) | |
| value = self.critic_head(critic_embedding) | |
| if self.action_space == 'discrete': | |
| logit = self.actor_head(actor_embedding) | |
| return ttorch.as_tensor({'logit': logit, 'value': value['pred']}) | |
| elif self.action_space == 'continuous': | |
| x = self.actor_head(actor_embedding) | |
| return ttorch.as_tensor({'logit': x, 'value': value['pred']}) | |
| elif self.action_space == 'hybrid': | |
| action_type = self.actor_head[0](actor_embedding) | |
| action_args = self.actor_head[1](actor_embedding) | |
| return ttorch.as_tensor( | |
| { | |
| 'logit': { | |
| 'action_type': action_type, | |
| 'action_args': action_args | |
| }, | |
| 'value': value['pred'] | |
| } | |
| ) | |