| import gymnasium as gym | |
| import numpy as np | |
| import cv2 | |
| from gymnasium.spaces import Box | |
| class CarRacingEnvWrapper(gym.Wrapper): | |
| def __init__(self, env, num_stack_frames, grayscale, resize_dim): | |
| super().__init__(env) | |
| self.num_stack_frames = num_stack_frames | |
| self.grayscale = grayscale | |
| self.resize_dim = resize_dim | |
| self.orig_observation_space = self.env.observation_space | |
| if self.grayscale: | |
| shape = (self.resize_dim[0], self.resize_dim[1], self.num_stack_frames) | |
| else: | |
| shape = (self.resize_dim[0], self.resize_dim[1], self.env.observation_space.shape[2] * self.num_stack_frames) | |
| self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8) | |
| self.frames = np.zeros(shape, dtype=np.uint8) | |
| def _process_obs(self, obs): | |
| obs = obs.astype(np.float32) / 255.0 | |
| if self.grayscale: | |
| grayscale_obs = np.dot(obs[...,:3], [0.2989, 0.5870, 0.1140]) | |
| resized_obs = cv2.resize(grayscale_obs, self.resize_dim, interpolation=cv2.INTER_AREA) | |
| processed_obs = np.expand_dims(resized_obs, axis=-1) | |
| else: | |
| processed_obs = cv2.resize(obs, self.resize_dim, interpolation=cv2.INTER_AREA) | |
| return (processed_obs * 255).astype(np.uint8) | |
| def _stack_frames(self, processed_obs): | |
| self.frames = np.roll(self.frames, shift=-1, axis=-1) | |
| if self.grayscale: | |
| self.frames[..., -1] = processed_obs[..., 0] | |
| else: | |
| self.frames[..., -3:] = processed_obs | |
| return self.frames | |
| def reset(self, **kwargs): | |
| obs, info = self.env.reset(**kwargs) | |
| processed_obs = self._process_obs(obs) | |
| if self.grayscale: | |
| self.frames = np.stack([processed_obs[..., 0]] * self.num_stack_frames, axis=-1) | |
| else: | |
| self.frames = np.stack([processed_obs] * self.num_stack_frames, axis=-1) | |
| return self.frames, info | |
| def step(self, action): | |
| obs, reward, terminated, truncated, info = self.env.step(action) | |
| is_stuck = np.mean(obs[64:72, 32:64, 1]) < 10 | |
| if is_stuck and reward < 0: | |
| truncated = True | |
| is_on_grass = np.mean(obs[64:72, 32:64, 1]) > 100 | |
| if is_on_grass: | |
| reward -= 10 | |
| if info.get('is_complete'): | |
| reward += 100 | |
| processed_obs = self._process_obs(obs) | |
| stacked_frames = self._stack_frames(processed_obs) | |
| info['original_reward'] = info.get('original_reward', reward) | |
| return stacked_frames, reward, terminated, truncated, info |