MathisW78's picture
Upload COMP0258 demo bundle (code + diffusion/PPO checkpoints + ablation assets)
6140064 verified
"""PPO agent adapter and checkpoint loading utilities."""
from __future__ import annotations
from pathlib import Path
from typing import Any
import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp
from Craftax_Baselines.ppo import ActorCritic
from Craftax_Baselines.ppo_rnn import ActorCriticRNN
from Craftax_Baselines.ppo_rnd import ActorCriticRND
def load_ppo_params(
path: str,
network: Any,
model_type: str,
num_envs: int,
obs_shape: tuple,
layer_size: int = 512,
) -> Any:
"""Restore PPO parameters from an Orbax checkpoint.
Args:
path: Path to the Orbax checkpoint directory.
network: Instantiated Flax network (used only for structure).
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
num_envs: Number of parallel environments (affects RNN init shape).
obs_shape: Observation shape tuple.
layer_size: Hidden layer size (for RNN hidden state init).
Returns:
Restored parameter pytree.
"""
path = str(Path(path).resolve())
rng = jax.random.PRNGKey(0)
if model_type == "ppo_rnn":
init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs)))
abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x)
else:
abstract = network.init(rng, jnp.zeros((1, *obs_shape)))
with ocp.CheckpointManager(path) as mgr:
step = mgr.latest_step()
if step is None:
raise FileNotFoundError(f"No checkpoint at {path}")
restored = mgr.restore(
step,
args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True),
)
print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})")
return restored["params"]
def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any:
"""Instantiate the correct PPO architecture.
Args:
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
num_actions: Size of the discrete action space.
layer_size: Hidden layer width.
config: Training config (forwarded to ``ActorCriticRNN``).
Returns:
Flax module instance.
"""
model_type = model_type.lower()
if model_type == "ppo_rnn":
return ActorCriticRNN(num_actions, config=config)
if model_type == "ppo_rnd":
return ActorCriticRND(num_actions, layer_size)
return ActorCritic(num_actions, layer_size)
def load_ppo_agent(
path: str,
num_actions: int,
obs_dim: int,
layer_size: int,
model_type: str,
config: dict,
num_envs: int = 1,
) -> "PPOAgent":
"""Build network, load params, and return a :class:`PPOAgent`.
Args:
path: Path to the Orbax checkpoint directory.
num_actions: Size of the discrete action space.
obs_dim: Observation vector dimensionality.
layer_size: Hidden layer width.
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
config: Training config dict.
num_envs: Number of parallel environments.
Returns:
A fully initialised :class:`PPOAgent`.
"""
net = build_ppo_network(model_type, num_actions, layer_size, config)
params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size)
return PPOAgent(net, params, model_type, layer_size)
class PPOAgent:
"""Uniform interface over PPO-RNN / PPO / PPO-RND for action collection.
Args:
network: Flax actor-critic module.
params: Loaded parameter pytree.
model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
layer_size: Hidden layer width (used for RNN hidden-state shape).
"""
def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None:
self.network = network
self.params = params
self.model_type = model_type.lower()
self.layer_size = layer_size
def init_hidden(self, batch_size: int) -> jnp.ndarray | None:
"""Return a zero hidden state for RNN models, else ``None``."""
if self.model_type == "ppo_rnn":
return jnp.zeros((batch_size, self.layer_size))
return None
def act(
self,
obs: jnp.ndarray,
done: jnp.ndarray,
hidden: jnp.ndarray | None,
rng: jax.Array,
temperature: float = 1.0,
) -> tuple[jnp.ndarray, jnp.ndarray | None]:
"""Sample an action.
Args:
obs: Observation array ``[B, obs_dim]``.
done: Episode-done flags ``[B]``.
hidden: RNN hidden state (``None`` for non-RNN models).
rng: PRNG key.
temperature: Softmax temperature for sampling.
Returns:
``(action, new_hidden)`` tuple.
"""
if self.model_type == "ppo_rnn":
ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
elif self.model_type == "ppo_rnd":
pi, _, _ = self.network.apply(self.params, obs)
new_hidden = hidden
else:
pi, _ = self.network.apply(self.params, obs)
new_hidden = hidden
action = jax.random.categorical(rng, pi.logits / temperature)
if self.model_type == "ppo_rnn":
action = action.squeeze(0)
return action, new_hidden
def get_pi(
self,
obs: jnp.ndarray,
done: jnp.ndarray | None = None,
hidden: jnp.ndarray | None = None,
) -> tuple[Any, jnp.ndarray | None]:
"""Return the policy distribution (used in DAgger expert labelling).
Args:
obs: Observation array ``[B, obs_dim]``.
done: Episode-done flags ``[B]`` (required for RNN models).
hidden: RNN hidden state.
Returns:
``(pi, new_hidden)`` tuple.
"""
if self.model_type == "ppo_rnn":
ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
return pi, new_hidden
if self.model_type == "ppo_rnd":
pi, _, _ = self.network.apply(self.params, obs)
return pi, hidden
pi, _ = self.network.apply(self.params, obs)
return pi, hidden