| import gym | |
| import torch | |
| import torch.nn as nn | |
| from typing import Sequence, Type | |
| from shared.module.feature_extractor import FeatureExtractor | |
| from shared.module.module import mlp | |
| class CriticHead(nn.Module): | |
| def __init__( | |
| self, | |
| hidden_sizes: Sequence[int] = (32,), | |
| activation: Type[nn.Module] = nn.Tanh, | |
| init_layers_orthogonal: bool = True, | |
| ) -> None: | |
| super().__init__() | |
| layer_sizes = tuple(hidden_sizes) + (1,) | |
| self._fc = mlp( | |
| layer_sizes, | |
| activation, | |
| init_layers_orthogonal=init_layers_orthogonal, | |
| final_layer_gain=1.0, | |
| ) | |
| def forward(self, obs: torch.Tensor) -> torch.Tensor: | |
| v = self._fc(obs) | |
| return v.squeeze(-1) | |