| | from typing import Any, Dict, List, Optional, Tuple, Type, Union |
| |
|
| | import gym |
| | import numpy as np |
| | import torch as th |
| | from torch.nn import functional as F |
| |
|
| | from stable_baselines3.common import logger |
| | from stable_baselines3.common.off_policy_algorithm import OffPolicyAlgorithm |
| | from stable_baselines3.common.preprocessing import maybe_transpose |
| | from stable_baselines3.common.type_aliases import GymEnv, MaybeCallback, Schedule |
| | from stable_baselines3.common.utils import get_linear_fn, is_vectorized_observation, polyak_update |
| | from stable_baselines3.dqn.policies import DQNPolicy |
| |
|
| |
|
| | class DQN(OffPolicyAlgorithm): |
| | """ |
| | Deep Q-Network (DQN) |
| | |
| | Paper: https://arxiv.org/abs/1312.5602, https://www.nature.com/articles/nature14236 |
| | Default hyperparameters are taken from the nature paper, |
| | except for the optimizer and learning rate that were taken from Stable Baselines defaults. |
| | |
| | :param policy: The policy model to use (MlpPolicy, CnnPolicy, ...) |
| | :param env: The environment to learn from (if registered in Gym, can be str) |
| | :param learning_rate: The learning rate, it can be a function |
| | of the current progress remaining (from 1 to 0) |
| | :param buffer_size: size of the replay buffer |
| | :param learning_starts: how many steps of the model to collect transitions for before learning starts |
| | :param batch_size: Minibatch size for each gradient update |
| | :param tau: the soft update coefficient ("Polyak update", between 0 and 1) default 1 for hard update |
| | :param gamma: the discount factor |
| | :param train_freq: Update the model every ``train_freq`` steps. Alternatively pass a tuple of frequency and unit |
| | like ``(5, "step")`` or ``(2, "episode")``. |
| | :param gradient_steps: How many gradient steps to do after each rollout (see ``train_freq``) |
| | Set to ``-1`` means to do as many gradient steps as steps done in the environment |
| | during the rollout. |
| | :param optimize_memory_usage: Enable a memory efficient variant of the replay buffer |
| | at a cost of more complexity. |
| | See https://github.com/DLR-RM/stable-baselines3/issues/37#issuecomment-637501195 |
| | :param target_update_interval: update the target network every ``target_update_interval`` |
| | environment steps. |
| | :param exploration_fraction: fraction of entire training period over which the exploration rate is reduced |
| | :param exploration_initial_eps: initial value of random action probability |
| | :param exploration_final_eps: final value of random action probability |
| | :param max_grad_norm: The maximum value for the gradient clipping |
| | :param tensorboard_log: the log location for tensorboard (if None, no logging) |
| | :param create_eval_env: Whether to create a second environment that will be |
| | used for evaluating the agent periodically. (Only available when passing string for the environment) |
| | :param policy_kwargs: additional arguments to be passed to the policy on creation |
| | :param verbose: the verbosity level: 0 no output, 1 info, 2 debug |
| | :param seed: Seed for the pseudo random generators |
| | :param device: Device (cpu, cuda, ...) on which the code should be run. |
| | Setting it to auto, the code will be run on the GPU if possible. |
| | :param _init_setup_model: Whether or not to build the network at the creation of the instance |
| | """ |
| |
|
| | def __init__( |
| | self, |
| | policy: Union[str, Type[DQNPolicy]], |
| | env: Union[GymEnv, str], |
| | learning_rate: Union[float, Schedule] = 1e-4, |
| | buffer_size: int = 1000000, |
| | learning_starts: int = 50000, |
| | batch_size: Optional[int] = 32, |
| | tau: float = 1.0, |
| | gamma: float = 0.99, |
| | train_freq: Union[int, Tuple[int, str]] = 4, |
| | gradient_steps: int = 1, |
| | optimize_memory_usage: bool = False, |
| | target_update_interval: int = 10000, |
| | exploration_fraction: float = 0.1, |
| | exploration_initial_eps: float = 1.0, |
| | exploration_final_eps: float = 0.05, |
| | max_grad_norm: float = 10, |
| | tensorboard_log: Optional[str] = None, |
| | create_eval_env: bool = False, |
| | policy_kwargs: Optional[Dict[str, Any]] = None, |
| | verbose: int = 0, |
| | seed: Optional[int] = None, |
| | device: Union[th.device, str] = "auto", |
| | _init_setup_model: bool = True, |
| | ): |
| |
|
| | super(DQN, self).__init__( |
| | policy, |
| | env, |
| | DQNPolicy, |
| | learning_rate, |
| | buffer_size, |
| | learning_starts, |
| | batch_size, |
| | tau, |
| | gamma, |
| | train_freq, |
| | gradient_steps, |
| | action_noise=None, |
| | policy_kwargs=policy_kwargs, |
| | tensorboard_log=tensorboard_log, |
| | verbose=verbose, |
| | device=device, |
| | create_eval_env=create_eval_env, |
| | seed=seed, |
| | sde_support=False, |
| | optimize_memory_usage=optimize_memory_usage, |
| | supported_action_spaces=(gym.spaces.Discrete,), |
| | ) |
| |
|
| | self.exploration_initial_eps = exploration_initial_eps |
| | self.exploration_final_eps = exploration_final_eps |
| | self.exploration_fraction = exploration_fraction |
| | self.target_update_interval = target_update_interval |
| | self.max_grad_norm = max_grad_norm |
| | |
| | self.exploration_rate = 0.0 |
| | |
| | self.exploration_schedule = None |
| | self.q_net, self.q_net_target = None, None |
| |
|
| | if _init_setup_model: |
| | self._setup_model() |
| |
|
| | def _setup_model(self) -> None: |
| | super(DQN, self)._setup_model() |
| | self._create_aliases() |
| | self.exploration_schedule = get_linear_fn( |
| | self.exploration_initial_eps, self.exploration_final_eps, self.exploration_fraction |
| | ) |
| |
|
| | def _create_aliases(self) -> None: |
| | self.q_net = self.policy.q_net |
| | self.q_net_target = self.policy.q_net_target |
| |
|
| | def _on_step(self) -> None: |
| | """ |
| | Update the exploration rate and target network if needed. |
| | This method is called in ``collect_rollouts()`` after each step in the environment. |
| | """ |
| | if self.num_timesteps % self.target_update_interval == 0: |
| | polyak_update(self.q_net.parameters(), self.q_net_target.parameters(), self.tau) |
| |
|
| | self.exploration_rate = self.exploration_schedule(self._current_progress_remaining) |
| | logger.record("rollout/exploration rate", self.exploration_rate) |
| |
|
| | def train(self, gradient_steps: int, batch_size: int = 100) -> None: |
| | |
| | self._update_learning_rate(self.policy.optimizer) |
| |
|
| | losses = [] |
| | for _ in range(gradient_steps): |
| | |
| | replay_data = self.replay_buffer.sample(batch_size, env=self._vec_normalize_env) |
| |
|
| | with th.no_grad(): |
| | |
| | next_q_values = self.q_net_target(replay_data.next_observations) |
| | |
| | next_q_values, _ = next_q_values.max(dim=1) |
| | |
| | next_q_values = next_q_values.reshape(-1, 1) |
| | |
| | target_q_values = replay_data.rewards + (1 - replay_data.dones) * self.gamma * next_q_values |
| |
|
| | |
| | current_q_values = self.q_net(replay_data.observations) |
| |
|
| | |
| | current_q_values = th.gather(current_q_values, dim=1, index=replay_data.actions.long()) |
| |
|
| | |
| | loss = F.smooth_l1_loss(current_q_values, target_q_values) |
| | losses.append(loss.item()) |
| |
|
| | |
| | self.policy.optimizer.zero_grad() |
| | loss.backward() |
| | |
| | th.nn.utils.clip_grad_norm_(self.policy.parameters(), self.max_grad_norm) |
| | self.policy.optimizer.step() |
| |
|
| | |
| | self._n_updates += gradient_steps |
| |
|
| | logger.record("train/n_updates", self._n_updates, exclude="tensorboard") |
| | logger.record("train/loss", np.mean(losses)) |
| |
|
| | def predict( |
| | self, |
| | observation: np.ndarray, |
| | state: Optional[np.ndarray] = None, |
| | mask: Optional[np.ndarray] = None, |
| | deterministic: bool = False, |
| | ) -> Tuple[np.ndarray, Optional[np.ndarray]]: |
| | """ |
| | Overrides the base_class predict function to include epsilon-greedy exploration. |
| | |
| | :param observation: the input observation |
| | :param state: The last states (can be None, used in recurrent policies) |
| | :param mask: The last masks (can be None, used in recurrent policies) |
| | :param deterministic: Whether or not to return deterministic actions. |
| | :return: the model's action and the next state |
| | (used in recurrent policies) |
| | """ |
| | if not deterministic and np.random.rand() < self.exploration_rate: |
| | if is_vectorized_observation(maybe_transpose(observation, self.observation_space), self.observation_space): |
| | n_batch = observation.shape[0] |
| | action = np.array([self.action_space.sample() for _ in range(n_batch)]) |
| | else: |
| | action = np.array(self.action_space.sample()) |
| | else: |
| | action, state = self.policy.predict(observation, state, mask, deterministic) |
| | return action, state |
| |
|
| | def learn( |
| | self, |
| | total_timesteps: int, |
| | callback: MaybeCallback = None, |
| | log_interval: int = 4, |
| | eval_env: Optional[GymEnv] = None, |
| | eval_freq: int = -1, |
| | n_eval_episodes: int = 5, |
| | tb_log_name: str = "DQN", |
| | eval_log_path: Optional[str] = None, |
| | reset_num_timesteps: bool = True, |
| | ) -> OffPolicyAlgorithm: |
| |
|
| | return super(DQN, self).learn( |
| | total_timesteps=total_timesteps, |
| | callback=callback, |
| | log_interval=log_interval, |
| | eval_env=eval_env, |
| | eval_freq=eval_freq, |
| | n_eval_episodes=n_eval_episodes, |
| | tb_log_name=tb_log_name, |
| | eval_log_path=eval_log_path, |
| | reset_num_timesteps=reset_num_timesteps, |
| | ) |
| |
|
| | def _excluded_save_params(self) -> List[str]: |
| | return super(DQN, self)._excluded_save_params() + ["q_net", "q_net_target"] |
| |
|
| | def _get_torch_save_params(self) -> Tuple[List[str], List[str]]: |
| | state_dicts = ["policy", "policy.optimizer"] |
| |
|
| | return state_dicts, [] |
| |
|