Spaces:
Running
Running
| from time import time | |
| from easydict import EasyDict | |
| import pytest | |
| import numpy as np | |
| from dizoo.overcooked.envs import OvercookEnv, OvercookGameEnv | |
| class TestOvercooked: | |
| def test_overcook(self, action_mask): | |
| num_agent = 2 | |
| sum_rew = 0.0 | |
| env = OvercookEnv(EasyDict({'concat_obs': True, 'action_mask': action_mask})) | |
| obs = env.reset() | |
| for _ in range(env._horizon): | |
| action = env.random_action() | |
| timestep = env.step(action) | |
| obs = timestep.obs | |
| if action_mask: | |
| for k, v in obs.items(): | |
| if k not in ['agent_state', 'action_mask']: | |
| assert False | |
| assert v.shape == env.observation_space[k].shape | |
| else: | |
| assert obs.shape == env.observation_space.shape | |
| assert timestep.done | |
| sum_rew += timestep.info['eval_episode_return'][0] | |
| print("sum reward is:", sum_rew) | |
| def test_overcook_game(self, concat_obs): | |
| env = OvercookGameEnv(EasyDict({'concat_obs': concat_obs})) | |
| print('observation space: {}'.format(env.observation_space.shape)) | |
| obs = env.reset() | |
| for _ in range(env._horizon): | |
| action = env.random_action() | |
| timestep = env.step(action) | |
| obs = timestep.obs | |
| assert obs.shape == env.observation_space.shape | |
| assert timestep.done | |
| print("agent 0 sum reward is:", timestep.info[0]['eval_episode_return']) | |
| print("agent 1 sum reward is:", timestep.info[1]['eval_episode_return']) | |