"""PPO agent adapter and checkpoint loading utilities.""" from __future__ import annotations from pathlib import Path from typing import Any import jax import jax.numpy as jnp import numpy as np import orbax.checkpoint as ocp from Craftax_Baselines.ppo import ActorCritic from Craftax_Baselines.ppo_rnn import ActorCriticRNN from Craftax_Baselines.ppo_rnd import ActorCriticRND def load_ppo_params( path: str, network: Any, model_type: str, num_envs: int, obs_shape: tuple, layer_size: int = 512, ) -> Any: """Restore PPO parameters from an Orbax checkpoint. Args: path: Path to the Orbax checkpoint directory. network: Instantiated Flax network (used only for structure). model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. num_envs: Number of parallel environments (affects RNN init shape). obs_shape: Observation shape tuple. layer_size: Hidden layer size (for RNN hidden state init). Returns: Restored parameter pytree. """ path = str(Path(path).resolve()) rng = jax.random.PRNGKey(0) if model_type == "ppo_rnn": init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs))) abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x) else: abstract = network.init(rng, jnp.zeros((1, *obs_shape))) with ocp.CheckpointManager(path) as mgr: step = mgr.latest_step() if step is None: raise FileNotFoundError(f"No checkpoint at {path}") restored = mgr.restore( step, args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True), ) print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})") return restored["params"] def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any: """Instantiate the correct PPO architecture. Args: model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. num_actions: Size of the discrete action space. layer_size: Hidden layer width. config: Training config (forwarded to ``ActorCriticRNN``). Returns: Flax module instance. """ model_type = model_type.lower() if model_type == "ppo_rnn": return ActorCriticRNN(num_actions, config=config) if model_type == "ppo_rnd": return ActorCriticRND(num_actions, layer_size) return ActorCritic(num_actions, layer_size) def load_ppo_agent( path: str, num_actions: int, obs_dim: int, layer_size: int, model_type: str, config: dict, num_envs: int = 1, ) -> "PPOAgent": """Build network, load params, and return a :class:`PPOAgent`. Args: path: Path to the Orbax checkpoint directory. num_actions: Size of the discrete action space. obs_dim: Observation vector dimensionality. layer_size: Hidden layer width. model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. config: Training config dict. num_envs: Number of parallel environments. Returns: A fully initialised :class:`PPOAgent`. """ net = build_ppo_network(model_type, num_actions, layer_size, config) params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size) return PPOAgent(net, params, model_type, layer_size) class PPOAgent: """Uniform interface over PPO-RNN / PPO / PPO-RND for action collection. Args: network: Flax actor-critic module. params: Loaded parameter pytree. model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``. layer_size: Hidden layer width (used for RNN hidden-state shape). """ def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None: self.network = network self.params = params self.model_type = model_type.lower() self.layer_size = layer_size def init_hidden(self, batch_size: int) -> jnp.ndarray | None: """Return a zero hidden state for RNN models, else ``None``.""" if self.model_type == "ppo_rnn": return jnp.zeros((batch_size, self.layer_size)) return None def act( self, obs: jnp.ndarray, done: jnp.ndarray, hidden: jnp.ndarray | None, rng: jax.Array, temperature: float = 1.0, ) -> tuple[jnp.ndarray, jnp.ndarray | None]: """Sample an action. Args: obs: Observation array ``[B, obs_dim]``. done: Episode-done flags ``[B]``. hidden: RNN hidden state (``None`` for non-RNN models). rng: PRNG key. temperature: Softmax temperature for sampling. Returns: ``(action, new_hidden)`` tuple. """ if self.model_type == "ppo_rnn": ac_in = (obs[np.newaxis, :], done[np.newaxis, :]) new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in) elif self.model_type == "ppo_rnd": pi, _, _ = self.network.apply(self.params, obs) new_hidden = hidden else: pi, _ = self.network.apply(self.params, obs) new_hidden = hidden action = jax.random.categorical(rng, pi.logits / temperature) if self.model_type == "ppo_rnn": action = action.squeeze(0) return action, new_hidden def get_pi( self, obs: jnp.ndarray, done: jnp.ndarray | None = None, hidden: jnp.ndarray | None = None, ) -> tuple[Any, jnp.ndarray | None]: """Return the policy distribution (used in DAgger expert labelling). Args: obs: Observation array ``[B, obs_dim]``. done: Episode-done flags ``[B]`` (required for RNN models). hidden: RNN hidden state. Returns: ``(pi, new_hidden)`` tuple. """ if self.model_type == "ppo_rnn": ac_in = (obs[np.newaxis, :], done[np.newaxis, :]) new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in) return pi, new_hidden if self.model_type == "ppo_rnd": pi, _, _ = self.network.apply(self.params, obs) return pi, hidden pi, _ = self.network.apply(self.params, obs) return pi, hidden