File size: 6,481 Bytes
6140064
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
"""PPO agent adapter and checkpoint loading utilities."""

from __future__ import annotations

from pathlib import Path
from typing import Any

import jax
import jax.numpy as jnp
import numpy as np
import orbax.checkpoint as ocp

from Craftax_Baselines.ppo import ActorCritic
from Craftax_Baselines.ppo_rnn import ActorCriticRNN
from Craftax_Baselines.ppo_rnd import ActorCriticRND


def load_ppo_params(
    path: str,
    network: Any,
    model_type: str,
    num_envs: int,
    obs_shape: tuple,
    layer_size: int = 512,
) -> Any:
    """Restore PPO parameters from an Orbax checkpoint.

    Args:
        path:        Path to the Orbax checkpoint directory.
        network:     Instantiated Flax network (used only for structure).
        model_type:  One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
        num_envs:    Number of parallel environments (affects RNN init shape).
        obs_shape:   Observation shape tuple.
        layer_size:  Hidden layer size (for RNN hidden state init).

    Returns:
        Restored parameter pytree.
    """
    path = str(Path(path).resolve())
    rng = jax.random.PRNGKey(0)
    if model_type == "ppo_rnn":
        init_x = (jnp.zeros((1, num_envs, *obs_shape)), jnp.zeros((1, num_envs)))
        abstract = network.init(rng, jnp.zeros((num_envs, layer_size)), init_x)
    else:
        abstract = network.init(rng, jnp.zeros((1, *obs_shape)))

    with ocp.CheckpointManager(path) as mgr:
        step = mgr.latest_step()
        if step is None:
            raise FileNotFoundError(f"No checkpoint at {path}")
        restored = mgr.restore(
            step,
            args=ocp.args.PyTreeRestore(item={"params": abstract}, partial_restore=True),
        )
    print(f"Loaded {model_type.upper()} checkpoint from '{path}' (step {step})")
    return restored["params"]


def build_ppo_network(model_type: str, num_actions: int, layer_size: int, config: dict) -> Any:
    """Instantiate the correct PPO architecture.

    Args:
        model_type:  One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
        num_actions: Size of the discrete action space.
        layer_size:  Hidden layer width.
        config:      Training config (forwarded to ``ActorCriticRNN``).

    Returns:
        Flax module instance.
    """
    model_type = model_type.lower()
    if model_type == "ppo_rnn":
        return ActorCriticRNN(num_actions, config=config)
    if model_type == "ppo_rnd":
        return ActorCriticRND(num_actions, layer_size)
    return ActorCritic(num_actions, layer_size)


def load_ppo_agent(
    path: str,
    num_actions: int,
    obs_dim: int,
    layer_size: int,
    model_type: str,
    config: dict,
    num_envs: int = 1,
) -> "PPOAgent":
    """Build network, load params, and return a :class:`PPOAgent`.

    Args:
        path:        Path to the Orbax checkpoint directory.
        num_actions: Size of the discrete action space.
        obs_dim:     Observation vector dimensionality.
        layer_size:  Hidden layer width.
        model_type:  One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
        config:      Training config dict.
        num_envs:    Number of parallel environments.

    Returns:
        A fully initialised :class:`PPOAgent`.
    """
    net = build_ppo_network(model_type, num_actions, layer_size, config)
    params = load_ppo_params(path, net, model_type, num_envs, (obs_dim,), layer_size)
    return PPOAgent(net, params, model_type, layer_size)


class PPOAgent:
    """Uniform interface over PPO-RNN / PPO / PPO-RND for action collection.

    Args:
        network:    Flax actor-critic module.
        params:     Loaded parameter pytree.
        model_type: One of ``"ppo_rnn"``, ``"ppo_rnd"``, or ``"ppo"``.
        layer_size: Hidden layer width (used for RNN hidden-state shape).
    """

    def __init__(self, network: Any, params: Any, model_type: str, layer_size: int = 512) -> None:
        self.network = network
        self.params = params
        self.model_type = model_type.lower()
        self.layer_size = layer_size

    def init_hidden(self, batch_size: int) -> jnp.ndarray | None:
        """Return a zero hidden state for RNN models, else ``None``."""
        if self.model_type == "ppo_rnn":
            return jnp.zeros((batch_size, self.layer_size))
        return None

    def act(
        self,
        obs: jnp.ndarray,
        done: jnp.ndarray,
        hidden: jnp.ndarray | None,
        rng: jax.Array,
        temperature: float = 1.0,
    ) -> tuple[jnp.ndarray, jnp.ndarray | None]:
        """Sample an action.

        Args:
            obs:         Observation array ``[B, obs_dim]``.
            done:        Episode-done flags ``[B]``.
            hidden:      RNN hidden state (``None`` for non-RNN models).
            rng:         PRNG key.
            temperature: Softmax temperature for sampling.

        Returns:
            ``(action, new_hidden)`` tuple.
        """
        if self.model_type == "ppo_rnn":
            ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
            new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
        elif self.model_type == "ppo_rnd":
            pi, _, _ = self.network.apply(self.params, obs)
            new_hidden = hidden
        else:
            pi, _ = self.network.apply(self.params, obs)
            new_hidden = hidden

        action = jax.random.categorical(rng, pi.logits / temperature)
        if self.model_type == "ppo_rnn":
            action = action.squeeze(0)
        return action, new_hidden

    def get_pi(
        self,
        obs: jnp.ndarray,
        done: jnp.ndarray | None = None,
        hidden: jnp.ndarray | None = None,
    ) -> tuple[Any, jnp.ndarray | None]:
        """Return the policy distribution (used in DAgger expert labelling).

        Args:
            obs:    Observation array ``[B, obs_dim]``.
            done:   Episode-done flags ``[B]`` (required for RNN models).
            hidden: RNN hidden state.

        Returns:
            ``(pi, new_hidden)`` tuple.
        """
        if self.model_type == "ppo_rnn":
            ac_in = (obs[np.newaxis, :], done[np.newaxis, :])
            new_hidden, pi, _ = self.network.apply(self.params, hidden, ac_in)
            return pi, new_hidden
        if self.model_type == "ppo_rnd":
            pi, _, _ = self.network.apply(self.params, obs)
            return pi, hidden
        pi, _ = self.network.apply(self.params, obs)
        return pi, hidden