| import gym |
| from gym import spaces |
| from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights |
| from agent.utils import calc_holes_array |
|
|
| import numpy as np |
|
|
|
|
| class CustomObsWrapper(gym.ObservationWrapper): |
| def __init__(self, env): |
| super().__init__(env) |
| self.observation_space = spaces.Dict({ |
| "board": env.observation_space["board"], |
| "piece": env.observation_space["piece"], |
| "holes_list": spaces.Box( |
| low=1, |
| high=env.height, |
| shape=(env.width,), |
| dtype=np.uint8, |
| ), |
| "x": spaces.Discrete(env.width), |
| "y": spaces.Discrete(env.width), |
| "piece_shape": spaces.Box( |
| low=0, |
| high=1, |
| shape=(4, 4), |
| dtype=np.uint8, |
| ), |
| "empty_above": spaces.Box( |
| low=0, |
| high=env.height, |
| shape=(env.width,), |
| dtype=np.uint8, |
| ), |
| "heights": spaces.Box( |
| low=0, |
| high=env.height, |
| shape=(env.width,), |
| dtype=np.uint8, |
| ), |
| "bumps": spaces.Box( |
| low=0, |
| high=env.height, |
| shape=(env.width - 1,), |
| dtype=int, |
| ) |
| }) |
|
|
| def observation(self, obs): |
| board = obs["board"] |
| piece = obs["piece"] |
|
|
| heights = get_heights(board) |
| bumps = get_bumps_from_heights(heights) |
| holes_array = calc_holes_array(self, board, heights) |
| empty_above = np.max(heights) - heights[:] |
| piece_shape = np.zeros((4, 4), dtype=np.uint8) |
| piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:] |
|
|
|
|
|
|
| obs = { |
| "board": board, |
| "x": self.current_pos["x"], |
| "y": self.current_pos["y"], |
| "piece_shape": piece_shape, |
| "piece": piece, |
| "empty_above": empty_above, |
| "holes_list": holes_array, |
| "heights": heights, |
| "bumps": bumps |
| } |
|
|
| return obs |
|
|