| from typing import Optional, Type |
|
|
| import gym |
| import torch.nn as nn |
|
|
| from rl_algo_impls.shared.encoder.cnn import FlattenedCnnEncoder |
| from rl_algo_impls.shared.module.utils import layer_init |
|
|
|
|
| class NatureCnn(FlattenedCnnEncoder): |
| """ |
| CNN from DQN Nature paper: Mnih, Volodymyr, et al. |
| "Human-level control through deep reinforcement learning." |
| Nature 518.7540 (2015): 529-533. |
| """ |
|
|
| def __init__( |
| self, |
| obs_space: gym.Space, |
| activation: Type[nn.Module], |
| cnn_init_layers_orthogonal: Optional[bool], |
| linear_init_layers_orthogonal: bool, |
| cnn_flatten_dim: int, |
| **kwargs, |
| ) -> None: |
| if cnn_init_layers_orthogonal is None: |
| cnn_init_layers_orthogonal = True |
| in_channels = obs_space.shape[0] |
| cnn = nn.Sequential( |
| layer_init( |
| nn.Conv2d(in_channels, 32, kernel_size=8, stride=4), |
| cnn_init_layers_orthogonal, |
| ), |
| activation(), |
| layer_init( |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), |
| cnn_init_layers_orthogonal, |
| ), |
| activation(), |
| layer_init( |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), |
| cnn_init_layers_orthogonal, |
| ), |
| activation(), |
| ) |
| super().__init__( |
| obs_space, |
| activation, |
| linear_init_layers_orthogonal, |
| cnn_flatten_dim, |
| cnn, |
| **kwargs, |
| ) |
|
|