Spaces:
Sleeping
Sleeping
| from collections import namedtuple | |
| import numpy as np | |
| def convert(dictionary): | |
| return namedtuple('GenericDict', dictionary.keys())(**dictionary) | |
| class MultiAgentEnv(object): | |
| def __init__(self, batch_size=None, **kwargs): | |
| # Unpack arguments from sacred | |
| args = kwargs["env_args"] | |
| if isinstance(args, dict): | |
| args = convert(args) | |
| self.args = args | |
| if getattr(args, "seed", None) is not None: | |
| self.seed = args.seed | |
| self.rs = np.random.RandomState(self.seed) # initialise numpy random state | |
| def step(self, actions): | |
| """ Returns reward, terminated, info """ | |
| raise NotImplementedError | |
| def get_obs(self): | |
| """ Returns all agent observations in a list """ | |
| raise NotImplementedError | |
| def get_obs_agent(self, agent_id): | |
| """ Returns observation for agent_id """ | |
| raise NotImplementedError | |
| def get_obs_size(self): | |
| """ Returns the shape of the observation """ | |
| raise NotImplementedError | |
| def get_state(self): | |
| raise NotImplementedError | |
| def get_state_size(self): | |
| """ Returns the shape of the state""" | |
| raise NotImplementedError | |
| def get_avail_actions(self): | |
| raise NotImplementedError | |
| def get_avail_agent_actions(self, agent_id): | |
| """ Returns the available actions for agent_id """ | |
| raise NotImplementedError | |
| def get_total_actions(self): | |
| """ Returns the total number of actions an agent could ever take """ | |
| # TODO: This is only suitable for a discrete 1 dimensional action space for each agent | |
| raise NotImplementedError | |
| def get_stats(self): | |
| raise NotImplementedError | |
| # TODO: Temp hack | |
| def get_agg_stats(self, stats): | |
| return {} | |
| def reset(self): | |
| """ Returns initial observations and states""" | |
| raise NotImplementedError | |
| def render(self): | |
| raise NotImplementedError | |
| def close(self): | |
| raise NotImplementedError | |
| def seed(self, seed): | |
| raise NotImplementedError | |
| def get_env_info(self): | |
| env_info = { | |
| "state_shape": self.get_state_size(), | |
| "obs_shape": self.get_obs_size(), | |
| "n_actions": self.get_total_actions(), | |
| "n_agents": self.n_agents, | |
| "episode_limit": self.episode_limit | |
| } | |
| return env_info | |