| import numpy as np |
| import gym |
|
|
|
|
| class SymetryWrapper(gym.Wrapper): |
| """ |
| :param env: (gym.Env) Gym environment that will be wrapped |
| """ |
|
|
| def __init__(self, env): |
| |
|
|
| self.symetry = False |
| super(SymetryWrapper, self).__init__(env) |
|
|
| def reset(self): |
| """ |
| Reset the environment |
| """ |
| obs = self.env.reset() |
|
|
| obs = self.post_obs(obs) |
|
|
| return obs |
|
|
| def step(self, action): |
| """ |
| :param action: ([float] or int) Action taken by the agent |
| :return: (np.ndarray, float, bool, dict) observation, reward, is the episode over?, additional informations |
| """ |
| if self.symetry: |
| action = symetrise_action(action) |
|
|
| obs, reward, done, info = self.env.step(action) |
|
|
| obs = self.post_obs(obs) |
|
|
| return obs, reward, done, info |
|
|
| def post_obs(self, obs): |
| self.symetry = get_symetry(obs) |
| if self.symetry: |
| obs = symetrise_obs(obs) |
| return obs |
|
|
|
|
| def get_symetry(obs): |
| blue_obs, red_obs, blue_fire, red_fire = obs |
|
|
| |
| count = 0 |
| for this_obs in (blue_obs, red_obs): |
| for d in this_obs: |
| add = 1 if (d[1] < 0.5) else 0 |
| count += add |
|
|
| |
| symetry = bool(2*count < (len(blue_obs) + len(red_obs))) |
|
|
| return symetry |
|
|
|
|
| def symetrise_obs(obs): |
|
|
| blue_obs, red_obs, blue_fire, red_fire = obs |
|
|
| for this_obs in (blue_obs, red_obs): |
| |
| this_obs[:, 1] = 1 - this_obs[:, 1] |
| this_obs[:, 4] = 1 - this_obs[:, 4] |
|
|
| return blue_obs, red_obs, blue_fire, red_fire |
|
|
|
|
| def symetrise_action(action): |
|
|
| blue_action, red_action = action |
|
|
| for this_action in (blue_action, red_action): |
| for act in this_action: |
|
|
| |
| act[1] = - act[1] |
|
|
| action = blue_action, red_action |
|
|
| return action |
|
|
|
|
| def test_symetrise_obs(): |
|
|
| obs = np.arange(12).reshape(2, 6), np.arange(12).reshape(2, 6), np.random.random((1, 1)), np.random.random((1, 1)) |
| print(obs) |
| symetrise_obs(obs) |
| print(obs) |
|
|