| import numpy as np |
| import os |
| import torch |
|
|
| from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs |
| from typing import Optional, Sequence, TypeVar |
|
|
| from dqn.q_net import QNetwork |
| from shared.policy.policy import Policy |
|
|
| DQNPolicySelf = TypeVar("DQNPolicySelf", bound="DQNPolicy") |
|
|
|
|
| class DQNPolicy(Policy): |
| def __init__( |
| self, |
| env: VecEnv, |
| hidden_sizes: Sequence[int] = [], |
| cnn_feature_dim: int = 512, |
| cnn_style: str = "nature", |
| cnn_layers_init_orthogonal: Optional[bool] = None, |
| **kwargs, |
| ) -> None: |
| super().__init__(env, **kwargs) |
| self.q_net = QNetwork( |
| env.observation_space, |
| env.action_space, |
| hidden_sizes, |
| cnn_feature_dim=cnn_feature_dim, |
| cnn_style=cnn_style, |
| cnn_layers_init_orthogonal=cnn_layers_init_orthogonal, |
| ) |
|
|
| def act( |
| self, obs: VecEnvObs, eps: float = 0, deterministic: bool = True |
| ) -> np.ndarray: |
| assert eps == 0 if deterministic else eps >= 0 |
| if not deterministic and np.random.random() < eps: |
| return np.array( |
| [self.env.action_space.sample() for _ in range(self.env.num_envs)] |
| ) |
| else: |
| o = self._as_tensor(obs) |
| with torch.no_grad(): |
| return self.q_net(o).argmax(axis=1).cpu().numpy() |
|
|