|
|
import numpy as np |
|
|
import torch |
|
|
import torch.nn as nn |
|
|
|
|
|
from collections import defaultdict |
|
|
from dataclasses import dataclass, asdict |
|
|
from stable_baselines3.common.vec_env.base_vec_env import VecEnv, VecEnvObs |
|
|
from torch.optim import Adam |
|
|
from torch.utils.tensorboard.writer import SummaryWriter |
|
|
from typing import Optional, Sequence, TypeVar |
|
|
|
|
|
from shared.algorithm import Algorithm |
|
|
from shared.callbacks.callback import Callback |
|
|
from shared.gae import compute_rtg_and_advantage, compute_advantage |
|
|
from shared.trajectory import Trajectory, TrajectoryAccumulator |
|
|
from vpg.policy import VPGActorCritic |
|
|
|
|
|
|
|
|
@dataclass |
|
|
class TrainEpochStats: |
|
|
pi_loss: float |
|
|
v_loss: float |
|
|
envs_with_done: int = 0 |
|
|
episodes_done: int = 0 |
|
|
|
|
|
def write_to_tensorboard(self, tb_writer: SummaryWriter, global_step: int) -> None: |
|
|
tb_writer.add_scalars("losses", asdict(self), global_step=global_step) |
|
|
|
|
|
|
|
|
class VPGTrajectoryAccumulator(TrajectoryAccumulator): |
|
|
def __init__(self, num_envs: int) -> None: |
|
|
super().__init__(num_envs, trajectory_class=Trajectory) |
|
|
self.completed_per_env: defaultdict[int, int] = defaultdict(int) |
|
|
|
|
|
def on_done(self, env_idx: int, trajectory: Trajectory) -> None: |
|
|
self.completed_per_env[env_idx] += 1 |
|
|
|
|
|
|
|
|
VanillaPolicyGradientSelf = TypeVar( |
|
|
"VanillaPolicyGradientSelf", bound="VanillaPolicyGradient" |
|
|
) |
|
|
|
|
|
|
|
|
class VanillaPolicyGradient(Algorithm): |
|
|
def __init__( |
|
|
self, |
|
|
policy: VPGActorCritic, |
|
|
env: VecEnv, |
|
|
device: torch.device, |
|
|
tb_writer: SummaryWriter, |
|
|
gamma: float = 0.99, |
|
|
pi_lr: float = 3e-4, |
|
|
val_lr: float = 1e-3, |
|
|
train_v_iters: int = 80, |
|
|
gae_lambda: float = 0.97, |
|
|
max_grad_norm: float = 10.0, |
|
|
n_steps: int = 4_000, |
|
|
sde_sample_freq: int = -1, |
|
|
update_rtg_between_v_iters: bool = False, |
|
|
) -> None: |
|
|
super().__init__(policy, env, device, tb_writer) |
|
|
self.policy = policy |
|
|
|
|
|
self.gamma = gamma |
|
|
self.gae_lambda = gae_lambda |
|
|
self.pi_optim = Adam(self.policy.pi.parameters(), lr=pi_lr) |
|
|
self.val_optim = Adam(self.policy.v.parameters(), lr=val_lr) |
|
|
self.max_grad_norm = max_grad_norm |
|
|
|
|
|
self.n_steps = n_steps |
|
|
self.train_v_iters = train_v_iters |
|
|
self.sde_sample_freq = sde_sample_freq |
|
|
self.update_rtg_between_v_iters = update_rtg_between_v_iters |
|
|
|
|
|
def learn( |
|
|
self: VanillaPolicyGradientSelf, |
|
|
total_timesteps: int, |
|
|
callback: Optional[Callback] = None, |
|
|
) -> VanillaPolicyGradientSelf: |
|
|
timesteps_elapsed = 0 |
|
|
epoch_cnt = 0 |
|
|
while timesteps_elapsed < total_timesteps: |
|
|
epoch_cnt += 1 |
|
|
accumulator = self._collect_trajectories() |
|
|
epoch_stats = self.train(accumulator.all_trajectories) |
|
|
epoch_stats.envs_with_done = len(accumulator.completed_per_env) |
|
|
epoch_stats.episodes_done = sum(accumulator.completed_per_env.values()) |
|
|
epoch_steps = accumulator.n_timesteps() |
|
|
timesteps_elapsed += epoch_steps |
|
|
epoch_stats.write_to_tensorboard( |
|
|
self.tb_writer, global_step=timesteps_elapsed |
|
|
) |
|
|
print( |
|
|
f"Epoch: {epoch_cnt} | " |
|
|
f"Pi Loss: {round(epoch_stats.pi_loss, 2)} | " |
|
|
f"V Loss: {round(epoch_stats.v_loss, 2)} | " |
|
|
f"Total Steps: {timesteps_elapsed}" |
|
|
) |
|
|
if callback: |
|
|
callback.on_step(timesteps_elapsed=epoch_steps) |
|
|
return self |
|
|
|
|
|
def train(self, trajectories: Sequence[Trajectory]) -> TrainEpochStats: |
|
|
self.policy.train() |
|
|
obs = torch.as_tensor( |
|
|
np.concatenate([np.array(t.obs) for t in trajectories]), device=self.device |
|
|
) |
|
|
act = torch.as_tensor( |
|
|
np.concatenate([np.array(t.act) for t in trajectories]), device=self.device |
|
|
) |
|
|
rtg, adv = compute_rtg_and_advantage( |
|
|
trajectories, self.policy, self.gamma, self.gae_lambda, self.device |
|
|
) |
|
|
|
|
|
pi_loss = self._update_pi(obs, act, adv) |
|
|
v_loss = 0 |
|
|
for _ in range(self.train_v_iters): |
|
|
if self.update_rtg_between_v_iters: |
|
|
rtg = compute_advantage( |
|
|
trajectories, self.policy, self.gamma, self.gae_lambda, self.device |
|
|
) |
|
|
v_loss = self._update_v(obs, rtg) |
|
|
|
|
|
return TrainEpochStats(pi_loss, v_loss) |
|
|
|
|
|
def _collect_trajectories(self) -> VPGTrajectoryAccumulator: |
|
|
self.policy.eval() |
|
|
obs = self.env.reset() |
|
|
accumulator = VPGTrajectoryAccumulator(self.env.num_envs) |
|
|
self.policy.reset_noise() |
|
|
for i in range(self.n_steps): |
|
|
if self.sde_sample_freq > 0 and i > 0 and i % self.sde_sample_freq == 0: |
|
|
self.policy.reset_noise() |
|
|
action, value, _, clamped_action = self.policy.step(obs) |
|
|
next_obs, reward, done, _ = self.env.step(clamped_action) |
|
|
accumulator.step(obs, action, next_obs, reward, done, value) |
|
|
obs = next_obs |
|
|
return accumulator |
|
|
|
|
|
def _update_pi( |
|
|
self, obs: torch.Tensor, act: torch.Tensor, adv: torch.Tensor |
|
|
) -> float: |
|
|
self.pi_optim.zero_grad() |
|
|
_, logp, _ = self.policy.pi(obs, act) |
|
|
pi_loss = -(logp * adv).mean() |
|
|
pi_loss.backward() |
|
|
nn.utils.clip_grad_norm_(self.policy.pi.parameters(), self.max_grad_norm) |
|
|
self.pi_optim.step() |
|
|
return pi_loss.item() |
|
|
|
|
|
def _update_v(self, obs: torch.Tensor, rtg: torch.Tensor) -> float: |
|
|
self.val_optim.zero_grad() |
|
|
v = self.policy.v(obs) |
|
|
v_loss = ((v - rtg) ** 2).mean() |
|
|
v_loss.backward() |
|
|
nn.utils.clip_grad_norm_(self.policy.v.parameters(), self.max_grad_norm) |
|
|
self.val_optim.step() |
|
|
return v_loss.item() |
|
|
|