| | """ |
| | 0. multi-threaded actor |
| | python sebulba_ppo_envpool.py --actor-device-ids 0 --num-actor-threads 2 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | |
| | 🔥 core settings: |
| | |
| | * test throughput |
| | * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l1_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l12_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | * this will help us diagnose the throughput issue |
| | * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l1 --actor-device-ids 0 --learner-device-ids 1 --profile --total-timesteps 500000 --track |
| | * python sebulba_ppo_envpool.py --exp-name sebula_thpt_a0_l12 --actor-device-ids 0 --learner-device-ids 1 2 --profile --total-timesteps 500000 --track |
| | * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --num-actor-threads 2 --track |
| | * Best performance so far |
| | * python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track |
| | * python sebulba_ppo_envpool.py --actor-device-ids 0 --learner-device-ids 1 2 3 4 --params-queue-timeout 0.02 --track |
| | |
| | # 1. rollout is faster than training |
| | |
| | ## throughput |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0,1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1,2 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track |
| | |
| | |
| | # 1.1 rollout is faster than training w/ timeout |
| | |
| | ## shared: actor on GPU0 and learner on GPU0 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0,1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1,2 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | # 1.2. rollout is much faster than training w/ timeout |
| | |
| | ## throughput |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 --update-epochs 8 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0,1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 0 1 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1,2 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_rollout_is_much_faster_timeout --actor-device-ids 0 --learner-device-ids 1 2 --update-epochs 8 --params-queue-timeout 0.02 --total-timesteps 500000 --track |
| | |
| | # 2. training is faster than rollout |
| | |
| | ## throughput |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_thpt_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --params-queue-timeout 0.02 --profile --test-actor-learner-throughput --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_1gpu_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l1_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 --total-timesteps 500000 --track |
| | |
| | ## shared: actor on GPU0 and learner on GPU0,1 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l01_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 0 1 --total-timesteps 500000 --track |
| | |
| | ## separate: actor on GPU0 and learner on GPU1,2 |
| | python sebulba_ppo_envpool.py --exp-name sebulba_ppo_envpool_a0_l12_training_is_faster --update-epochs 1 --async-batch-size 64 --actor-device-ids 0 --learner-device-ids 1 2 --total-timesteps 500000 --track |
| | |
| | """ |
| | |
| | |
| | import argparse |
| | import os |
| | import random |
| | import time |
| | import uuid |
| | from collections import deque |
| | from distutils.util import strtobool |
| | from functools import partial |
| | from typing import Sequence |
| |
|
| | os.environ[ |
| | "XLA_PYTHON_CLIENT_MEM_FRACTION" |
| | ] = "0.6" |
| | os.environ["XLA_FLAGS"] = "--xla_cpu_multi_thread_eigen=false " "intra_op_parallelism_threads=1" |
| | import multiprocessing as mp |
| | import queue |
| | import threading |
| |
|
| | import envpool |
| | import flax |
| | import flax.linen as nn |
| | import gym |
| | import jax |
| | import jax.numpy as jnp |
| | import numpy as np |
| | import optax |
| | from flax.linen.initializers import constant, orthogonal |
| | from flax.training.train_state import TrainState |
| | from torch.utils.tensorboard import SummaryWriter |
| |
|
| |
|
| | def parse_args(): |
| | |
| | parser = argparse.ArgumentParser() |
| | parser.add_argument("--exp-name", type=str, default=os.path.basename(__file__).rstrip(".py"), |
| | help="the name of this experiment") |
| | parser.add_argument("--seed", type=int, default=1, |
| | help="seed of the experiment") |
| | parser.add_argument("--torch-deterministic", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| | help="if toggled, `torch.backends.cudnn.deterministic=False`") |
| | parser.add_argument("--cuda", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| | help="if toggled, cuda will be enabled by default") |
| | parser.add_argument("--track", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="if toggled, this experiment will be tracked with Weights and Biases") |
| | parser.add_argument("--wandb-project-name", type=str, default="cleanRL", |
| | help="the wandb's project name") |
| | parser.add_argument("--wandb-entity", type=str, default=None, |
| | help="the entity (team) of wandb's project") |
| | parser.add_argument("--capture-video", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="weather to capture videos of the agent performances (check out `videos` folder)") |
| | parser.add_argument("--save-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="whether to save model into the `runs/{run_name}` folder") |
| | parser.add_argument("--upload-model", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="whether to upload the saved model to huggingface") |
| | parser.add_argument("--hf-entity", type=str, default="", |
| | help="the user or org name of the model repository from the Hugging Face Hub") |
| |
|
| | |
| | parser.add_argument("--env-id", type=str, default="Breakout-v5", |
| | help="the id of the environment") |
| | parser.add_argument("--total-timesteps", type=int, default=50000000, |
| | help="total timesteps of the experiments") |
| | parser.add_argument("--learning-rate", type=float, default=2.5e-4, |
| | help="the learning rate of the optimizer") |
| | parser.add_argument("--num-envs", type=int, default=64, |
| | help="the number of parallel game environments") |
| | parser.add_argument("--async-batch-size", type=int, default=16, |
| | help="the envpool's batch size in the async mode") |
| | parser.add_argument("--num-steps", type=int, default=128, |
| | help="the number of steps to run in each environment per policy rollout") |
| | parser.add_argument("--anneal-lr", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| | help="Toggle learning rate annealing for policy and value networks") |
| | parser.add_argument("--gamma", type=float, default=0.99, |
| | help="the discount factor gamma") |
| | parser.add_argument("--gae-lambda", type=float, default=0.95, |
| | help="the lambda for the general advantage estimation") |
| | parser.add_argument("--num-minibatches", type=int, default=4, |
| | help="the number of mini-batches") |
| | parser.add_argument("--update-epochs", type=int, default=4, |
| | help="the K epochs to update the policy") |
| | parser.add_argument("--norm-adv", type=lambda x: bool(strtobool(x)), default=True, nargs="?", const=True, |
| | help="Toggles advantages normalization") |
| | parser.add_argument("--clip-coef", type=float, default=0.1, |
| | help="the surrogate clipping coefficient") |
| | parser.add_argument("--ent-coef", type=float, default=0.01, |
| | help="coefficient of the entropy") |
| | parser.add_argument("--vf-coef", type=float, default=0.5, |
| | help="coefficient of the value function") |
| | parser.add_argument("--max-grad-norm", type=float, default=0.5, |
| | help="the maximum norm for the gradient clipping") |
| | parser.add_argument("--target-kl", type=float, default=None, |
| | help="the target KL divergence threshold") |
| |
|
| | parser.add_argument("--actor-device-ids", type=int, nargs="+", default=[0], |
| | help="the device ids that actor workers will use") |
| | parser.add_argument("--learner-device-ids", type=int, nargs="+", default=[0], |
| | help="the device ids that actor workers will use") |
| | parser.add_argument("--num-actor-threads", type=int, default=1, |
| | help="the number of actor threads") |
| | parser.add_argument("--profile", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="whether to call block_until_ready() for profiling") |
| | parser.add_argument("--test-actor-learner-throughput", type=lambda x: bool(strtobool(x)), default=False, nargs="?", const=True, |
| | help="whether to test actor-learner throughput by removing the actor-learner communication") |
| | parser.add_argument("--params-queue-timeout", type=float, default=None, |
| | help="the timeout for the `params_queue.get()` operation in the actor thread to pull params;" + \ |
| | "by default it's `None`; if you set a timeout, it will likely make the actor run faster but will introduce some side effects," + \ |
| | "such as the actor will not be able to pull the latest params from the learner and will use the old params instead") |
| | args = parser.parse_args() |
| | args.batch_size = int(args.num_envs * args.num_steps) |
| | args.minibatch_size = int(args.batch_size // args.num_minibatches) |
| | args.num_updates = args.total_timesteps // args.batch_size |
| | args.async_update = int(args.num_envs / args.async_batch_size) |
| | assert len(args.actor_device_ids) == 1, "only 1 actor_device_ids is supported now" |
| | |
| | return args |
| |
|
| |
|
| | LEARNER_WARMUP_TIME = 10 |
| |
|
| |
|
| | def make_env(env_id, seed, num_envs, async_batch_size=1, num_threads=None, thread_affinity_offset=-1): |
| | def thunk(): |
| | envs = envpool.make( |
| | env_id, |
| | env_type="gym", |
| | num_envs=num_envs, |
| | num_threads=num_threads if num_threads is not None else async_batch_size, |
| | thread_affinity_offset=thread_affinity_offset, |
| | batch_size=async_batch_size, |
| | episodic_life=False, |
| | repeat_action_probability=0.25, |
| | noop_max=1, |
| | full_action_space=True, |
| | max_episode_steps=int(108000 / 4), |
| | reward_clip=True, |
| | seed=seed, |
| | ) |
| | envs.num_envs = num_envs |
| | envs.single_action_space = envs.action_space |
| | envs.single_observation_space = envs.observation_space |
| | envs.is_vector_env = True |
| | return envs |
| |
|
| | return thunk |
| |
|
| |
|
| | class ResidualBlock(nn.Module): |
| | channels: int |
| |
|
| | @nn.compact |
| | def __call__(self, x): |
| | inputs = x |
| | x = nn.relu(x) |
| | x = nn.Conv( |
| | self.channels, |
| | kernel_size=(3, 3), |
| | )(x) |
| | x = nn.relu(x) |
| | x = nn.Conv( |
| | self.channels, |
| | kernel_size=(3, 3), |
| | )(x) |
| | return x + inputs |
| |
|
| |
|
| | class ConvSequence(nn.Module): |
| | channels: int |
| |
|
| | @nn.compact |
| | def __call__(self, x): |
| | x = nn.Conv( |
| | self.channels, |
| | kernel_size=(3, 3), |
| | )(x) |
| | x = nn.max_pool(x, window_shape=(3, 3), strides=(2, 2), padding="SAME") |
| | x = ResidualBlock(self.channels)(x) |
| | x = ResidualBlock(self.channels)(x) |
| | return x |
| |
|
| |
|
| | class Network(nn.Module): |
| | channelss: Sequence[int] = (16, 32, 32) |
| |
|
| | @nn.compact |
| | def __call__(self, x): |
| | x = jnp.transpose(x, (0, 2, 3, 1)) |
| | x = x / (255.0) |
| | for channels in self.channelss: |
| | x = ConvSequence(channels)(x) |
| | x = nn.relu(x) |
| | x = x.reshape((x.shape[0], -1)) |
| | x = nn.Dense(256, kernel_init=orthogonal(np.sqrt(2)), bias_init=constant(0.0))(x) |
| | x = nn.relu(x) |
| | return x |
| |
|
| |
|
| | class Critic(nn.Module): |
| | @nn.compact |
| | def __call__(self, x): |
| | return nn.Dense(1, kernel_init=orthogonal(1), bias_init=constant(0.0))(x) |
| |
|
| |
|
| | class Actor(nn.Module): |
| | action_dim: int |
| |
|
| | @nn.compact |
| | def __call__(self, x): |
| | return nn.Dense(self.action_dim, kernel_init=orthogonal(0.01), bias_init=constant(0.0))(x) |
| |
|
| |
|
| | @flax.struct.dataclass |
| | class AgentParams: |
| | network_params: flax.core.FrozenDict |
| | actor_params: flax.core.FrozenDict |
| | critic_params: flax.core.FrozenDict |
| |
|
| |
|
| | @partial(jax.jit, static_argnums=(3)) |
| | def get_action_and_value( |
| | params: TrainState, |
| | next_obs: np.ndarray, |
| | key: jax.random.PRNGKey, |
| | action_dim: int, |
| | ): |
| | hidden = Network().apply(params.network_params, next_obs) |
| | logits = Actor(action_dim).apply(params.actor_params, hidden) |
| | |
| | |
| | key, subkey = jax.random.split(key) |
| | u = jax.random.uniform(subkey, shape=logits.shape) |
| | action = jnp.argmax(logits - jnp.log(-jnp.log(u)), axis=1) |
| | logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
| | value = Critic().apply(params.critic_params, hidden) |
| | return action, logprob, value.squeeze(), key |
| |
|
| |
|
| | @jax.jit |
| | def prepare_data( |
| | obs: list, |
| | dones: list, |
| | values: list, |
| | actions: list, |
| | logprobs: list, |
| | env_ids: list, |
| | rewards: list, |
| | ): |
| | obs = jnp.asarray(obs) |
| | dones = jnp.asarray(dones) |
| | values = jnp.asarray(values) |
| | actions = jnp.asarray(actions) |
| | logprobs = jnp.asarray(logprobs) |
| | env_ids = jnp.asarray(env_ids) |
| | rewards = jnp.asarray(rewards) |
| |
|
| | |
| | T, B = env_ids.shape |
| | index_ranges = jnp.arange(T * B, dtype=jnp.int32) |
| | next_index_ranges = jnp.zeros_like(index_ranges, dtype=jnp.int32) |
| | last_env_ids = jnp.zeros(args.num_envs, dtype=jnp.int32) - 1 |
| |
|
| | def f(carry, x): |
| | last_env_ids, next_index_ranges = carry |
| | env_id, index_range = x |
| | next_index_ranges = next_index_ranges.at[last_env_ids[env_id]].set( |
| | jnp.where(last_env_ids[env_id] != -1, index_range, next_index_ranges[last_env_ids[env_id]]) |
| | ) |
| | last_env_ids = last_env_ids.at[env_id].set(index_range) |
| | return (last_env_ids, next_index_ranges), None |
| |
|
| | (last_env_ids, next_index_ranges), _ = jax.lax.scan( |
| | f, |
| | (last_env_ids, next_index_ranges), |
| | (env_ids.reshape(-1), index_ranges), |
| | ) |
| |
|
| | |
| | rewards = rewards.reshape(-1)[next_index_ranges].reshape((args.num_steps) * args.async_update, args.async_batch_size) |
| | advantages, returns, _, final_env_ids = compute_gae(env_ids, rewards, values, dones) |
| | |
| | b_obs = obs.reshape((-1,) + obs.shape[2:]) |
| | b_actions = actions.reshape(-1) |
| | b_logprobs = logprobs.reshape(-1) |
| | b_advantages = advantages.reshape(-1) |
| | b_returns = returns.reshape(-1) |
| | return b_obs, b_actions, b_logprobs, b_advantages, b_returns |
| |
|
| |
|
| | def rollout( |
| | i, |
| | num_threads, |
| | thread_affinity_offset, |
| | key: jax.random.PRNGKey, |
| | args, |
| | rollout_queue, |
| | params_queue: queue.Queue, |
| | writer, |
| | learner_devices, |
| | ): |
| | envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size, num_threads, thread_affinity_offset)() |
| | len_actor_device_ids = len(args.actor_device_ids) |
| | global_step = 0 |
| | |
| | start_time = time.time() |
| |
|
| | |
| | episode_returns = np.zeros((args.num_envs,), dtype=np.float32) |
| | returned_episode_returns = np.zeros((args.num_envs,), dtype=np.float32) |
| | episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) |
| | returned_episode_lengths = np.zeros((args.num_envs,), dtype=np.float32) |
| | envs.async_reset() |
| |
|
| | params_queue_get_time = deque(maxlen=10) |
| | rollout_time = deque(maxlen=10) |
| | data_transfer_time = deque(maxlen=10) |
| | rollout_queue_put_time = deque(maxlen=10) |
| | params_timeout_count = 0 |
| | for update in range(1, args.num_updates + 2): |
| | update_time_start = time.time() |
| | obs = [] |
| | dones = [] |
| | actions = [] |
| | logprobs = [] |
| | values = [] |
| | env_ids = [] |
| | rewards = [] |
| | truncations = [] |
| | terminations = [] |
| | env_recv_time = 0 |
| | inference_time = 0 |
| | storage_time = 0 |
| | env_send_time = 0 |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | params_queue_get_time_start = time.time() |
| | try: |
| | params = params_queue.get(timeout=args.params_queue_timeout) |
| | except queue.Empty: |
| | |
| | params_timeout_count += 1 |
| | params_queue_get_time.append(time.time() - params_queue_get_time_start) |
| | writer.add_scalar("stats/params_queue_get_time", np.mean(params_queue_get_time), global_step) |
| | writer.add_scalar("stats/params_queue_timeout_count", params_timeout_count, global_step) |
| | rollout_time_start = time.time() |
| | for _ in range( |
| | args.async_update, (args.num_steps + 1) * args.async_update |
| | ): |
| | env_recv_time_start = time.time() |
| | next_obs, next_reward, next_done, info = envs.recv() |
| | env_recv_time += time.time() - env_recv_time_start |
| | global_step += len(next_done) * args.num_actor_threads * len_actor_device_ids |
| | env_id = info["env_id"] |
| |
|
| | inference_time_start = time.time() |
| | action, logprob, value, key = get_action_and_value(params, next_obs, key, envs.single_action_space.n) |
| | inference_time += time.time() - inference_time_start |
| |
|
| | env_send_time_start = time.time() |
| | envs.send(np.array(action), env_id) |
| | env_send_time += time.time() - env_send_time_start |
| | storage_time_start = time.time() |
| | obs.append(next_obs) |
| | dones.append(next_done) |
| | values.append(value) |
| | actions.append(action) |
| | logprobs.append(logprob) |
| | env_ids.append(env_id) |
| | rewards.append(next_reward) |
| | truncations.append(info["TimeLimit.truncated"]) |
| | terminations.append(info["terminated"]) |
| | episode_returns[env_id] += info["reward"] |
| | returned_episode_returns[env_id] = np.where( |
| | info["terminated"] + info["TimeLimit.truncated"], episode_returns[env_id], returned_episode_returns[env_id] |
| | ) |
| | episode_returns[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]) |
| | episode_lengths[env_id] += 1 |
| | returned_episode_lengths[env_id] = np.where( |
| | info["terminated"] + info["TimeLimit.truncated"], episode_lengths[env_id], returned_episode_lengths[env_id] |
| | ) |
| | episode_lengths[env_id] *= (1 - info["terminated"]) * (1 - info["TimeLimit.truncated"]) |
| | storage_time += time.time() - storage_time_start |
| | if args.profile: |
| | action.block_until_ready() |
| | rollout_time.append(time.time() - rollout_time_start) |
| | writer.add_scalar("stats/rollout_time", np.mean(rollout_time), global_step) |
| |
|
| | avg_episodic_return = np.mean(returned_episode_returns) |
| | writer.add_scalar("charts/avg_episodic_return", avg_episodic_return, global_step) |
| | writer.add_scalar("charts/avg_episodic_length", np.mean(returned_episode_lengths), global_step) |
| | if i == 0: |
| | print(f"global_step={global_step}, avg_episodic_return={avg_episodic_return}") |
| | print("SPS:", int(global_step / (time.time() - start_time))) |
| | writer.add_scalar("charts/SPS", int(global_step / (time.time() - start_time)), global_step) |
| |
|
| | writer.add_scalar("stats/truncations", np.sum(truncations), global_step) |
| | writer.add_scalar("stats/terminations", np.sum(terminations), global_step) |
| | writer.add_scalar("stats/env_recv_time", env_recv_time, global_step) |
| | writer.add_scalar("stats/inference_time", inference_time, global_step) |
| | writer.add_scalar("stats/storage_time", storage_time, global_step) |
| | writer.add_scalar("stats/env_send_time", env_send_time, global_step) |
| |
|
| | data_transfer_time_start = time.time() |
| | b_obs, b_actions, b_logprobs, b_advantages, b_returns = prepare_data( |
| | obs, |
| | dones, |
| | values, |
| | actions, |
| | logprobs, |
| | env_ids, |
| | rewards, |
| | ) |
| | payload = ( |
| | global_step, |
| | update, |
| | jnp.array_split(b_obs, len(learner_devices)), |
| | jnp.array_split(b_actions, len(learner_devices)), |
| | jnp.array_split(b_logprobs, len(learner_devices)), |
| | jnp.array_split(b_advantages, len(learner_devices)), |
| | jnp.array_split(b_returns, len(learner_devices)), |
| | ) |
| | if args.profile: |
| | payload[2][0].block_until_ready() |
| | data_transfer_time.append(time.time() - data_transfer_time_start) |
| | writer.add_scalar("stats/data_transfer_time", np.mean(data_transfer_time), global_step) |
| | if update == 1 or not args.test_actor_learner_throughput: |
| | rollout_queue_put_time_start = time.time() |
| | rollout_queue.put(payload) |
| | rollout_queue_put_time.append(time.time() - rollout_queue_put_time_start) |
| | writer.add_scalar("stats/rollout_queue_put_time", np.mean(rollout_queue_put_time), global_step) |
| |
|
| | if update == 1 or update == 2 or update == 3: |
| | time.sleep(LEARNER_WARMUP_TIME) |
| |
|
| | writer.add_scalar( |
| | "charts/SPS_update", |
| | int( |
| | args.num_envs |
| | * args.num_steps |
| | * args.num_actor_threads |
| | * len_actor_device_ids |
| | / (time.time() - update_time_start) |
| | ), |
| | global_step, |
| | ) |
| |
|
| |
|
| | @partial(jax.jit, static_argnums=(3)) |
| | def get_action_and_value2( |
| | params: flax.core.FrozenDict, |
| | x: np.ndarray, |
| | action: np.ndarray, |
| | action_dim: int, |
| | ): |
| | hidden = Network().apply(params.network_params, x) |
| | logits = Actor(action_dim).apply(params.actor_params, hidden) |
| | logprob = jax.nn.log_softmax(logits)[jnp.arange(action.shape[0]), action] |
| | logits = logits - jax.scipy.special.logsumexp(logits, axis=-1, keepdims=True) |
| | logits = logits.clip(min=jnp.finfo(logits.dtype).min) |
| | p_log_p = logits * jax.nn.softmax(logits) |
| | entropy = -p_log_p.sum(-1) |
| | value = Critic().apply(params.critic_params, hidden).squeeze() |
| | return logprob, entropy, value |
| |
|
| |
|
| | @jax.jit |
| | def compute_gae( |
| | env_ids: np.ndarray, |
| | rewards: np.ndarray, |
| | values: np.ndarray, |
| | dones: np.ndarray, |
| | ): |
| | dones = jnp.asarray(dones) |
| | values = jnp.asarray(values) |
| | env_ids = jnp.asarray(env_ids) |
| | rewards = jnp.asarray(rewards) |
| |
|
| | _, B = env_ids.shape |
| | final_env_id_checked = jnp.zeros(args.num_envs, jnp.int32) - 1 |
| | final_env_ids = jnp.zeros(B, jnp.int32) |
| | advantages = jnp.zeros(B) |
| | lastgaelam = jnp.zeros(args.num_envs) |
| | lastdones = jnp.zeros(args.num_envs) + 1 |
| | lastvalues = jnp.zeros(args.num_envs) |
| |
|
| | def compute_gae_once(carry, x): |
| | lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked = carry |
| | ( |
| | done, |
| | value, |
| | eid, |
| | reward, |
| | ) = x |
| | nextnonterminal = 1.0 - lastdones[eid] |
| | nextvalues = lastvalues[eid] |
| | delta = jnp.where(final_env_id_checked[eid] == -1, 0, reward + args.gamma * nextvalues * nextnonterminal - value) |
| | advantages = delta + args.gamma * args.gae_lambda * nextnonterminal * lastgaelam[eid] |
| | final_env_ids = jnp.where(final_env_id_checked[eid] == 1, 1, 0) |
| | final_env_id_checked = final_env_id_checked.at[eid].set( |
| | jnp.where(final_env_id_checked[eid] == -1, 1, final_env_id_checked[eid]) |
| | ) |
| |
|
| | |
| | lastgaelam = lastgaelam.at[eid].set(advantages) |
| | lastdones = lastdones.at[eid].set(done) |
| | lastvalues = lastvalues.at[eid].set(value) |
| | return (lastvalues, lastdones, advantages, lastgaelam, final_env_ids, final_env_id_checked), ( |
| | advantages, |
| | final_env_ids, |
| | ) |
| |
|
| | (_, _, _, _, final_env_ids, final_env_id_checked), (advantages, final_env_ids) = jax.lax.scan( |
| | compute_gae_once, |
| | ( |
| | lastvalues, |
| | lastdones, |
| | advantages, |
| | lastgaelam, |
| | final_env_ids, |
| | final_env_id_checked, |
| | ), |
| | ( |
| | dones, |
| | values, |
| | env_ids, |
| | rewards, |
| | ), |
| | reverse=True, |
| | ) |
| | return advantages, advantages + values, final_env_id_checked, final_env_ids |
| |
|
| |
|
| | def ppo_loss(params, x, a, logp, mb_advantages, mb_returns, action_dim): |
| | newlogprob, entropy, newvalue = get_action_and_value2(params, x, a, action_dim) |
| | logratio = newlogprob - logp |
| | ratio = jnp.exp(logratio) |
| | approx_kl = ((ratio - 1) - logratio).mean() |
| |
|
| | if args.norm_adv: |
| | mb_advantages = (mb_advantages - mb_advantages.mean()) / (mb_advantages.std() + 1e-8) |
| |
|
| | |
| | pg_loss1 = -mb_advantages * ratio |
| | pg_loss2 = -mb_advantages * jnp.clip(ratio, 1 - args.clip_coef, 1 + args.clip_coef) |
| | pg_loss = jnp.maximum(pg_loss1, pg_loss2).mean() |
| |
|
| | |
| | v_loss = 0.5 * ((newvalue - mb_returns) ** 2).mean() |
| |
|
| | entropy_loss = entropy.mean() |
| | loss = pg_loss - args.ent_coef * entropy_loss + v_loss * args.vf_coef |
| | return loss, (pg_loss, v_loss, entropy_loss, jax.lax.stop_gradient(approx_kl)) |
| |
|
| |
|
| | @partial(jax.jit, static_argnums=(6)) |
| | def single_device_update( |
| | agent_state: TrainState, |
| | b_obs, |
| | b_actions, |
| | b_logprobs, |
| | b_advantages, |
| | b_returns, |
| | action_dim, |
| | key: jax.random.PRNGKey, |
| | ): |
| | ppo_loss_grad_fn = jax.value_and_grad(ppo_loss, has_aux=True) |
| |
|
| | def update_epoch(carry, _): |
| | agent_state, key = carry |
| | key, subkey = jax.random.split(key) |
| |
|
| | |
| | def convert_data(x: jnp.ndarray): |
| | x = jax.random.permutation(subkey, x) |
| | x = jnp.reshape(x, (args.num_minibatches, -1) + x.shape[1:]) |
| | return x |
| |
|
| | def update_minibatch(agent_state, minibatch): |
| | mb_obs, mb_actions, mb_logprobs, mb_advantages, mb_returns = minibatch |
| | (loss, (pg_loss, v_loss, entropy_loss, approx_kl)), grads = ppo_loss_grad_fn( |
| | agent_state.params, |
| | mb_obs, |
| | mb_actions, |
| | mb_logprobs, |
| | mb_advantages, |
| | mb_returns, |
| | action_dim, |
| | ) |
| | grads = jax.lax.pmean(grads, axis_name="devices") |
| | agent_state = agent_state.apply_gradients(grads=grads) |
| | return agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
| |
|
| | agent_state, (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) = jax.lax.scan( |
| | update_minibatch, |
| | agent_state, |
| | ( |
| | convert_data(b_obs), |
| | convert_data(b_actions), |
| | convert_data(b_logprobs), |
| | convert_data(b_advantages), |
| | convert_data(b_returns), |
| | ), |
| | ) |
| | return (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, grads) |
| |
|
| | (agent_state, key), (loss, pg_loss, v_loss, entropy_loss, approx_kl, _) = jax.lax.scan( |
| | update_epoch, (agent_state, key), (), length=args.update_epochs |
| | ) |
| | return agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key |
| |
|
| |
|
| | if __name__ == "__main__": |
| | devices = jax.devices("gpu") |
| | args = parse_args() |
| | run_name = f"{args.env_id}__{args.exp_name}__{args.seed}__{uuid.uuid4()}" |
| | if args.track: |
| | import wandb |
| |
|
| | wandb.init( |
| | project=args.wandb_project_name, |
| | entity=args.wandb_entity, |
| | sync_tensorboard=True, |
| | config=vars(args), |
| | name=run_name, |
| | monitor_gym=True, |
| | save_code=True, |
| | ) |
| | writer = SummaryWriter(f"runs/{run_name}") |
| | writer.add_text( |
| | "hyperparameters", |
| | "|param|value|\n|-|-|\n%s" % ("\n".join([f"|{key}|{value}|" for key, value in vars(args).items()])), |
| | ) |
| |
|
| | |
| | random.seed(args.seed) |
| | np.random.seed(args.seed) |
| | key = jax.random.PRNGKey(args.seed) |
| | key, network_key, actor_key, critic_key = jax.random.split(key, 4) |
| |
|
| | |
| | envs = make_env(args.env_id, args.seed, args.num_envs, args.async_batch_size)() |
| | assert isinstance(envs.single_action_space, gym.spaces.Discrete), "only discrete action space is supported" |
| |
|
| | def linear_schedule(count): |
| | |
| | |
| | frac = 1.0 - (count // (args.num_minibatches * args.update_epochs)) / args.num_updates |
| | return args.learning_rate * frac |
| |
|
| | network = Network() |
| | actor = Actor(action_dim=envs.single_action_space.n) |
| | critic = Critic() |
| | network_params = network.init(network_key, np.array([envs.single_observation_space.sample()])) |
| | agent_state = TrainState.create( |
| | apply_fn=None, |
| | params=AgentParams( |
| | network_params, |
| | actor.init(actor_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
| | critic.init(critic_key, network.apply(network_params, np.array([envs.single_observation_space.sample()]))), |
| | ), |
| | tx=optax.chain( |
| | optax.clip_by_global_norm(args.max_grad_norm), |
| | optax.inject_hyperparams(optax.adam)( |
| | learning_rate=linear_schedule if args.anneal_lr else args.learning_rate, eps=1e-5 |
| | ), |
| | ), |
| | ) |
| | learner_devices = [devices[d_id] for d_id in args.learner_device_ids] |
| | actor_devices = [devices[d_id] for d_id in args.actor_device_ids] |
| | agent_state = flax.jax_utils.replicate(agent_state, devices=learner_devices) |
| |
|
| | multi_device_update = jax.pmap( |
| | single_device_update, |
| | axis_name="devices", |
| | devices=learner_devices, |
| | in_axes=(0, 0, 0, 0, 0, 0, None, None), |
| | out_axes=(0, 0, 0, 0, 0, 0, None), |
| | static_broadcasted_argnums=(6), |
| | ) |
| |
|
| | rollout_queue = queue.Queue(maxsize=2) |
| | params_queues = [] |
| | num_cpus = mp.cpu_count() |
| | fair_num_cpus = num_cpus // len(args.actor_device_ids) |
| |
|
| | class DummyWriter: |
| | def add_scalar(self, arg0, arg1, arg3): |
| | pass |
| |
|
| | |
| | |
| | |
| |
|
| | dummy_writer = DummyWriter() |
| | for d_idx, d_id in enumerate(args.actor_device_ids): |
| | for j in range(args.num_actor_threads): |
| | params_queue = queue.Queue(maxsize=2) |
| | params_queue.put(jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id])) |
| | threading.Thread( |
| | target=rollout, |
| | args=( |
| | j, |
| | fair_num_cpus if args.num_actor_threads > 1 else None, |
| | j * args.num_actor_threads if args.num_actor_threads > 1 else -1, |
| | jax.device_put(key, devices[d_id]), |
| | args, |
| | rollout_queue, |
| | params_queue, |
| | writer if d_idx == 0 and j == 0 else dummy_writer, |
| | learner_devices, |
| | ), |
| | ).start() |
| | params_queues.append(params_queue) |
| |
|
| | rollout_queue_get_time = deque(maxlen=10) |
| | learner_update = 0 |
| | while True: |
| | learner_update += 1 |
| | if learner_update == 1 or not args.test_actor_learner_throughput: |
| | rollout_queue_get_time_start = time.time() |
| | global_step, update, b_obs, b_actions, b_logprobs, b_advantages, b_returns = rollout_queue.get() |
| | rollout_queue_get_time.append(time.time() - rollout_queue_get_time_start) |
| | writer.add_scalar("stats/rollout_queue_get_time", np.mean(rollout_queue_get_time), global_step) |
| |
|
| | training_time_start = time.time() |
| | (agent_state, loss, pg_loss, v_loss, entropy_loss, approx_kl, key) = multi_device_update( |
| | agent_state, |
| | jax.device_put_sharded(b_obs, learner_devices), |
| | jax.device_put_sharded(b_actions, learner_devices), |
| | jax.device_put_sharded(b_logprobs, learner_devices), |
| | jax.device_put_sharded(b_advantages, learner_devices), |
| | jax.device_put_sharded(b_returns, learner_devices), |
| | envs.single_action_space.n, |
| | key, |
| | ) |
| | if learner_update == 1 or not args.test_actor_learner_throughput: |
| | for d_idx, d_id in enumerate(args.actor_device_ids): |
| | for j in range(args.num_actor_threads): |
| | params_queues[d_idx * args.num_actor_threads + j].put( |
| | jax.device_put(flax.jax_utils.unreplicate(agent_state.params), devices[d_id]) |
| | ) |
| | if args.profile: |
| | v_loss[-1, -1, -1].block_until_ready() |
| | writer.add_scalar("stats/training_time", time.time() - training_time_start, global_step) |
| | writer.add_scalar("stats/rollout_queue_size", rollout_queue.qsize(), global_step) |
| | writer.add_scalar("stats/params_queue_size", params_queue.qsize(), global_step) |
| | print(global_step, update, rollout_queue.qsize(), f"training time: {time.time() - training_time_start}s") |
| |
|
| | |
| | writer.add_scalar("charts/learning_rate", agent_state.opt_state[1].hyperparams["learning_rate"][0].item(), global_step) |
| | writer.add_scalar("losses/value_loss", v_loss[-1, -1, -1].item(), global_step) |
| | writer.add_scalar("losses/policy_loss", pg_loss[-1, -1, -1].item(), global_step) |
| | writer.add_scalar("losses/entropy", entropy_loss[-1, -1, -1].item(), global_step) |
| | writer.add_scalar("losses/approx_kl", approx_kl[-1, -1, -1].item(), global_step) |
| | writer.add_scalar("losses/loss", loss[-1, -1, -1].item(), global_step) |
| | if update > args.num_updates: |
| | break |
| |
|
| | if args.save_model: |
| | agent_state = flax.jax_utils.unreplicate(agent_state) |
| | model_path = f"runs/{run_name}/{args.exp_name}.cleanrl_model" |
| | with open(model_path, "wb") as f: |
| | f.write( |
| | flax.serialization.to_bytes( |
| | [ |
| | vars(args), |
| | [ |
| | agent_state.params.network_params, |
| | agent_state.params.actor_params, |
| | agent_state.params.critic_params, |
| | ], |
| | ] |
| | ) |
| | ) |
| | print(f"model saved to {model_path}") |
| | from cleanrl_utils.evals.ppo_envpool_jax_eval import evaluate |
| |
|
| | episodic_returns = evaluate( |
| | model_path, |
| | make_env, |
| | args.env_id, |
| | eval_episodes=10, |
| | run_name=f"{run_name}-eval", |
| | Model=(Network, Actor, Critic), |
| | ) |
| | for idx, episodic_return in enumerate(episodic_returns): |
| | writer.add_scalar("eval/episodic_return", episodic_return, idx) |
| |
|
| | if args.upload_model: |
| | from cleanrl_utils.huggingface import push_to_hub |
| |
|
| | repo_name = f"{args.env_id}-{args.exp_name}-seed{args.seed}" |
| | repo_id = f"{args.hf_entity}/{repo_name}" if args.hf_entity else repo_name |
| | push_to_hub( |
| | args, |
| | episodic_returns, |
| | repo_id, |
| | "PPO", |
| | f"runs/{run_name}", |
| | f"videos/{run_name}-eval", |
| | extra_dependencies=["jax", "envpool", "atari"], |
| | ) |
| |
|
| | envs.close() |
| | writer.close() |
| |
|