Spaces:
Running
Running
| import jax | |
| import jax.numpy as jnp | |
| import chex | |
| from flax import struct | |
| from functools import partial | |
| from typing import Dict, Optional, List, Tuple, Union # noqa: F401 | |
| from jaxmarl.environments.multi_agent_env import MultiAgentEnv, State | |
| from jaxmarl.wrappers.baselines import JaxMARLWrapper | |
| class LogEnvState: | |
| env_state: State | |
| episode_returns: float | |
| episode_lengths: int | |
| returned_episode_returns: float | |
| returned_episode_lengths: int | |
| class LogWrapper(JaxMARLWrapper): | |
| """Log the episode returns and lengths. | |
| NOTE for now for envs where agents terminate at the same time. | |
| Based on the JaxMARL LogWrapper, but modified to support auto-resetting wrapped envs. | |
| """ | |
| def __init__(self, env: MultiAgentEnv, replace_info: bool = False): | |
| super().__init__(env) | |
| self.replace_info = replace_info | |
| def get_avail_actions(self, state): | |
| # Handle both LogEnvState (from run_episodes passing full state) | |
| # and already-unwrapped WrappedEnvState (from training code that | |
| # manually does env_state.env_state before calling). | |
| if isinstance(state, LogEnvState): | |
| return self._env.get_avail_actions(state.env_state) | |
| return self._env.get_avail_actions(state) | |
| def reset(self, key: chex.PRNGKey) -> Tuple[chex.Array, State]: | |
| obs, env_state = self._env.reset(key) | |
| state = LogEnvState( | |
| env_state, | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| ) | |
| return obs, state | |
| def step( | |
| self, | |
| key: chex.PRNGKey, | |
| state: LogEnvState, | |
| action: Union[int, float], | |
| ) -> Tuple[chex.Array, LogEnvState, float, bool, dict]: | |
| obs, env_state, reward, done, info = self._env.step( | |
| key, state.env_state, action | |
| ) | |
| ep_done = done["__all__"] | |
| new_episode_return = state.episode_returns + self._batchify_floats(reward) | |
| new_episode_length = state.episode_lengths + 1 | |
| state = LogEnvState( | |
| env_state=env_state, | |
| episode_returns=new_episode_return * (1 - ep_done), | |
| episode_lengths=new_episode_length * (1 - ep_done), | |
| returned_episode_returns=state.returned_episode_returns * (1 - ep_done) | |
| + new_episode_return * ep_done, | |
| returned_episode_lengths=state.returned_episode_lengths * (1 - ep_done) | |
| + new_episode_length * ep_done, | |
| ) | |
| if self.replace_info: | |
| info = {} | |
| info["returned_episode_returns"] = state.returned_episode_returns | |
| info["returned_episode_lengths"] = state.returned_episode_lengths | |
| info["returned_episode"] = jnp.full((self._env.num_agents,), ep_done) | |
| # for compatibility with auto-resetting wrapped envs | |
| state = jax.tree.map( | |
| lambda x, y: jax.lax.select(ep_done, x, y), | |
| LogEnvState( | |
| env_state, | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| jnp.zeros((self._env.num_agents,)), | |
| ), | |
| state) | |
| return obs, state, reward, done, info |