| """ |
| * 🥼 Test throughput (see docs): |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| * this will help us diagnose the throughput issue |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| * 🔥 Best performance so far (more GPUs -> faster) |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 --track |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 0 1 --track |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 --num-envs 60 --async-batch-size 20 --track |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --track |
| * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 --num-envs 60 --async-batch-size 20 --track |
| * (this actually doesn't work that well) python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 5 6 7 --num-envs 70 --async-batch-size 35 --track |
| """ |
| |
| import argparse |
| import os |
| import random |
| import time |
| import uuid |
| import warnings |
| from collections import deque |
| from distutils.util import strtobool |
| from functools import partial |
| from typing import Sequence |
|
|
| os.environ[ |
| "XLA_PYTHON_CLIENT_MEM_FRACTION" |
| ] = "0.6" |
| os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1" |
| import multiprocessing as mp |
| import queue |
| import threading |
|
|
| import envpool |
| import flax |
| import flax.linen as nn |
| import gym |
| import jax |
| import jax.numpy as jnp |
| import numpy as np |
| import optax |
| from flax.linen.initializers import constant, orthogonal |
| from flax.training.train_state import TrainState |
| from torch.utils.tensorboard import SummaryWriter |
|
|
|
|
| def parse_args(): |
| |
| parser = argparse.ArgumentParser() |
| parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), |
| help="the name of this experiment") |
| parser.add_argument("--seed", type=int, default=1, |
| help="seed of the experiment") |
| parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| help="if toggled, `torch.backends.cudnn.deterministic=False`") |
| parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| help="if toggled, cuda will be enabled by default") |
| parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="if toggled, this experiment will be tracked with Weights and Biases") |
| parser.add_argument("--wandb-project-name", type=str, default="cleanRL", |
| help="the wandb's project name") |
| parser.add_argument("--wandb-entity", type=str, default=None, |
| help="the entity (team) of wandb's project") |
| parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="weather to capture videos of the agent performances (check out `videos` folder)") |
| parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="whether to save model into the `runs/{run_name}` folder") |
| parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="whether to upload the saved model to huggingface") |
| parser.add_argument("--hf-entity", type=str, default="", |
| help="the user or org name of the model repository from the Hugging Face Hub") |
|
|
| |
| parser.add_argument("--env-id", type=str, default="Breakout-v5", |
| help="the id of the environment") |
| parser.add_argument("--total-timesteps", type=int, default=50000000, |
| help="total timesteps of the experiments") |
| parser.add_argument("--learning-rate", type=float, default=2.5e-4, |
| help="the learning rate of the optimizer") |
| parser.add_argument("--num-envs", type=int, default=64, |
| help="the number of parallel game environments") |
| parser.add_argument("--async-batch-size", type=int, default=16, |
| help="the envpool's batch size in the async mode") |
| parser.add_argument("--num-steps", type=int, default=128, |
| help="the number of steps to run in each environment per policy rollout") |
| parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| help="Toggle learning rate annealing for policy and value networks") |
| parser.add_argument("--gamma", type=float, default=0.99, |
| help="the discount factor gamma") |
| parser.add_argument("--gae-lambda", type=float, default=0.95, |
| help="the lambda for the general advantage estimation") |
| parser.add_argument("--num-minibatches", type=int, default=4, |
| help="the number of mini-batches") |
| parser.add_argument("--update-epochs", type=int, default=4, |
| help="the K epochs to update the policy") |
| parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| help="Toggles advantages normalization") |
| parser.add_argument("--clip-coef", type=float, default=0.1, |
| help="the surrogate clipping coefficient") |
| parser.add_argument("--ent-coef", type=float, default=0.01, |
| help="coefficient of the entropy") |
| parser.add_argument("--vf-coef", type=float, default=0.5, |
| help="coefficient of the value function") |
| parser.add_argument("--max-grad-norm", type=float, default=0.5, |
| help="the maximum norm for the gradient clipping") |
| parser.add_argument("--target-kl", type=float, default=None, |
| help="the target KL divergence threshold") |
|
|
| parser.add_argument("--actor-device-ids", type=int, nargs="+", default=[0], |
| help="the device ids that actor workers will use") |
| parser.add_argument("--learner-device-ids", type=int, nargs="+", default=[0], |
| help="the device ids that actor workers will use") |
| parser.add_argument("--num-actor-threads", type=int, default=1, |
| help="the number of actor threads") |
| parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="whether to call block_until_ready() for profiling") |
| parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| help="whether to test actor-learner throughput by removing the actor-learner communication") |
| args = parser.parse_args() |
| args.batch_size = int(args.num_envs * args.num_steps) |
| args.minibatch_size = int(args.batch_size // args.num_minibatches) |
| args.num_updates = args.total_timesteps // args.batch_size |
| args.async_update = int(args.num_envs / args.async_batch_size) |
| assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now" |
| if args.num_actor_threads > 1: |
| warnings.warn("⚠️ !!!! `num_actor_threads` > 1 is not tested with learning; see docs for detail") |
| |
| return args |
|
|
|
|
| def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1): |
| def thunk(): |
| envs = envpool.make( |
| env_id, |
| env_type="gym", |
| num_envs=num_envs, |
| num_threads=num_threads if num_threads is not None else async_batch_size, |
| thread_affinity_offset=thread_affinity_offset, |
| batch_size=async_batch_size, |
| episodic_life=True, |
| repeat_action_probability=0, |
| noop_max=30, |
| full_action_space=False, |
| max_episode_steps=int(108000 / 4), |
| reward_clip=True, |
| seed=seed, |
| ) |
| envs.num_envs = num_envs |
| envs.single_action_space = envs.action_space |
| envs.single_observation_space = envs.observation_space |
| envs.is_vector_env = True |
| return envs |
|
|
| return thunk |
|
|
|
|
| class ResidualBlock(nn.Module): |
| channels: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| inputs = x |
| x = nn.relu(x) |
| x = nn.Conv( |
| self.channels, |
| kernel_size=(3, 3), |
| )(x) |
| x = nn.relu(x) |
| x = nn.Conv( |
| self.channels, |
| kernel_size=(3, 3), |
| )(x) |
| return x + inputs |
|
|
|
|
| class ConvSequence(nn.Module): |
| channels: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| x = nn.Conv( |
| self.channels, |
| kernel_size=(3, 3), |
| )(x) |
| x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") |
| x = ResidualBlock(self.channels)(x) |
| x = ResidualBlock(self.channels)(x) |
| return x |
|
|
|
|
| class Network(nn.Module): |
| channelss: Sequence[int] = (16, 32, 32) |
|
|
| @nn.compact |
| def __call__(self, x): |
| x = jnp.transpose(x, (0, 2, 3, 1)) |
| x = x / (255.0) |
| for channels in self.channelss: |
| x = ConvSequence(channels)(x) |
| x = nn.relu(x) |
| x = x.reshape((x.shape[0], -1)) |
| x = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) |
| x = nn.relu(x) |
| return x |
|
|
|
|
| class Critic(nn.Module): |
| @nn.compact |
| def __call__(self, x): |
| return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) |
|
|
|
|
| class Actor(nn.Module): |
| action_dim: int |
|
|
| @nn.compact |
| def __call__(self, x): |
| return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) |
|
|
|
|
| @flax.struct.dataclass |
| class AgentParams: |
| network_params: flax.core.FrozenDict |
| actor_params: flax.core.FrozenDict |
| critic_params: flax.core.FrozenDict |
|
|
|
|
| @partial(jax.jit, static_argnums=(3)) |
| def get_action_and_value( |
| params: TrainState, |
| next_obs: np.ndarray, |
| key: jax.random.PRNGKey, |
| action_dim: int, |
| ): |
| hidden = Network().apply(params.network_params, next_obs) |
| logits = Actor(action_dim).apply(params.actor_params, hidden) |
| |
| |
| key, subkey = jax.random.split(key) |
| u = jax.random.uniform(subkey, shape=logits.shape) |
| action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) |
| logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
| value = Critic().apply(params.critic_params, hidden) |
| return action, logprob, value.squeeze(), key |
|
|
|
|
| @jax.jit |
| def prepare_data( |
| obs: list, |
| dones: list, |
| values: list, |
| actions: list, |
| logprobs: list, |
| env_ids: list, |
| rewards: list, |
| ): |
| obs = jnp.asarray(obs) |
| dones = jnp.asarray(dones) |
| values = jnp.asarray(values) |
| actions = jnp.asarray(actions) |
| logprobs = jnp.asarray(logprobs) |
| env_ids = jnp.asarray(env_ids) |
| rewards = jnp.asarray(rewards) |
|
|
| |
| T, B = env_ids.shape |
| index_ranges = jnp.arange(T * B, dtype=jnp.int32) |
| next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32) |
| last_env_ids = jnp.zeros(args.num_envs, dtype=jnp.int32) - 1 |
|
|
| def f(carry, x): |
| last_env_ids, next_index_ranges = carry |
| env_id, index_range = x |
| next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set( |
| jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]]) |
| ) |
| last_env_ids = last_env_ids.at[env_id].set(index_range) |
| return (last_env_ids, next_index_ranges), None |
|
|
| (last_env_ids, next_index_ranges), _ = jax.lax.scan( |
| f, |
| (last_env_ids, next_index_ranges), |
| (env_ids.reshape(-1), index_ranges), |
| ) |
|
|
| |
| rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * args.async_update, args.async_batch_size) |
| advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones) |
| |
| b_obs = obs.reshape((-1,) + obs.shape[2:]) |
| b_actions = actions.reshape(-1) |
| b_logprobs = logprobs.reshape(-1) |
| b_advantages = advantages.reshape(-1) |
| b_returns = returns.reshape(-1) |
| return b_obs, b_actions, b_logprobs, b_advantages, b_returns |
|
|
|
|
| def rollout( |
| i, |
| num_threads, |
| thread_affinity_offset, |
| key: jax.random.PRNGKey, |
| args, |
| rollout_queue, |
| params_queue: queue.Queue, |
| writer, |
| learner_devices, |
| ): |
| envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size, num_threads, thread_affinity_offset)() |
| len_actor_device_ids = len(args.actor_device_ids) |
| global_step = 0 |
| |
| start_time = time.time() |
|
|
| |
| episode_returns = np.zeros((args.num_envs,), dtype=np.float32) |
| returned_episode_returns = np.zeros((args.num_envs,), dtype=np.float32) |
| episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) |
| returned_episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) |
| envs.async_reset() |
|
|
| params_queue_get_time = deque(maxlen=10) |
| rollout_time = deque(maxlen=10) |
| data_transfer_time = deque(maxlen=10) |
| rollout_queue_put_time = deque(maxlen=10) |
| actor_policy_version = 0 |
| for update in range(1, args.num_updates + 2): |
| |
| |
| |
| |
| |
| |
| update_time_start = time.time() |
| obs = [] |
| dones = [] |
| actions = [] |
| logprobs = [] |
| values = [] |
| env_ids = [] |
| rewards = [] |
| truncations = [] |
| terminations = [] |
| env_recv_time = 0 |
| inference_time = 0 |
| storage_time = 0 |
| env_send_time = 0 |
|
|
| |
| |
| |
| params_queue_get_time_start = time.time() |
| if update != 2: |
| params = params_queue.get() |
| actor_policy_version += 1 |
| params_queue_get_time.append(time.time() - params_queue_get_time_start) |
| writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step) |
| rollout_time_start = time.time() |
| for _ in range( |
| args.async_update, (args.num_steps + 1) * args.async_update |
| ): |
| env_recv_time_start = time.time() |
| next_obs, next_reward, next_done, info = envs.recv() |
| env_recv_time += time.time() - env_recv_time_start |
| global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids |
| env_id = info["env_id"] |
|
|
| inference_time_start = time.time() |
| action, logprob, value, key = get_action_and_value(params, next_obs, key, envs.single_action_space.n) |
| inference_time += time.time() - inference_time_start |
|
|
| env_send_time_start = time.time() |
| envs.send(np.array(action), env_id) |
| env_send_time += time.time() - env_send_time_start |
| storage_time_start = time.time() |
| obs.append(next_obs) |
| dones.append(next_done) |
| values.append(value) |
| actions.append(action) |
| logprobs.append(logprob) |
| env_ids.append(env_id) |
| rewards.append(next_reward) |
| truncations.append(info["TimeLimit.truncated"]) |
| terminations.append(info["terminated"]) |
| episode_returns[env_id] += info["reward"] |
| returned_episode_returns[env_id] = np.where( |
| info["terminated"] + info["TimeLimit.truncated"], episode_returns[env_id], returned_episode_returns[env_id] |
| ) |
| episode_returns[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]) |
| episode_lengths[env_id] += 1 |
| returned_episode_lengths[env_id] = np.where( |
| info["terminated"] + info["TimeLimit.truncated"], episode_lengths[env_id], returned_episode_lengths[env_id] |
| ) |
| episode_lengths[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]) |
| storage_time += time.time() - storage_time_start |
| if args.profile: |
| action.block_until_ready() |
| rollout_time.append(time.time() - rollout_time_start) |
| writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step) |
|
|
| avg_episodic_return = np.mean(returned_episode_returns) |
| writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) |
| writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step) |
| if i == 0: |
| print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") |
| print("SPS:", int(global_step / (time.time() - start_time))) |
| writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
|
|
| writer.add_scalar("stats/truncations", np.sum(truncations), global_step) |
| writer.add_scalar("stats/terminations", np.sum(terminations), global_step) |
| writer.add_scalar("stats/env_recv_time", env_recv_time, global_step) |
| writer.add_scalar("stats/inference_time", inference_time, global_step) |
| writer.add_scalar("stats/storage_time", storage_time, global_step) |
| writer.add_scalar("stats/env_send_time", env_send_time, global_step) |
|
|
| data_transfer_time_start = time.time() |
| b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data( |
| obs, |
| dones, |
| values, |
| actions, |
| logprobs, |
| env_ids, |
| rewards, |
| ) |
| payload = ( |
| global_step, |
| actor_policy_version, |
| update, |
| jnp.array_split(b_obs, len(learner_devices)), |
| jnp.array_split(b_actions, len(learner_devices)), |
| jnp.array_split(b_logprobs, len(learner_devices)), |
| jnp.array_split(b_advantages, len(learner_devices)), |
| jnp.array_split(b_returns, len(learner_devices)), |
| ) |
| if args.profile: |
| payload[2][0].block_until_ready() |
| data_transfer_time.append(time.time() - data_transfer_time_start) |
| writer.add_scalar("stats/data_transfer_time", np.mean(data_transfer_time), global_step) |
| if update == 1 or not args.test_actor_learner_throughput: |
| rollout_queue_put_time_start = time.time() |
| rollout_queue.put(payload) |
| rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) |
| writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step) |
|
|
| writer.add_scalar( |
| "charts/SPS_update", |
| int( |
| args.num_envs |
| * args.num_steps |
| * args.num_actor_threads |
| * len_actor_device_ids |
| / (time.time() - update_time_start) |
| ), |
| global_step, |
| ) |
|
|
|
|
| @partial(jax.jit, static_argnums=(3)) |
| def get_action_and_value2( |
| params: flax.core.FrozenDict, |
| x: np.ndarray, |
| action: np.ndarray, |
| action_dim: int, |
| ): |
| hidden = Network().apply(params.network_params, x) |
| logits = Actor(action_dim).apply(params.actor_params, hidden) |
| logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
| logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) |
| logits = logits.clip(min=jnp.finfo(logits.dtype).min) |
| p_log_p = logits * jax.nn.softmax(logits) |
| entropy = -p_log_p.sum(-1) |
| value = Critic().apply(params.critic_params, hidden).squeeze() |
| return logprob, entropy, value |
|
|
|
|
| @jax.jit |
| def compute_gae( |
| env_ids: np.ndarray, |
| rewards: np.ndarray, |
| values: np.ndarray, |
| dones: np.ndarray, |
| ): |
| dones = jnp.asarray(dones) |
| values = jnp.asarray(values) |
| env_ids = jnp.asarray(env_ids) |
| rewards = jnp.asarray(rewards) |
|
|
| _, B = env_ids.shape |
| final_env_id_checked = jnp.zeros(args.num_envs, jnp.int32) - 1 |
| final_env_ids = jnp.zeros(B, jnp.int32) |
| advantages = jnp.zeros(B) |
| lastgaelam = jnp.zeros(args.num_envs) |
| lastdones = jnp.zeros(args.num_envs) + 1 |
| lastvalues = jnp.zeros(args.num_envs) |
|
|
| def compute_gae_once(carry, x): |
| lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry |
| ( |
| done, |
| value, |
| eid, |
| reward, |
| ) = x |
| nextnonterminal = 1.0 - lastdones[eid] |
| nextvalues = lastvalues[eid] |
| delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value) |
| advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid] |
| final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0) |
| final_env_id_checked = final_env_id_checked.at[eid].set( |
| jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid]) |
| ) |
|
|
| |
| lastgaelam = lastgaelam.at[eid].set(advantages) |
| lastdones = lastdones.at[eid].set(done) |
| lastvalues = lastvalues.at[eid].set(value) |
| return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), ( |
| advantages, |
| final_env_ids, |
| ) |
|
|
| (_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan( |
| compute_gae_once, |
| ( |
| lastvalues, |
| lastdones, |
| advantages, |
| lastgaelam, |
| final_env_ids, |
| final_env_id_checked, |
| ), |
| ( |
| dones, |
| values, |
| env_ids, |
| rewards, |
| ), |
| reverse=True, |
| ) |
| return advantages, advantages + values, final_env_id_checked, final_env_ids |
|
|
|
|
| def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, action_dim): |
| newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, action_dim) |
| logratio = newlogprob - logp |
| ratio = jnp.exp(logratio) |
| approx_kl = ((ratio - 1) - logratio).mean() |
|
|
| if args.norm_adv: |
| mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) |
|
|
| |
| pg_loss1 = -mb_advantages * ratio |
| pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) |
| pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() |
|
|
| |
| v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() |
|
|
| entropy_loss = entropy.mean() |
| loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef |
| return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) |
|
|
|
|
| @partial(jax.jit, static_argnums=(6)) |
| def single_device_update( |
| agent_state: TrainState, |
| b_obs, |
| b_actions, |
| b_logprobs, |
| b_advantages, |
| b_returns, |
| action_dim, |
| key: jax.random.PRNGKey, |
| ): |
| ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) |
|
|
| def update_epoch(carry, _): |
| agent_state, key = carry |
| key, subkey = jax.random.split(key) |
|
|
| |
| def convert_data(x: jnp.ndarray): |
| x = jax.random.permutation(subkey, x) |
| x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) |
| return x |
|
|
| def update_minibatch(agent_state, minibatch): |
| mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch |
| (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( |
| agent_state.params, |
| mb_obs, |
| mb_actions, |
| mb_logprobs, |
| mb_advantages, |
| mb_returns, |
| action_dim, |
| ) |
| grads = jax.lax.pmean(grads, axis_name="devices") |
| agent_state = agent_state.apply_gradients(grads=grads) |
| return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
|
|
| agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( |
| update_minibatch, |
| agent_state, |
| ( |
| convert_data(b_obs), |
| convert_data(b_actions), |
| convert_data(b_logprobs), |
| convert_data(b_advantages), |
| convert_data(b_returns), |
| ), |
| ) |
| return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
|
|
| (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan( |
| update_epoch, (agent_state, key), (), length=args.update_epochs |
| ) |
| return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key |
|
|
|
|
| if __name__ == "__main__": |
| devices = jax.devices("gpu") |
| args = parse_args() |
| run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{uuid.uuid4()}" |
| if args.track: |
| import wandb |
|
|
| wandb.init( |
| project=args.wandb_project_name, |
| entity=args.wandb_entity, |
| sync_tensorboard=True, |
| config=vars(args), |
| name=run_name, |
| monitor_gym=True, |
| save_code=True, |
| ) |
| print(devices) |
| writer = SummaryWriter(f"runs/{run_name}") |
| writer.add_text( |
| "hyperparameters", |
| "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
| ) |
|
|
| |
| random.seed(args.seed) |
| np.random.seed(args.seed) |
| key = jax.random.PRNGKey(args.seed) |
| key, network_key, actor_key, critic_key = jax.random.split(key, 4) |
|
|
| |
| envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size)() |
| assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" |
|
|
| def linear_schedule(count): |
| |
| |
| frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates |
| return args.learning_rate * frac |
|
|
| network = Network() |
| actor = Actor(action_dim=envs.single_action_space.n) |
| critic = Critic() |
| network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) |
| agent_state = TrainState.create( |
| apply_fn=None, |
| params=AgentParams( |
| network_params, |
| actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
| critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
| ), |
| tx=optax.chain( |
| optax.clip_by_global_norm(args.max_grad_norm), |
| optax.inject_hyperparams(optax.adam)( |
| learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 |
| ), |
| ), |
| ) |
| learner_devices = [devices[d_id] for d_id in args.learner_device_ids] |
| actor_devices = [devices[d_id] for d_id in args.actor_device_ids] |
| agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) |
|
|
| multi_device_update = jax.pmap( |
| single_device_update, |
| axis_name="devices", |
| devices=learner_devices, |
| in_axes=(0, 0, 0, 0, 0, 0, None, None), |
| out_axes=(0, 0, 0, 0, 0, 0, None), |
| static_broadcasted_argnums=(6), |
| ) |
|
|
| rollout_queue = queue.Queue(maxsize=1) |
| params_queues = [] |
| num_cpus = mp.cpu_count() |
| fair_num_cpus = num_cpus // len(args.actor_device_ids) |
|
|
| class DummyWriter: |
| def add_scalar(self, arg0, arg1, arg3): |
| pass |
|
|
| dummy_writer = DummyWriter() |
| for d_idx, d_id in enumerate(args.actor_device_ids): |
| for j in range(args.num_actor_threads): |
| params_queue = queue.Queue(maxsize=1) |
| params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id])) |
| threading.Thread( |
| target=rollout, |
| args=( |
| j, |
| fair_num_cpus if args.num_actor_threads > 1 else None, |
| j * args.num_actor_threads if args.num_actor_threads > 1 else -1, |
| jax.device_put(key, devices[d_id]), |
| args, |
| rollout_queue, |
| params_queue, |
| writer if d_idx == 0 and j == 0 else dummy_writer, |
| learner_devices, |
| ), |
| ).start() |
| params_queues.append(params_queue) |
|
|
| rollout_queue_get_time = deque(maxlen=10) |
| learner_policy_version = 0 |
| while True: |
| learner_policy_version += 1 |
| if learner_policy_version == 1 or not args.test_actor_learner_throughput: |
| rollout_queue_get_time_start = time.time() |
| ( |
| global_step, |
| actor_policy_version, |
| update, |
| b_obs, |
| b_actions, |
| b_logprobs, |
| b_advantages, |
| b_returns, |
| ) = rollout_queue.get() |
| rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) |
| writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step) |
|
|
| training_time_start = time.time() |
| (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key) = multi_device_update( |
| agent_state, |
| jax.device_put_sharded(b_obs, learner_devices), |
| jax.device_put_sharded(b_actions, learner_devices), |
| jax.device_put_sharded(b_logprobs, learner_devices), |
| jax.device_put_sharded(b_advantages, learner_devices), |
| jax.device_put_sharded(b_returns, learner_devices), |
| envs.single_action_space.n, |
| key, |
| ) |
| if learner_policy_version == 1 or not args.test_actor_learner_throughput: |
| for d_idx, d_id in enumerate(args.actor_device_ids): |
| for j in range(args.num_actor_threads): |
| params_queues[d_idx * args.num_actor_threads + j].put( |
| jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]) |
| ) |
| if args.profile: |
| v_loss[-1, -1, -1].block_until_ready() |
| writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step) |
| writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step) |
| writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step) |
| print( |
| global_step, |
| f"actor_policy_version={actor_policy_version}, actor_update={update}, learner_policy_version={learner_policy_version}, training time: {time.time() - training_time_start}s", |
| ) |
|
|
| |
| writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step) |
| writer.add_scalar("losses/value_loss", v_loss[-1, -1, -1].item(), global_step) |
| writer.add_scalar("losses/policy_loss", pg_loss[-1, -1, -1].item(), global_step) |
| writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step) |
| writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step) |
| writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step) |
| if update >= args.num_updates: |
| break |
|
|
| if args.save_model: |
| agent_state = flax.jax_utils.unreplicate(agent_state) |
| model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" |
| with open(model_path, "wb") as f: |
| f.write( |
| flax.serialization.to_bytes( |
| [ |
| vars(args), |
| [ |
| agent_state.params.network_params, |
| agent_state.params.actor_params, |
| agent_state.params.critic_params, |
| ], |
| ] |
| ) |
| ) |
| print(f"model saved to {model_path}") |
| from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate |
|
|
| episodic_returns = evaluate( |
| model_path, |
| make_env, |
| args.env_id, |
| eval_episodes=10, |
| run_name=f"{run_name}-eval", |
| Model=(Network, Actor, Critic), |
| ) |
| for idx, episodic_return in enumerate(episodic_returns): |
| writer.add_scalar("eval/episodic_return", episodic_return, idx) |
|
|
| if args.upload_model: |
| from cleanrl_utils.huggingface import push_to_hub |
|
|
| repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" |
| repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name |
| push_to_hub( |
| args, |
| episodic_returns, |
| repo_id, |
| "PPO", |
| f"runs/{run_name}", |
| f"videos/{run_name}-eval", |
| extra_dependencies=["jax", "envpool", "atari"], |
| ) |
|
|
| envs.close() |
| writer.close() |
|
|