| from stable_baselines3.common.vec_env.base_vec_env import VecEnv | |
| from typing import Optional, Sequence | |
| from gym.spaces import Box, Discrete | |
| from shared.policy.on_policy import ActorCritic, default_hidden_sizes | |
| class PPOActorCritic(ActorCritic): | |
| def __init__( | |
| self, | |
| env: VecEnv, | |
| pi_hidden_sizes: Optional[Sequence[int]] = None, | |
| v_hidden_sizes: Optional[Sequence[int]] = None, | |
| **kwargs, | |
| ) -> None: | |
| pi_hidden_sizes = ( | |
| pi_hidden_sizes | |
| if pi_hidden_sizes is not None | |
| else default_hidden_sizes(env.observation_space) | |
| ) | |
| v_hidden_sizes = ( | |
| v_hidden_sizes | |
| if v_hidden_sizes is not None | |
| else default_hidden_sizes(env.observation_space) | |
| ) | |
| super().__init__( | |
| env, | |
| pi_hidden_sizes, | |
| v_hidden_sizes, | |
| **kwargs, | |
| ) | |