| from flax.core import FrozenDict |
| import flax.linen as nn |
| import jax |
| import jax.numpy as jnp |
| from functools import partial |
|
|
|
|
| |
|
|
|
|
| def scale(state: jnp.ndarray) -> jnp.ndarray: |
| return state / 255.0 |
|
|
|
|
| class Torso(nn.Module): |
| initialization_type: str |
|
|
| @nn.compact |
| def __call__(self, state): |
| if self.initialization_type == "dqn": |
| initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") |
| elif self.initialization_type == "iqn": |
| initializer = nn.initializers.variance_scaling( |
| scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" |
| ) |
|
|
| x = nn.Conv(features=32, kernel_size=(8, 8), strides=(4, 4), kernel_init=initializer)(state) |
| x = nn.relu(x) |
| x = nn.Conv(features=64, kernel_size=(4, 4), strides=(2, 2), kernel_init=initializer)(x) |
| x = nn.relu(x) |
| x = nn.Conv(features=64, kernel_size=(3, 3), strides=(1, 1), kernel_init=initializer)(x) |
| x = nn.relu(x) |
|
|
| return x.flatten() |
|
|
|
|
| class Head(nn.Module): |
| n_actions: int |
| initialization_type: str |
|
|
| @nn.compact |
| def __call__(self, x): |
| if self.initialization_type == "dqn": |
| initializer = nn.initializers.variance_scaling(scale=1.0, mode="fan_avg", distribution="truncated_normal") |
| elif self.initialization_type == "iqn": |
| initializer = nn.initializers.variance_scaling( |
| scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform" |
| ) |
|
|
| x = nn.Dense(features=512, kernel_init=initializer)(x) |
| x = nn.relu(x) |
|
|
| return nn.Dense(features=self.n_actions, kernel_init=initializer)(x) |
|
|
|
|
| class QuantileEmbedding(nn.Module): |
| n_features: int = 7744 |
| quantile_embedding_dim: int = 64 |
|
|
| @nn.compact |
| def __call__(self, key, n_quantiles): |
| initializer = nn.initializers.variance_scaling(scale=1.0 / jnp.sqrt(3.0), mode="fan_in", distribution="uniform") |
|
|
| quantiles = jax.random.uniform(key, shape=(n_quantiles, 1)) |
| arange = jnp.arange(1, self.quantile_embedding_dim + 1).reshape((1, self.quantile_embedding_dim)) |
|
|
| quantile_embedding = nn.Dense(features=self.n_features, kernel_init=initializer)( |
| jnp.cos(jnp.pi * quantiles @ arange) |
| ) |
| |
| return (nn.relu(quantile_embedding), jnp.squeeze(quantiles, axis=1)) |
|
|
|
|
| |
|
|
|
|
| class AtariSharediDQNNet: |
| def __init__(self, n_actions: int) -> None: |
| self.n_heads = 5 |
| self.n_actions = n_actions |
| self.torso = Torso("dqn") |
| self.head = Head(self.n_actions, "dqn") |
|
|
| def apply(self, params: FrozenDict, idx_head: int, state: jnp.ndarray) -> jnp.ndarray: |
| feature = self.torso.apply( |
| params[f"torso_params_{min(idx_head, 1)}"], |
| state, |
| ) |
|
|
| return self.head.apply(params[f"head_params_{idx_head}"], feature) |
|
|
|
|
| class AtariiDQN: |
| def __init__(self, n_actions: int, idx_head: int) -> None: |
| self.network = AtariSharediDQNNet(n_actions) |
| self.idx_head = idx_head |
|
|
| @partial(jax.jit, static_argnames="self") |
| def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: |
| return jnp.argmax(self.network.apply(params, self.idx_head, scale(state))).astype(jnp.int8) |
|
|
|
|
| |
|
|
|
|
| class AtariSharediIQNNet: |
| def __init__(self, n_actions: int) -> None: |
| self.n_heads = 4 |
| self.n_actions = n_actions |
| self.torso = Torso("iqn") |
| self.quantile_embedding = QuantileEmbedding() |
| self.head = Head(self.n_actions, "iqn") |
|
|
| def apply( |
| self, params: FrozenDict, idx_head: int, state: jnp.ndarray, key: jax.random.PRNGKey, n_quantiles: int |
| ) -> jnp.ndarray: |
| |
| state_feature = self.torso.apply( |
| jax.tree_util.tree_map( |
| lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["torso_params"] |
| ), |
| state, |
| ) |
|
|
| |
| quantiles_feature, _ = self.quantile_embedding.apply( |
| jax.tree_util.tree_map( |
| lambda param: param[jax.lax.cond(idx_head >= 1, lambda: 1, lambda: 0)], params["quantiles_params"] |
| ), |
| key, |
| n_quantiles, |
| ) |
|
|
| |
| feature = jax.vmap( |
| lambda quantile_feature_, state_feature_: quantile_feature_ * state_feature_, in_axes=(0, None) |
| )(quantiles_feature, state_feature) |
|
|
| return self.head.apply( |
| jax.tree_util.tree_map(lambda param: param[idx_head], params["head_params"]), feature |
| ) |
|
|
|
|
| class AtariiIQN: |
| def __init__(self, n_actions: int, idx_head: int) -> None: |
| self.network = AtariSharediIQNNet(n_actions) |
| self.idx_head = idx_head |
| self.n_quantiles_policy = 32 |
|
|
| @partial(jax.jit, static_argnames="self") |
| def best_action(self, params: FrozenDict, state: jnp.ndarray, key: jax.random.PRNGKeyArray) -> jnp.int8: |
| |
| q_quantiles = self.network.apply(params, self.idx_head, scale(state), key, self.n_quantiles_policy) |
| q_values = jnp.mean(q_quantiles, axis=0) |
|
|
| return jnp.argmax(q_values).astype(jnp.int8) |
|
|