VPG playing SpaceInvadersNoFrameskip-v4 from https://github.com/sgoodfriend/rl-algo-impls/tree/e8bc541d8b5e67bb4d3f2075282463fb61f5f2c6
41a6762
| import gym | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from abc import ABC, abstractmethod | |
| from gym.spaces import Box, Discrete | |
| from stable_baselines3.common.preprocessing import get_flattened_obs_dim | |
| from typing import Dict, Optional, Sequence, Type | |
| from shared.module.module import layer_init | |
| class CnnFeatureExtractor(nn.Module, ABC): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| activation: Type[nn.Module] = nn.ReLU, | |
| init_layers_orthogonal: Optional[bool] = None, | |
| **kwargs, | |
| ) -> None: | |
| super().__init__() | |
| class NatureCnn(CnnFeatureExtractor): | |
| """ | |
| CNN from DQN Nature paper: Mnih, Volodymyr, et al. | |
| "Human-level control through deep reinforcement learning." | |
| Nature 518.7540 (2015): 529-533. | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| activation: Type[nn.Module] = nn.ReLU, | |
| init_layers_orthogonal: Optional[bool] = None, | |
| **kwargs, | |
| ) -> None: | |
| if init_layers_orthogonal is None: | |
| init_layers_orthogonal = True | |
| super().__init__(in_channels, activation, init_layers_orthogonal) | |
| self.cnn = nn.Sequential( | |
| layer_init( | |
| nn.Conv2d(in_channels, 32, kernel_size=8, stride=4), | |
| init_layers_orthogonal, | |
| ), | |
| activation(), | |
| layer_init( | |
| nn.Conv2d(32, 64, kernel_size=4, stride=2), | |
| init_layers_orthogonal, | |
| ), | |
| activation(), | |
| layer_init( | |
| nn.Conv2d(64, 64, kernel_size=3, stride=1), | |
| init_layers_orthogonal, | |
| ), | |
| activation(), | |
| nn.Flatten(), | |
| ) | |
| def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
| return self.cnn(obs) | |
| class ResidualBlock(nn.Module): | |
| def __init__( | |
| self, | |
| channels: int, | |
| activation: Type[nn.Module] = nn.ReLU, | |
| init_layers_orthogonal: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.residual = nn.Sequential( | |
| activation(), | |
| layer_init( | |
| nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal | |
| ), | |
| activation(), | |
| layer_init( | |
| nn.Conv2d(channels, channels, 3, padding=1), init_layers_orthogonal | |
| ), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return x + self.residual(x) | |
| class ConvSequence(nn.Module): | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| out_channels: int, | |
| activation: Type[nn.Module] = nn.ReLU, | |
| init_layers_orthogonal: bool = False, | |
| ) -> None: | |
| super().__init__() | |
| self.seq = nn.Sequential( | |
| layer_init( | |
| nn.Conv2d(in_channels, out_channels, 3, padding=1), | |
| init_layers_orthogonal, | |
| ), | |
| nn.MaxPool2d(3, stride=2, padding=1), | |
| ResidualBlock(out_channels, activation, init_layers_orthogonal), | |
| ResidualBlock(out_channels, activation, init_layers_orthogonal), | |
| ) | |
| def forward(self, x: torch.Tensor) -> torch.Tensor: | |
| return self.seq(x) | |
| class ImpalaCnn(CnnFeatureExtractor): | |
| """ | |
| IMPALA-style CNN architecture | |
| """ | |
| def __init__( | |
| self, | |
| in_channels: int, | |
| activation: Type[nn.Module] = nn.ReLU, | |
| init_layers_orthogonal: Optional[bool] = None, | |
| impala_channels: Sequence[int] = (16, 32, 32), | |
| **kwargs, | |
| ) -> None: | |
| if init_layers_orthogonal is None: | |
| init_layers_orthogonal = False | |
| super().__init__(in_channels, activation, init_layers_orthogonal) | |
| sequences = [] | |
| for out_channels in impala_channels: | |
| sequences.append( | |
| ConvSequence( | |
| in_channels, out_channels, activation, init_layers_orthogonal | |
| ) | |
| ) | |
| in_channels = out_channels | |
| sequences.extend( | |
| [ | |
| activation(), | |
| nn.Flatten(), | |
| ] | |
| ) | |
| self.seq = nn.Sequential(*sequences) | |
| def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
| return self.seq(obs) | |
| CNN_EXTRACTORS_BY_STYLE: Dict[str, Type[CnnFeatureExtractor]] = { | |
| "nature": NatureCnn, | |
| "impala": ImpalaCnn, | |
| } | |
| class FeatureExtractor(nn.Module): | |
| def __init__( | |
| self, | |
| obs_space: gym.Space, | |
| activation: Type[nn.Module], | |
| init_layers_orthogonal: bool = False, | |
| cnn_feature_dim: int = 512, | |
| cnn_style: str = "nature", | |
| cnn_layers_init_orthogonal: Optional[bool] = None, | |
| impala_channels: Sequence[int] = (16, 32, 32), | |
| ) -> None: | |
| super().__init__() | |
| if isinstance(obs_space, Box): | |
| # Conv2D: (channels, height, width) | |
| if len(obs_space.shape) == 3: | |
| cnn = CNN_EXTRACTORS_BY_STYLE[cnn_style]( | |
| obs_space.shape[0], | |
| activation, | |
| init_layers_orthogonal=cnn_layers_init_orthogonal, | |
| impala_channels=impala_channels, | |
| ) | |
| def preprocess(obs: torch.Tensor) -> torch.Tensor: | |
| if len(obs.shape) == 3: | |
| obs = obs.unsqueeze(0) | |
| return obs.float() / 255.0 | |
| with torch.no_grad(): | |
| cnn_out = cnn(preprocess(torch.as_tensor(obs_space.sample()))) | |
| self.preprocess = preprocess | |
| self.feature_extractor = nn.Sequential( | |
| cnn, | |
| layer_init( | |
| nn.Linear(cnn_out.shape[1], cnn_feature_dim), | |
| init_layers_orthogonal, | |
| ), | |
| activation(), | |
| ) | |
| self.out_dim = cnn_feature_dim | |
| elif len(obs_space.shape) == 1: | |
| def preprocess(obs: torch.Tensor) -> torch.Tensor: | |
| if len(obs.shape) == 1: | |
| obs = obs.unsqueeze(0) | |
| return obs.float() | |
| self.preprocess = preprocess | |
| self.feature_extractor = nn.Flatten() | |
| self.out_dim = get_flattened_obs_dim(obs_space) | |
| else: | |
| raise ValueError(f"Unsupported observation space: {obs_space}") | |
| elif isinstance(obs_space, Discrete): | |
| self.preprocess = lambda x: F.one_hot(x, obs_space.n).float() | |
| self.feature_extractor = nn.Flatten() | |
| self.out_dim = obs_space.n | |
| else: | |
| raise NotImplementedError | |
| def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
| if self.preprocess: | |
| obs = self.preprocess(obs) | |
| return self.feature_extractor(obs) | |