jaxaht-benchmark / envs /log_wrapper.py
lainwired's picture
Initial jaxaht-benchmark deployment
5146e76
import jax
import jax.numpy as jnp
import chex
from flax import struct
from functools import partial
from typing import Dict, Optional, List, Tuple, Union # noqa: F401
from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State
from jaxmarl.wrappers.baselines import JaxMARLWrapper
@struct.dataclass
class LogEnvState:
env_state: State
episode_returns: float
episode_lengths: int
returned_episode_returns: float
returned_episode_lengths: int
class LogWrapper(JaxMARLWrapper):
"""Log the episode returns and lengths.
NOTE for now for envs where agents terminate at the same time.
Based on the JaxMARL LogWrapper, but modified to support auto-resetting wrapped envs.
"""
def __init__(self, env: MultiAgentEnv, replace_info: bool = False):
super().__init__(env)
self.replace_info = replace_info
@partial(jax.jit, static_argnums=(0,))
def get_avail_actions(self, state):
# Handle both LogEnvState (from run_episodes passing full state)
# and already-unwrapped WrappedEnvState (from training code that
# manually does env_state.env_state before calling).
if isinstance(state, LogEnvState):
return self._env.get_avail_actions(state.env_state)
return self._env.get_avail_actions(state)
@partial(jax.jit, static_argnums=(0,))
def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]:
obs, env_state = self._env.reset(key)
state = LogEnvState(
env_state,
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
)
return obs, state
@partial(jax.jit, static_argnums=(0,))
def step(
self,
key: chex.PRNGKey,
state: LogEnvState,
action: Union[int, float],
) -> Tuple[chex.Array, LogEnvState, float, bool, dict]:
obs, env_state, reward, done, info = self._env.step(
key, state.env_state, action
)
ep_done = done["__all__"]
new_episode_return = state.episode_returns + self._batchify_floats(reward)
new_episode_length = state.episode_lengths + 1
state = LogEnvState(
env_state=env_state,
episode_returns=new_episode_return * (1 - ep_done),
episode_lengths=new_episode_length * (1 - ep_done),
returned_episode_returns=state.returned_episode_returns * (1 - ep_done)
+ new_episode_return * ep_done,
returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done)
+ new_episode_length * ep_done,
)
if self.replace_info:
info = {}
info["returned_episode_returns"] = state.returned_episode_returns
info["returned_episode_lengths"] = state.returned_episode_lengths
info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done)
# for compatibility with auto-resetting wrapped envs
state = jax.tree.map(
lambda x, y: jax.lax.select(ep_done, x, y),
LogEnvState(
env_state,
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
jnp.zeros((self._env.num_agents,)),
),
state)
return obs, state, reward, done, info