import jax import jax.numpy as jnp import chex from flax import struct from functools import partial from typing import Dict, Optional, List, Tuple, Union # noqa: F401 from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State from jaxmarl.wrappers.baselines import JaxMARLWrapper @struct.dataclass class LogEnvState: env_state: State episode_returns: float episode_lengths: int returned_episode_returns: float returned_episode_lengths: int class LogWrapper(JaxMARLWrapper): """Log the episode returns and lengths. NOTE for now for envs where agents terminate at the same time. Based on the JaxMARL LogWrapper, but modified to support auto-resetting wrapped envs. """ def __init__(self, env: MultiAgentEnv, replace_info: bool = False): super().__init__(env) self.replace_info = replace_info @partial(jax.jit, static_argnums=(0,)) def get_avail_actions(self, state): # Handle both LogEnvState (from run_episodes passing full state) # and already-unwrapped WrappedEnvState (from training code that # manually does env_state.env_state before calling). if isinstance(state, LogEnvState): return self._env.get_avail_actions(state.env_state) return self._env.get_avail_actions(state) @partial(jax.jit, static_argnums=(0,)) def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]: obs, env_state = self._env.reset(key) state = LogEnvState( env_state, jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), ) return obs, state @partial(jax.jit, static_argnums=(0,)) def step( self, key: chex.PRNGKey, state: LogEnvState, action: Union[int, float], ) -> Tuple[chex.Array, LogEnvState, float, bool, dict]: obs, env_state, reward, done, info = self._env.step( key, state.env_state, action ) ep_done = done["__all__"] new_episode_return = state.episode_returns + self._batchify_floats(reward) new_episode_length = state.episode_lengths + 1 state = LogEnvState( env_state=env_state, episode_returns=new_episode_return * (1 - ep_done), episode_lengths=new_episode_length * (1 - ep_done), returned_episode_returns=state.returned_episode_returns * (1 - ep_done) + new_episode_return * ep_done, returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done) + new_episode_length * ep_done, ) if self.replace_info: info = {} info["returned_episode_returns"] = state.returned_episode_returns info["returned_episode_lengths"] = state.returned_episode_lengths info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done) # for compatibility with auto-resetting wrapped envs state = jax.tree.map( lambda x, y: jax.lax.select(ep_done, x, y), LogEnvState( env_state, jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), jnp.zeros((self._env.num_agents,)), ), state) return obs, state, reward, done, info