Spaces:
Runtime error
Runtime error
| from functools import partial | |
| import json | |
| import os | |
| import re | |
| import time | |
| from enum import IntEnum | |
| from typing import Tuple | |
| import chex | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| import orbax.checkpoint as ocp | |
| from flax import core, struct | |
| from flax.training.train_state import TrainState as BaseTrainState | |
| import wandb | |
| from jaxued.environments.underspecified_env import EnvParams, EnvState, Observation, UnderspecifiedEnv | |
| from jaxued.level_sampler import LevelSampler | |
| from jaxued.utils import compute_max_returns, max_mc, positive_value_loss | |
| from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name | |
| from kinetix.environment.env_state import StaticEnvParams | |
| from kinetix.environment.utils import permute_pcg_state | |
| from kinetix.environment.wrappers import ( | |
| UnderspecifiedToGymnaxWrapper, | |
| LogWrapper, | |
| DenseRewardWrapper, | |
| AutoReplayWrapper, | |
| ) | |
| from kinetix.models import make_network_from_config | |
| from kinetix.pcg.pcg import env_state_to_pcg_state | |
| from kinetix.render.renderer_pixels import make_render_pixels | |
| from kinetix.models.actor_critic import ScannedRNN | |
| from kinetix.util.saving import ( | |
| expand_pcg_state, | |
| get_pcg_state_from_json, | |
| load_pcg_state_pickle, | |
| load_world_state_pickle, | |
| stack_list_of_pytrees, | |
| import_env_state_from_json, | |
| load_from_json_file, | |
| ) | |
| from flax.training.train_state import TrainState | |
| BASE_DIR = "worlds" | |
| DEFAULT_EVAL_LEVELS = [ | |
| "easy.cartpole", | |
| "easy.flappy_bird", | |
| "easy.unicycle", | |
| "easy.car_left", | |
| "easy.car_right", | |
| "easy.pinball", | |
| "easy.swing_up", | |
| "easy.thruster", | |
| ] | |
| def get_eval_levels(eval_levels, static_env_params): | |
| should_permute = [".permute" in l for l in eval_levels] | |
| eval_levels = [re.sub(r"\.permute\d+", "", l) for l in eval_levels] | |
| ls = [get_pcg_state_from_json(os.path.join(BASE_DIR, l + ("" if l.endswith(".json") else ".json"))) for l in eval_levels] | |
| ls = [expand_pcg_state(l, static_env_params) for l in ls] | |
| new_ls = [] | |
| rng = jax.random.PRNGKey(0) | |
| for sp, l in zip(should_permute, ls): | |
| rng, _rng = jax.random.split(rng) | |
| if sp: | |
| l = permute_pcg_state(_rng, l, static_env_params) | |
| new_ls.append(l) | |
| return stack_list_of_pytrees(new_ls) | |
| def evaluate_rnn( # from jaxued | |
| rng: chex.PRNGKey, | |
| env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| train_state: TrainState, | |
| init_hstate: chex.ArrayTree, | |
| init_obs: Observation, | |
| init_env_state: EnvState, | |
| max_episode_length: int, | |
| keep_states=True, | |
| return_trajectories=False, | |
| ) -> Tuple[chex.Array, chex.Array, chex.Array]: | |
| """This runs the RNN on the environment, given an initial state and observation, and returns (states, rewards, episode_lengths) | |
| Args: | |
| rng (chex.PRNGKey): | |
| env (UnderspecifiedEnv): | |
| env_params (EnvParams): | |
| train_state (TrainState): | |
| init_hstate (chex.ArrayTree): Shape (num_levels, ) | |
| init_obs (Observation): Shape (num_levels, ) | |
| init_env_state (EnvState): Shape (num_levels, ) | |
| max_episode_length (int): | |
| Returns: | |
| Tuple[chex.Array, chex.Array, chex.Array]: (States, rewards, episode lengths) ((NUM_STEPS, NUM_LEVELS), (NUM_STEPS, NUM_LEVELS), (NUM_LEVELS,) | |
| """ | |
| num_levels = jax.tree_util.tree_flatten(init_obs)[0][0].shape[0] | |
| def step(carry, _): | |
| rng, hstate, obs, state, done, mask, episode_length = carry | |
| rng, rng_action, rng_step = jax.random.split(rng, 3) | |
| x = jax.tree.map(lambda x: x[None, ...], (obs, done)) | |
| hstate, pi, _ = train_state.apply_fn(train_state.params, hstate, x) | |
| action = pi.sample(seed=rng_action).squeeze(0) | |
| obs, next_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
| jax.random.split(rng_step, num_levels), state, action, env_params | |
| ) | |
| next_mask = mask & ~done | |
| episode_length += mask | |
| if keep_states: | |
| return (rng, hstate, obs, next_state, done, next_mask, episode_length), (state, reward, done, info) | |
| else: | |
| return (rng, hstate, obs, next_state, done, next_mask, episode_length), (None, reward, done, info) | |
| (_, _, _, _, _, _, episode_lengths), (states, rewards, dones, infos) = jax.lax.scan( | |
| step, | |
| ( | |
| rng, | |
| init_hstate, | |
| init_obs, | |
| init_env_state, | |
| jnp.zeros(num_levels, dtype=bool), | |
| jnp.ones(num_levels, dtype=bool), | |
| jnp.zeros(num_levels, dtype=jnp.int32), | |
| ), | |
| None, | |
| length=max_episode_length, | |
| ) | |
| done_idx = jnp.argmax(dones, axis=0) | |
| to_return = (states, rewards, done_idx, episode_lengths, infos) | |
| if return_trajectories: | |
| return to_return, (dones, rewards) | |
| return to_return | |
| def general_eval( | |
| rng: chex.PRNGKey, | |
| eval_env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| train_state: TrainState, | |
| levels: EnvState, | |
| num_eval_steps: int, | |
| num_levels: int, | |
| keep_states=True, | |
| return_trajectories=False, | |
| ): | |
| """ | |
| This evaluates the current policy on the set of evaluation levels | |
| It returns (states, cum_rewards, episode_lengths), with shapes (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,) | |
| """ | |
| rng, rng_reset = jax.random.split(rng) | |
| init_obs, init_env_state = jax.vmap(eval_env.reset_to_level, (0, 0, None))( | |
| jax.random.split(rng_reset, num_levels), levels, env_params | |
| ) | |
| init_hstate = ScannedRNN.initialize_carry(num_levels) | |
| (states, rewards, done_idx, episode_lengths, infos), (dones, reward) = evaluate_rnn( | |
| rng, | |
| eval_env, | |
| env_params, | |
| train_state, | |
| init_hstate, | |
| init_obs, | |
| init_env_state, | |
| num_eval_steps, | |
| keep_states=keep_states, | |
| return_trajectories=True, | |
| ) | |
| mask = jnp.arange(num_eval_steps)[..., None] < episode_lengths | |
| cum_rewards = (rewards * mask).sum(axis=0) | |
| to_return = ( | |
| states, | |
| cum_rewards, | |
| done_idx, | |
| episode_lengths, | |
| infos, | |
| ) # (num_steps, num_eval_levels, ...), (num_eval_levels,), (num_eval_levels,) | |
| if return_trajectories: | |
| return to_return, (dones, reward) | |
| return to_return | |
| def compute_gae( | |
| gamma: float, | |
| lambd: float, | |
| last_value: chex.Array, | |
| values: chex.Array, | |
| rewards: chex.Array, | |
| dones: chex.Array, | |
| ) -> Tuple[chex.Array, chex.Array]: | |
| """This takes in arrays of shape (NUM_STEPS, NUM_ENVS) and returns the advantages and targets. | |
| Args: | |
| gamma (float): | |
| lambd (float): | |
| last_value (chex.Array): Shape (NUM_ENVS) | |
| values (chex.Array): Shape (NUM_STEPS, NUM_ENVS) | |
| rewards (chex.Array): Shape (NUM_STEPS, NUM_ENVS) | |
| dones (chex.Array): Shape (NUM_STEPS, NUM_ENVS) | |
| Returns: | |
| Tuple[chex.Array, chex.Array]: advantages, targets; each of shape (NUM_STEPS, NUM_ENVS) | |
| """ | |
| def compute_gae_at_timestep(carry, x): | |
| gae, next_value = carry | |
| value, reward, done = x | |
| delta = reward + gamma * next_value * (1 - done) - value | |
| gae = delta + gamma * lambd * (1 - done) * gae | |
| return (gae, value), gae | |
| _, advantages = jax.lax.scan( | |
| compute_gae_at_timestep, | |
| (jnp.zeros_like(last_value), last_value), | |
| (values, rewards, dones), | |
| reverse=True, | |
| unroll=16, | |
| ) | |
| return advantages, advantages + values | |
| def sample_trajectories_rnn( | |
| rng: chex.PRNGKey, | |
| env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| train_state: TrainState, | |
| init_hstate: chex.ArrayTree, | |
| init_obs: Observation, | |
| init_env_state: EnvState, | |
| num_envs: int, | |
| max_episode_length: int, | |
| return_states: bool = False, | |
| ) -> Tuple[ | |
| Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array], | |
| Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict], | |
| ]: | |
| """This samples trajectories from the environment using the agent specified by the `train_state`. | |
| Args: | |
| rng (chex.PRNGKey): Singleton | |
| env (UnderspecifiedEnv): | |
| env_params (EnvParams): | |
| train_state (TrainState): Singleton | |
| init_hstate (chex.ArrayTree): This is the init RNN hidden state, has to have shape (NUM_ENVS, ...) | |
| init_obs (Observation): The initial observation, shape (NUM_ENVS, ...) | |
| init_env_state (EnvState): The initial env state (NUM_ENVS, ...) | |
| num_envs (int): The number of envs that are vmapped over. | |
| max_episode_length (int): The maximum episode length, i.e., the number of steps to do the rollouts for. | |
| Returns: | |
| Tuple[Tuple[chex.PRNGKey, TrainState, chex.ArrayTree, Observation, EnvState, chex.Array], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict]]: (rng, train_state, hstate, last_obs, last_env_state, last_value), traj, where traj is (obs, action, reward, done, log_prob, value, info). The first element in the tuple consists of arrays that have shapes (NUM_ENVS, ...) (except `rng` and and `train_state` which are singleton). The second element in the tuple is of shape (NUM_STEPS, NUM_ENVS, ...), and it contains the trajectory. | |
| """ | |
| def sample_step(carry, _): | |
| rng, train_state, hstate, obs, env_state, last_done = carry | |
| prev_state = env_state | |
| rng, rng_action, rng_step = jax.random.split(rng, 3) | |
| x = jax.tree.map(lambda x: x[None, ...], (obs, last_done)) | |
| hstate, pi, value = train_state.apply_fn(train_state.params, hstate, x) | |
| action = pi.sample(seed=rng_action) | |
| log_prob = pi.log_prob(action) | |
| value, action, log_prob = jax.tree.map(lambda x: x.squeeze(0), (value, action, log_prob)) | |
| next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
| jax.random.split(rng_step, num_envs), env_state, action, env_params | |
| ) | |
| carry = (rng, train_state, hstate, next_obs, env_state, done) | |
| step = (obs, action, reward, done, log_prob, value, info) | |
| if return_states: | |
| step += (prev_state,) | |
| return carry, step | |
| (rng, train_state, hstate, last_obs, last_env_state, last_done), traj = jax.lax.scan( | |
| sample_step, | |
| ( | |
| rng, | |
| train_state, | |
| init_hstate, | |
| init_obs, | |
| init_env_state, | |
| jnp.zeros(num_envs, dtype=bool), | |
| ), | |
| None, | |
| length=max_episode_length, | |
| ) | |
| x = jax.tree.map(lambda x: x[None, ...], (last_obs, last_done)) | |
| _, _, last_value = train_state.apply_fn(train_state.params, hstate, x) | |
| my_obs = traj[0] | |
| rew = traj[2] | |
| return (rng, train_state, hstate, last_obs, last_env_state, last_value.squeeze(0)), traj | |
| def update_actor_critic_rnn( | |
| rng: chex.PRNGKey, | |
| train_state: TrainState, | |
| init_hstate: chex.ArrayTree, | |
| batch: chex.ArrayTree, | |
| num_envs: int, | |
| n_steps: int, | |
| n_minibatch: int, | |
| n_epochs: int, | |
| clip_eps: float, | |
| entropy_coeff: float, | |
| critic_coeff: float, | |
| update_grad: bool = True, | |
| ) -> Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]: | |
| """This function takes in a rollout, and PPO hyperparameters, and updates the train state. | |
| Args: | |
| rng (chex.PRNGKey): | |
| train_state (TrainState): | |
| init_hstate (chex.ArrayTree): | |
| batch (chex.ArrayTree): obs, actions, dones, log_probs, values, targets, advantages | |
| num_envs (int): | |
| n_steps (int): | |
| n_minibatch (int): | |
| n_epochs (int): | |
| clip_eps (float): | |
| entropy_coeff (float): | |
| critic_coeff (float): | |
| update_grad (bool, optional): If False, the train state does not actually get updated. Defaults to True. | |
| Returns: | |
| Tuple[Tuple[chex.PRNGKey, TrainState], chex.ArrayTree]: It returns a new rng, the updated train_state, and the losses. The losses have structure (loss, (l_vf, l_clip, entropy)) | |
| """ | |
| obs, actions, dones, log_probs, values, targets, advantages = batch | |
| last_dones = jnp.roll(dones, 1, axis=0).at[0].set(False) | |
| batch = obs, actions, last_dones, log_probs, values, targets, advantages | |
| def update_epoch(carry, _): | |
| def update_minibatch(train_state, minibatch): | |
| init_hstate, obs, actions, last_dones, log_probs, values, targets, advantages = minibatch | |
| def loss_fn(params): | |
| _, pi, values_pred = train_state.apply_fn(params, init_hstate, (obs, last_dones)) | |
| log_probs_pred = pi.log_prob(actions) | |
| entropy = pi.entropy().mean() | |
| ratio = jnp.exp(log_probs_pred - log_probs) | |
| A = (advantages - advantages.mean()) / (advantages.std() + 1e-8) | |
| l_clip = (-jnp.minimum(ratio * A, jnp.clip(ratio, 1 - clip_eps, 1 + clip_eps) * A)).mean() | |
| values_pred_clipped = values + (values_pred - values).clip(-clip_eps, clip_eps) | |
| l_vf = 0.5 * jnp.maximum((values_pred - targets) ** 2, (values_pred_clipped - targets) ** 2).mean() | |
| loss = l_clip + critic_coeff * l_vf - entropy_coeff * entropy | |
| return loss, (l_vf, l_clip, entropy) | |
| grad_fn = jax.value_and_grad(loss_fn, has_aux=True) | |
| loss, grads = grad_fn(train_state.params) | |
| if update_grad: | |
| train_state = train_state.apply_gradients(grads=grads) | |
| grad_norm = jnp.linalg.norm( | |
| jnp.concatenate(jax.tree_util.tree_map(lambda x: x.flatten(), jax.tree_util.tree_flatten(grads)[0])) | |
| ) | |
| return train_state, (loss, grad_norm) | |
| rng, train_state = carry | |
| rng, rng_perm = jax.random.split(rng) | |
| permutation = jax.random.permutation(rng_perm, num_envs) | |
| minibatches = ( | |
| jax.tree.map( | |
| lambda x: jnp.take(x, permutation, axis=0).reshape(n_minibatch, -1, *x.shape[1:]), | |
| init_hstate, | |
| ), | |
| *jax.tree.map( | |
| lambda x: jnp.take(x, permutation, axis=1) | |
| .reshape(x.shape[0], n_minibatch, -1, *x.shape[2:]) | |
| .swapaxes(0, 1), | |
| batch, | |
| ), | |
| ) | |
| train_state, (losses, grads) = jax.lax.scan(update_minibatch, train_state, minibatches) | |
| return (rng, train_state), (losses, grads) | |
| return jax.lax.scan(update_epoch, (rng, train_state), None, n_epochs) | |
| def sample_trajectories_and_learn( | |
| env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| config: dict, | |
| rng: chex.PRNGKey, | |
| train_state: TrainState, | |
| init_hstate: chex.Array, | |
| init_obs: Observation, | |
| init_env_state: EnvState, | |
| update_grad: bool = True, | |
| return_states: bool = False, | |
| ) -> Tuple[ | |
| Tuple[chex.PRNGKey, TrainState, Observation, EnvState], | |
| Tuple[ | |
| Observation, | |
| chex.Array, | |
| chex.Array, | |
| chex.Array, | |
| chex.Array, | |
| chex.Array, | |
| dict, | |
| chex.Array, | |
| chex.Array, | |
| chex.ArrayTree, | |
| chex.Array, | |
| ], | |
| ]: | |
| """This function loops the following: | |
| - rollout for config['num_steps'] | |
| - learn / update policy | |
| And it loops it for config['outer_rollout_steps']. | |
| What is returns is a new carry (rng, train_state, init_obs, init_env_state), and concatenated rollouts. The shape of the rollouts are config['num_steps'] * config['outer_rollout_steps']. In other words, the trajectories returned by this function are the same as if we ran rollouts for config['num_steps'] * config['outer_rollout_steps'] steps, but the agent does perform PPO updates in between. | |
| Args: | |
| env (UnderspecifiedEnv): | |
| env_params (EnvParams): | |
| config (dict): | |
| rng (chex.PRNGKey): | |
| train_state (TrainState): | |
| init_obs (Observation): | |
| init_env_state (EnvState): | |
| update_grad (bool, optional): Defaults to True. | |
| Returns: | |
| Tuple[Tuple[chex.PRNGKey, TrainState, Observation, EnvState], Tuple[Observation, chex.Array, chex.Array, chex.Array, chex.Array, chex.Array, dict, chex.Array, chex.Array, chex.ArrayTree, chex.Array]]: This returns a tuple: | |
| ( | |
| (rng, train_state, init_obs, init_env_state), | |
| (obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads) | |
| ) | |
| """ | |
| def single_step(carry, _): | |
| rng, train_state, init_hstate, init_obs, init_env_state = carry | |
| ((rng, train_state, new_hstate, last_obs, last_env_state, last_value), traj,) = sample_trajectories_rnn( | |
| rng, | |
| env, | |
| env_params, | |
| train_state, | |
| init_hstate, | |
| init_obs, | |
| init_env_state, | |
| config["num_train_envs"], | |
| config["num_steps"], | |
| return_states=return_states, | |
| ) | |
| if return_states: | |
| states = traj[-1] | |
| traj = traj[:-1] | |
| (obs, actions, rewards, dones, log_probs, values, info) = traj | |
| advantages, targets = compute_gae(config["gamma"], config["gae_lambda"], last_value, values, rewards, dones) | |
| # Update the policy using trajectories collected from replay levels | |
| (rng, train_state), (losses, grads) = update_actor_critic_rnn( | |
| rng, | |
| train_state, | |
| init_hstate, | |
| (obs, actions, dones, log_probs, values, targets, advantages), | |
| config["num_train_envs"], | |
| config["num_steps"], | |
| config["num_minibatches"], | |
| config["update_epochs"], | |
| config["clip_eps"], | |
| config["ent_coef"], | |
| config["vf_coef"], | |
| update_grad=update_grad, | |
| ) | |
| new_carry = (rng, train_state, new_hstate, last_obs, last_env_state) | |
| step = (obs, actions, rewards, dones, log_probs, values, info, advantages, targets, losses, grads) | |
| if return_states: | |
| step += (states,) | |
| return new_carry, step | |
| carry = (rng, train_state, init_hstate, init_obs, init_env_state) | |
| new_carry, all_rollouts = jax.lax.scan(single_step, carry, None, length=config["outer_rollout_steps"]) | |
| all_rollouts = jax.tree_util.tree_map(lambda x: jnp.concatenate(x, axis=0), all_rollouts) | |
| return new_carry, all_rollouts | |
| def no_op_rollout( | |
| env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| rng: chex.PRNGKey, | |
| init_obs: Observation, | |
| init_env_state: EnvState, | |
| num_envs: int, | |
| max_episode_length: int, | |
| do_random=False, | |
| ): | |
| noop = jnp.array(env.action_type.noop_action()) | |
| zero_action = jnp.repeat(noop[None, ...], num_envs, axis=0) | |
| SHAPE = zero_action.shape | |
| def sample_step(carry, _): | |
| rng, obs, env_state, last_done = carry | |
| rng, rng_step, _rng = jax.random.split(rng, 3) | |
| if do_random: | |
| action = jax.vmap(env.action_space(env_params).sample)(jax.random.split(_rng, num_envs)) | |
| else: | |
| action = zero_action | |
| next_obs, env_state, reward, done, info = jax.vmap(env.step, in_axes=(0, 0, 0, None))( | |
| jax.random.split(rng_step, num_envs), env_state, action, env_params | |
| ) | |
| carry = (rng, next_obs, env_state, done) | |
| return carry, (obs, action, reward, done, info) | |
| (rng, last_obs, last_env_state, last_done), traj = jax.lax.scan( | |
| sample_step, | |
| ( | |
| rng, | |
| init_obs, | |
| init_env_state, | |
| jnp.zeros(num_envs, dtype=bool), | |
| ), | |
| None, | |
| length=max_episode_length, | |
| ) | |
| info = traj[-1] | |
| dones = traj[-2] | |
| returns_per_env = (info["returned_episode_returns"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0)) | |
| lens_per_env = (info["returned_episode_lengths"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0)) | |
| success_per_env = (info["returned_episode_solved"] * dones).sum(axis=0) / jnp.maximum(1, dones.sum(axis=0)) | |
| return returns_per_env, lens_per_env, success_per_env | |
| def no_op_and_random_rollout( | |
| env: UnderspecifiedEnv, | |
| env_params: EnvParams, | |
| rng: chex.PRNGKey, | |
| init_obs: Observation, | |
| init_env_state: EnvState, | |
| num_envs: int, | |
| max_episode_length: int, | |
| ): | |
| returns_noop, lens_noop, success_noop = no_op_rollout( | |
| env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=False | |
| ) | |
| returns_random, lens_random, success_random = no_op_rollout( | |
| env, env_params, rng, init_obs, init_env_state, num_envs, max_episode_length, do_random=True | |
| ) | |
| return returns_noop, lens_noop, success_noop, returns_random, lens_random, success_random | |