| from typing import Any, Optional |
| from gym import error |
| from mlagents_envs.base_env import BaseEnv |
| from pettingzoo import AECEnv |
|
|
| from mlagents_envs.envs.unity_pettingzoo_base_env import UnityPettingzooBaseEnv |
|
|
|
|
| class UnityAECEnv(UnityPettingzooBaseEnv, AECEnv): |
| """ |
| Unity AEC (PettingZoo) environment wrapper. |
| """ |
|
|
| def __init__(self, env: BaseEnv, seed: Optional[int] = None): |
| """ |
| Initializes a Unity AEC environment wrapper. |
| |
| :param env: The UnityEnvironment that is being wrapped. |
| :param seed: The seed for the action spaces of the agents. |
| """ |
| super().__init__(env, seed) |
|
|
| def step(self, action: Any) -> None: |
| """ |
| Sets the action of the active agent and get the observation, reward, done |
| and info of the next agent. |
| :param action: The action for the active agent |
| """ |
| self._assert_loaded() |
| if len(self._live_agents) <= 0: |
| raise error.Error( |
| "You must reset the environment before you can perform a step" |
| ) |
|
|
| |
| current_agent = self._agents[self._agent_index] |
| self._process_action(current_agent, action) |
|
|
| self._agent_index += 1 |
| |
| for k in self._rewards.keys(): |
| self._rewards[k] = 0 |
|
|
| if self._agent_index >= len(self._agents) and self.num_agents > 0: |
| |
| self._step() |
| self._live_agents.sort() |
|
|
| def observe(self, agent_id): |
| """ |
| Returns the observation an agent currently can make. `last()` calls this function. |
| """ |
| return ( |
| self._observations[agent_id], |
| self._cumm_rewards[agent_id], |
| self._dones[agent_id], |
| self._infos[agent_id], |
| ) |
|
|
| def last(self, observe=True): |
| """ |
| returns observation, cumulative reward, done, info for the current agent (specified by self.agent_selection) |
| """ |
| obs, reward, done, info = self.observe(self._agents[self._agent_index]) |
| return obs if observe else None, reward, done, info |
|
|
| @property |
| def agent_selection(self): |
| if not self._live_agents: |
| |
| return self._agents[0] |
| return self._agents[self._agent_index] |
|
|