import gym from gym import spaces from tetris_gym.utils.board_utils import get_heights, get_bumps_from_heights from agent.utils import calc_holes_array import numpy as np class CustomObsWrapper(gym.ObservationWrapper): def __init__(self, env): super().__init__(env) self.observation_space = spaces.Dict({ "board": env.observation_space["board"], "piece": env.observation_space["piece"], "holes_list": spaces.Box( low=1, high=env.height, shape=(env.width,), dtype=np.uint8, ), "x": spaces.Discrete(env.width), "y": spaces.Discrete(env.width), "piece_shape": spaces.Box( low=0, high=1, shape=(4, 4), dtype=np.uint8, ), "empty_above": spaces.Box( low=0, high=env.height, shape=(env.width,), dtype=np.uint8, ), "heights": spaces.Box( low=0, high=env.height, shape=(env.width,), dtype=np.uint8, ), "bumps": spaces.Box( low=0, high=env.height, shape=(env.width - 1,), dtype=int, ) }) def observation(self, obs): board = obs["board"] piece = obs["piece"] heights = get_heights(board) bumps = get_bumps_from_heights(heights) holes_array = calc_holes_array(self, board, heights) empty_above = np.max(heights) - heights[:] piece_shape = np.zeros((4, 4), dtype=np.uint8) piece_shape[:len(self.piece), :len(self.piece[0])] = self.piece[:] obs = { "board": board, "x": self.current_pos["x"], "y": self.current_pos["y"], "piece_shape": piece_shape, "piece": piece, "empty_above": empty_above, "holes_list": holes_array, "heights": heights, "bumps": bumps } return obs