| | from typing import Optional, Tuple, Type, Union |
| |
|
| | import gym |
| | import torch |
| | import torch.nn as nn |
| |
|
| | from rl_algo_impls.shared.encoder.cnn import CnnEncoder, EncoderOutDim |
| | from rl_algo_impls.shared.module.utils import layer_init |
| |
|
| |
|
| | class GridnetEncoder(CnnEncoder): |
| | """ |
| | Encoder for encoder-decoder for Gym-MicroRTS |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | obs_space: gym.Space, |
| | activation: Type[nn.Module] = nn.ReLU, |
| | cnn_init_layers_orthogonal: Optional[bool] = None, |
| | **kwargs |
| | ) -> None: |
| | if cnn_init_layers_orthogonal is None: |
| | cnn_init_layers_orthogonal = True |
| | super().__init__(obs_space, **kwargs) |
| | in_channels = obs_space.shape[0] |
| | self.encoder = nn.Sequential( |
| | layer_init( |
| | nn.Conv2d(in_channels, 32, kernel_size=3, padding=1), |
| | cnn_init_layers_orthogonal, |
| | ), |
| | nn.MaxPool2d(3, stride=2, padding=1), |
| | activation(), |
| | layer_init( |
| | nn.Conv2d(32, 64, kernel_size=3, padding=1), |
| | cnn_init_layers_orthogonal, |
| | ), |
| | nn.MaxPool2d(3, stride=2, padding=1), |
| | activation(), |
| | layer_init( |
| | nn.Conv2d(64, 128, kernel_size=3, padding=1), |
| | cnn_init_layers_orthogonal, |
| | ), |
| | nn.MaxPool2d(3, stride=2, padding=1), |
| | activation(), |
| | layer_init( |
| | nn.Conv2d(128, 256, kernel_size=3, padding=1), |
| | cnn_init_layers_orthogonal, |
| | ), |
| | nn.MaxPool2d(3, stride=2, padding=1), |
| | activation(), |
| | ) |
| | with torch.no_grad(): |
| | encoder_out = self.encoder( |
| | self.preprocess(torch.as_tensor(obs_space.sample())) |
| | ) |
| | self._out_dim = encoder_out.shape[1:] |
| |
|
| | def forward(self, obs: torch.Tensor) -> torch.Tensor: |
| | return self.encoder(super().forward(obs)) |
| |
|
| | @property |
| | def out_dim(self) -> EncoderOutDim: |
| | return self._out_dim |
| |
|