File size: 6,481 Bytes
6140064 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 | """PPO agent adapter and checkpoint loading utilities."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from Craftax_Baselines.ppo import ActorCritic
from Craftax_Baselines.ppo_rnn import ActorCriticRNN
from Craftax_Baselines.ppo_rnd import ActorCriticRND
def load_ppo_params(
path: str,
network: Any,
model_type: str,
num_envs: int,
obs_shape: tuple,
layer_size: int = 512,
) -> Any:
"""Restore PPO parameters from an Orbax checkpoint.
Args:
path: Path to the Orbax checkpoint directory.
network: Instantiated Flax network (used only for structure).
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
num_envs: Number of parallel environments (affects RNN init shape).
obs_shape: Observation shape tuple.
layer_size: Hidden layer size (for RNN hidden state init).
Returns:
Restored parameter pytree.
"""
path = str(Path(path).resolve())
rng = jax.random.PRNGKey(0)
if model_type == "ppo_rnn":
init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs)))
abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x)
else:
abstract = network.init(rng, jnp.zeros((1, *obs_shape)))
with ocp.CheckpointManager(path) as mgr:
step = mgr.latest_step()
if step is None:
raise FileNotFoundError(f"No checkpoint at {path}")
restored = mgr.restore(
step,
args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True),
)
print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})")
return restored["params"]
def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any:
"""Instantiate the correct PPO architecture.
Args:
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
num_actions: Size of the discrete action space.
layer_size: Hidden layer width.
config: Training config (forwarded to ``ActorCriticRNN``).
Returns:
Flax module instance.
"""
model_type = model_type.lower()
if model_type == "ppo_rnn":
return ActorCriticRNN(num_actions, config=config)
if model_type == "ppo_rnd":
return ActorCriticRND(num_actions, layer_size)
return ActorCritic(num_actions, layer_size)
def load_ppo_agent(
path: str,
num_actions: int,
obs_dim: int,
layer_size: int,
model_type: str,
config: dict,
num_envs: int = 1,
) -> "PPOAgent":
"""Build network, load params, and return a :class:`PPOAgent`.
Args:
path: Path to the Orbax checkpoint directory.
num_actions: Size of the discrete action space.
obs_dim: Observation vector dimensionality.
layer_size: Hidden layer width.
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
config: Training config dict.
num_envs: Number of parallel environments.
Returns:
A fully initialised :class:`PPOAgent`.
"""
net = build_ppo_network(model_type, num_actions, layer_size, config)
params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size)
return PPOAgent(net, params, model_type, layer_size)
class PPOAgent:
"""Uniform interface over PPO-RNN / PPO / PPO-RND for action collection.
Args:
network: Flax actor-critic module.
params: Loaded parameter pytree.
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
layer_size: Hidden layer width (used for RNN hidden-state shape).
"""
def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None:
self.network = network
self.params = params
self.model_type = model_type.lower()
self.layer_size = layer_size
def init_hidden(self, batch_size: int) -> jnp.ndarray | None:
"""Return a zero hidden state for RNN models, else ``None``."""
if self.model_type == "ppo_rnn":
return jnp.zeros((batch_size, self.layer_size))
return None
def act(
self,
obs: jnp.ndarray,
done: jnp.ndarray,
hidden: jnp.ndarray | None,
rng: jax.Array,
temperature: float = 1.0,
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
"""Sample an action.
Args:
obs: Observation array ``[B, obs_dim]``.
done: Episode-done flags ``[B]``.
hidden: RNN hidden state (``None`` for non-RNN models).
rng: PRNG key.
temperature: Softmax temperature for sampling.
Returns:
``(action, new_hidden)`` tuple.
"""
if self.model_type == "ppo_rnn":
ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
elif self.model_type == "ppo_rnd":
pi, _, _ = self.network.apply(self.params, obs)
new_hidden = hidden
else:
pi, _ = self.network.apply(self.params, obs)
new_hidden = hidden
action = jax.random.categorical(rng, pi.logits / temperature)
if self.model_type == "ppo_rnn":
action = action.squeeze(0)
return action, new_hidden
def get_pi(
self,
obs: jnp.ndarray,
done: jnp.ndarray | None = None,
hidden: jnp.ndarray | None = None,
) -> tuple[Any, jnp.ndarray | None]:
"""Return the policy distribution (used in DAgger expert labelling).
Args:
obs: Observation array ``[B, obs_dim]``.
done: Episode-done flags ``[B]`` (required for RNN models).
hidden: RNN hidden state.
Returns:
``(pi, new_hidden)`` tuple.
"""
if self.model_type == "ppo_rnn":
ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
return pi, new_hidden
if self.model_type == "ppo_rnd":
pi, _, _ = self.network.apply(self.params, obs)
return pi, hidden
pi, _ = self.network.apply(self.params, obs)
return pi, hidden
|