Spaces:
Runtime error
Runtime error
| import os | |
| import hydra | |
| from omegaconf import OmegaConf | |
| from kinetix.environment.ued.ued import ( | |
| make_reset_train_function_with_list_of_levels, | |
| make_reset_train_function_with_mutations, | |
| ) | |
| from kinetix.render.renderer_pixels import make_render_pixels | |
| from kinetix.util.config import ( | |
| get_video_frequency, | |
| init_wandb, | |
| normalise_config, | |
| generate_params_from_config, | |
| ) | |
| os.environ["WANDB_DISABLE_SERVICE"] = "True" | |
| import sys | |
| from typing import Any, NamedTuple | |
| import jax | |
| import jax.numpy as jnp | |
| import numpy as np | |
| import optax | |
| from flax.training.train_state import TrainState | |
| from kinetix.models import make_network_from_config | |
| from kinetix.util.learning import general_eval, get_eval_levels | |
| from flax.serialization import to_state_dict | |
| import wandb | |
| from kinetix.environment.env import PixelObservations, make_kinetix_env_from_name | |
| from kinetix.environment.wrappers import ( | |
| AutoReplayWrapper, | |
| AutoResetWrapper, | |
| BatchEnvWrapper, | |
| DenseRewardWrapper, | |
| LogWrapper, | |
| UnderspecifiedToGymnaxWrapper, | |
| ) | |
| from kinetix.models.actor_critic import ScannedRNN | |
| from kinetix.util.saving import ( | |
| load_train_state_from_wandb_artifact_path, | |
| save_model_to_wandb, | |
| ) | |
| class Transition(NamedTuple): | |
| done: jnp.ndarray | |
| action: jnp.ndarray | |
| value: jnp.ndarray | |
| reward: jnp.ndarray | |
| log_prob: jnp.ndarray | |
| obs: Any | |
| info: jnp.ndarray | |
| def make_train(config, env_params, static_env_params): | |
| config["num_updates"] = config["total_timesteps"] // config["num_steps"] // config["num_train_envs"] | |
| config["minibatch_size"] = config["num_train_envs"] * config["num_steps"] // config["num_minibatches"] | |
| env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
| if config["train_level_mode"] == "list": | |
| reset_func = make_reset_train_function_with_list_of_levels( | |
| config, config["train_levels_list"], static_env_params, is_loading_train_levels=True | |
| ) | |
| elif config["train_level_mode"] == "random": | |
| reset_func = make_reset_train_function_with_mutations( | |
| env.physics_engine, env_params, env.static_env_params, config | |
| ) | |
| else: | |
| raise ValueError(f"Unknown train_level_mode: {config['train_level_mode']}") | |
| env = UnderspecifiedToGymnaxWrapper(AutoResetWrapper(env, reset_func)) | |
| eval_env = make_kinetix_env_from_name(config["env_name"], static_env_params=static_env_params) | |
| eval_env = UnderspecifiedToGymnaxWrapper(AutoReplayWrapper(eval_env)) | |
| env = DenseRewardWrapper(env) | |
| env = LogWrapper(env) | |
| env = BatchEnvWrapper(env, num_envs=config["num_train_envs"]) | |
| eval_env_nonbatch = LogWrapper(DenseRewardWrapper(eval_env)) | |
| def linear_schedule(count): | |
| frac = 1.0 - (count // (config["num_minibatches"] * config["update_epochs"])) / config["num_updates"] | |
| return config["lr"] * frac | |
| def linear_warmup_cosine_decay_schedule(count): | |
| frac = (count // (config["num_minibatches"] * config["update_epochs"])) / config[ | |
| "num_updates" | |
| ] # between 0 and 1 | |
| delta = config["peak_lr"] - config["initial_lr"] | |
| frac_diff_max = 1.0 - config["warmup_frac"] | |
| frac_cosine = (frac - config["warmup_frac"]) / frac_diff_max | |
| return jax.lax.select( | |
| frac < config["warmup_frac"], | |
| config["initial_lr"] + delta * frac / config["warmup_frac"], | |
| config["peak_lr"] * jnp.maximum(0.0, 0.5 * (1.0 + jnp.cos(jnp.pi * ((frac_cosine) % 1.0)))), | |
| ) | |
| def train(rng): | |
| # INIT NETWORK | |
| network = make_network_from_config(env, env_params, config) | |
| rng, _rng = jax.random.split(rng) | |
| obsv, env_state = env.reset(_rng, env_params) | |
| dones = jnp.zeros((config["num_train_envs"]), dtype=jnp.bool_) | |
| rng, _rng = jax.random.split(rng) | |
| init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
| init_x = jax.tree.map(lambda x: x[None, ...], (obsv, dones)) | |
| network_params = network.init(_rng, init_hstate, init_x) | |
| param_count = sum(x.size for x in jax.tree_util.tree_leaves(network_params)) | |
| obs_size = sum(x.size for x in jax.tree_util.tree_leaves(obsv)) // config["num_train_envs"] | |
| print("Number of parameters", param_count, "size of obs: ", obs_size) | |
| if config["anneal_lr"]: | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(config["max_grad_norm"]), | |
| optax.adam(learning_rate=linear_schedule, eps=1e-5), | |
| ) | |
| elif config["warmup_lr"]: | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(config["max_grad_norm"]), | |
| optax.adamw(learning_rate=linear_warmup_cosine_decay_schedule, eps=1e-5), | |
| ) | |
| else: | |
| tx = optax.chain( | |
| optax.clip_by_global_norm(config["max_grad_norm"]), | |
| optax.adam(config["lr"], eps=1e-5), | |
| ) | |
| train_state = TrainState.create( | |
| apply_fn=network.apply, | |
| params=network_params, | |
| tx=tx, | |
| ) | |
| if config["load_from_checkpoint"] != None: | |
| print("LOADING from", config["load_from_checkpoint"], "with only params =", config["load_only_params"]) | |
| train_state = load_train_state_from_wandb_artifact_path( | |
| train_state, config["load_from_checkpoint"], load_only_params=config["load_only_params"] | |
| ) | |
| # INIT ENV | |
| rng, _rng = jax.random.split(rng) | |
| obsv, env_state = env.reset(_rng, env_params) | |
| init_hstate = ScannedRNN.initialize_carry(config["num_train_envs"]) | |
| render_static_env_params = env.static_env_params.replace(downscale=1) | |
| pixel_renderer = jax.jit(make_render_pixels(env_params, render_static_env_params)) | |
| pixel_render_fn = lambda x: pixel_renderer(x) / 255.0 | |
| eval_levels = get_eval_levels(config["eval_levels"], env.static_env_params) | |
| def _vmapped_eval_step(runner_state, rng): | |
| def _single_eval_step(rng): | |
| return general_eval( | |
| rng, | |
| eval_env_nonbatch, | |
| env_params, | |
| runner_state[0], | |
| eval_levels, | |
| env_params.max_timesteps, | |
| config["num_eval_levels"], | |
| keep_states=True, | |
| return_trajectories=True, | |
| ) | |
| (states, returns, done_idxs, episode_lengths, eval_infos), (eval_dones, eval_rewards) = jax.vmap( | |
| _single_eval_step | |
| )(jax.random.split(rng, config["eval_num_attempts"])) | |
| eval_solves = (eval_infos["returned_episode_solved"] * eval_dones).sum(axis=1) / jnp.maximum( | |
| 1, eval_dones.sum(axis=1) | |
| ) | |
| states_to_plot = jax.tree.map(lambda x: x[0], states) | |
| # obs = jax.vmap(jax.vmap(pixel_render_fn))(states_to_plot.env_state.env_state.env_state) | |
| return ( | |
| states_to_plot, | |
| done_idxs[0], | |
| returns[0], | |
| returns.mean(axis=0), | |
| episode_lengths.mean(axis=0), | |
| eval_solves.mean(axis=0), | |
| ) | |
| # TRAIN LOOP | |
| def _update_step(runner_state, unused): | |
| # COLLECT TRAJECTORIES | |
| def _env_step(runner_state, unused): | |
| ( | |
| train_state, | |
| env_state, | |
| last_obs, | |
| last_done, | |
| hstate, | |
| rng, | |
| update_step, | |
| ) = runner_state | |
| # SELECT ACTION | |
| rng, _rng = jax.random.split(rng) | |
| ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :]) | |
| hstate, pi, value = network.apply(train_state.params, hstate, ac_in) | |
| action = pi.sample(seed=_rng) | |
| log_prob = pi.log_prob(action) | |
| value, action, log_prob = ( | |
| value.squeeze(0), | |
| action.squeeze(0), | |
| log_prob.squeeze(0), | |
| ) | |
| # STEP ENV | |
| rng, _rng = jax.random.split(rng) | |
| obsv, env_state, reward, done, info = env.step(_rng, env_state, action, env_params) | |
| transition = Transition(last_done, action, value, reward, log_prob, last_obs, info) | |
| runner_state = ( | |
| train_state, | |
| env_state, | |
| obsv, | |
| done, | |
| hstate, | |
| rng, | |
| update_step, | |
| ) | |
| return runner_state, transition | |
| initial_hstate = runner_state[-3] | |
| runner_state, traj_batch = jax.lax.scan(_env_step, runner_state, None, config["num_steps"]) | |
| # CALCULATE ADVANTAGE | |
| ( | |
| train_state, | |
| env_state, | |
| last_obs, | |
| last_done, | |
| hstate, | |
| rng, | |
| update_step, | |
| ) = runner_state | |
| ac_in = (jax.tree.map(lambda x: x[np.newaxis, :], last_obs), last_done[np.newaxis, :]) | |
| _, _, last_val = network.apply(train_state.params, hstate, ac_in) | |
| last_val = last_val.squeeze(0) | |
| def _calculate_gae(traj_batch, last_val, last_done): | |
| def _get_advantages(carry, transition): | |
| gae, next_value, next_done = carry | |
| done, value, reward = ( | |
| transition.done, | |
| transition.value, | |
| transition.reward, | |
| ) | |
| delta = reward + config["gamma"] * next_value * (1 - next_done) - value | |
| gae = delta + config["gamma"] * config["gae_lambda"] * (1 - next_done) * gae | |
| return (gae, value, done), gae | |
| _, advantages = jax.lax.scan( | |
| _get_advantages, | |
| (jnp.zeros_like(last_val), last_val, last_done), | |
| traj_batch, | |
| reverse=True, | |
| unroll=16, | |
| ) | |
| return advantages, advantages + traj_batch.value | |
| advantages, targets = _calculate_gae(traj_batch, last_val, last_done) | |
| # UPDATE NETWORK | |
| def _update_epoch(update_state, unused): | |
| def _update_minbatch(train_state, batch_info): | |
| init_hstate, traj_batch, advantages, targets = batch_info | |
| def _loss_fn(params, init_hstate, traj_batch, gae, targets): | |
| # RERUN NETWORK | |
| _, pi, value = network.apply(params, init_hstate[0], (traj_batch.obs, traj_batch.done)) | |
| log_prob = pi.log_prob(traj_batch.action) | |
| # CALCULATE VALUE LOSS | |
| value_pred_clipped = traj_batch.value + (value - traj_batch.value).clip( | |
| -config["clip_eps"], config["clip_eps"] | |
| ) | |
| value_losses = jnp.square(value - targets) | |
| value_losses_clipped = jnp.square(value_pred_clipped - targets) | |
| value_loss = 0.5 * jnp.maximum(value_losses, value_losses_clipped).mean() | |
| # CALCULATE ACTOR LOSS | |
| ratio = jnp.exp(log_prob - traj_batch.log_prob) | |
| gae = (gae - gae.mean()) / (gae.std() + 1e-8) | |
| loss_actor1 = ratio * gae | |
| loss_actor2 = ( | |
| jnp.clip( | |
| ratio, | |
| 1.0 - config["clip_eps"], | |
| 1.0 + config["clip_eps"], | |
| ) | |
| * gae | |
| ) | |
| loss_actor = -jnp.minimum(loss_actor1, loss_actor2) | |
| loss_actor = loss_actor.mean() | |
| entropy = pi.entropy().mean() | |
| total_loss = loss_actor + config["vf_coef"] * value_loss - config["ent_coef"] * entropy | |
| return total_loss, (value_loss, loss_actor, entropy) | |
| grad_fn = jax.value_and_grad(_loss_fn, has_aux=True) | |
| total_loss, grads = grad_fn(train_state.params, init_hstate, traj_batch, advantages, targets) | |
| train_state = train_state.apply_gradients(grads=grads) | |
| return train_state, total_loss | |
| ( | |
| train_state, | |
| init_hstate, | |
| traj_batch, | |
| advantages, | |
| targets, | |
| rng, | |
| ) = update_state | |
| rng, _rng = jax.random.split(rng) | |
| permutation = jax.random.permutation(_rng, config["num_train_envs"]) | |
| batch = (init_hstate, traj_batch, advantages, targets) | |
| shuffled_batch = jax.tree_util.tree_map(lambda x: jnp.take(x, permutation, axis=1), batch) | |
| minibatches = jax.tree_util.tree_map( | |
| lambda x: jnp.swapaxes( | |
| jnp.reshape( | |
| x, | |
| [x.shape[0], config["num_minibatches"], -1] + list(x.shape[2:]), | |
| ), | |
| 1, | |
| 0, | |
| ), | |
| shuffled_batch, | |
| ) | |
| train_state, total_loss = jax.lax.scan(_update_minbatch, train_state, minibatches) | |
| update_state = ( | |
| train_state, | |
| init_hstate, | |
| traj_batch, | |
| advantages, | |
| targets, | |
| rng, | |
| ) | |
| return update_state, total_loss | |
| init_hstate = initial_hstate[None, :] # TBH | |
| update_state = ( | |
| train_state, | |
| init_hstate, | |
| traj_batch, | |
| advantages, | |
| targets, | |
| rng, | |
| ) | |
| update_state, loss_info = jax.lax.scan(_update_epoch, update_state, None, config["update_epochs"]) | |
| train_state = update_state[0] | |
| metric = jax.tree.map( | |
| lambda x: (x * traj_batch.info["returned_episode"]).sum() / traj_batch.info["returned_episode"].sum(), | |
| traj_batch.info, | |
| ) | |
| rng = update_state[-1] | |
| if config["use_wandb"]: | |
| vid_frequency = get_video_frequency(config, update_step) | |
| rng, _rng = jax.random.split(rng) | |
| to_log_videos = _vmapped_eval_step(runner_state, _rng) | |
| should_log_videos = update_step % vid_frequency == 0 | |
| first = jax.lax.cond( | |
| should_log_videos, | |
| lambda: jax.vmap(jax.vmap(pixel_render_fn))(to_log_videos[0].env_state.env_state.env_state), | |
| lambda: ( | |
| jnp.zeros( | |
| ( | |
| env_params.max_timesteps, | |
| config["num_eval_levels"], | |
| *PixelObservations(env_params, render_static_env_params) | |
| .observation_space(env_params) | |
| .shape, | |
| ) | |
| ) | |
| ), | |
| ) | |
| to_log_videos = (first, should_log_videos, *to_log_videos[1:]) | |
| def callback(metric, raw_info, loss_info, update_step, to_log_videos): | |
| to_log = {} | |
| to_log["timing/num_updates"] = update_step | |
| to_log["timing/num_env_steps"] = update_step * config["num_steps"] * config["num_train_envs"] | |
| ( | |
| obs_vid, | |
| should_log_videos, | |
| idx_vid, | |
| eval_return_vid, | |
| eval_return_mean, | |
| eval_eplen_mean, | |
| eval_solverate_mean, | |
| ) = to_log_videos | |
| to_log["eval/mean_eval_return"] = eval_return_mean.mean() | |
| to_log["eval/mean_eval_eplen"] = eval_eplen_mean.mean() | |
| for i, eval_name in enumerate(config["eval_levels"]): | |
| return_on_video = eval_return_vid[i] | |
| to_log[f"eval_video/return_{eval_name}"] = return_on_video | |
| to_log[f"eval_video/len_{eval_name}"] = idx_vid[i] | |
| to_log[f"eval_avg/return_{eval_name}"] = eval_return_mean[i] | |
| to_log[f"eval_avg/solve_rate_{eval_name}"] = eval_solverate_mean[i] | |
| if should_log_videos: | |
| for i, eval_name in enumerate(config["eval_levels"]): | |
| obs_to_use = obs_vid[: idx_vid[i], i] | |
| obs_to_use = np.asarray(obs_to_use).transpose(0, 3, 2, 1)[:, :, ::-1, :] | |
| to_log[f"media/eval_video_{eval_name}"] = wandb.Video((obs_to_use * 255).astype(np.uint8)) | |
| wandb.log(to_log) | |
| jax.debug.callback(callback, metric, traj_batch.info, loss_info, update_step, to_log_videos) | |
| runner_state = ( | |
| train_state, | |
| env_state, | |
| last_obs, | |
| last_done, | |
| hstate, | |
| rng, | |
| update_step + 1, | |
| ) | |
| return runner_state, metric | |
| rng, _rng = jax.random.split(rng) | |
| runner_state = ( | |
| train_state, | |
| env_state, | |
| obsv, | |
| jnp.zeros((config["num_train_envs"]), dtype=bool), | |
| init_hstate, | |
| _rng, | |
| 0, | |
| ) | |
| runner_state, metric = jax.lax.scan(_update_step, runner_state, None, config["num_updates"]) | |
| return {"runner_state": runner_state, "metric": metric} | |
| return train | |
| def main(config): | |
| config = normalise_config(OmegaConf.to_container(config), "PPO") | |
| env_params, static_env_params = generate_params_from_config(config) | |
| config["env_params"] = to_state_dict(env_params) | |
| config["static_env_params"] = to_state_dict(static_env_params) | |
| if config["use_wandb"]: | |
| run = init_wandb(config, "PPO") | |
| rng = jax.random.PRNGKey(config["seed"]) | |
| rng, _rng = jax.random.split(rng) | |
| train_jit = jax.jit(make_train(config, env_params, static_env_params)) | |
| out = train_jit(_rng) | |
| if config["use_wandb"]: | |
| if config["save_policy"]: | |
| train_state = jax.tree.map(lambda x: x, out["runner_state"][0]) | |
| save_model_to_wandb(train_state, config["total_timesteps"], config) | |
| if __name__ == "__main__": | |
| main() | |