import gymnasium as gym import numpy as np import cv2 from gymnasium.spaces import Box class CarRacingEnvWrapper(gym.Wrapper): def __init__(self, env, num_stack_frames, grayscale, resize_dim): super().__init__(env) self.num_stack_frames = num_stack_frames self.grayscale = grayscale self.resize_dim = resize_dim self.orig_observation_space = self.env.observation_space if self.grayscale: shape = (self.resize_dim[0], self.resize_dim[1], self.num_stack_frames) else: shape = (self.resize_dim[0], self.resize_dim[1], self.env.observation_space.shape[2] * self.num_stack_frames) self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8) self.frames = np.zeros(shape, dtype=np.uint8) def _process_obs(self, obs): obs = obs.astype(np.float32) / 255.0 if self.grayscale: grayscale_obs = np.dot(obs[...,:3], [0.2989, 0.5870, 0.1140]) resized_obs = cv2.resize(grayscale_obs, self.resize_dim, interpolation=cv2.INTER_AREA) processed_obs = np.expand_dims(resized_obs, axis=-1) else: processed_obs = cv2.resize(obs, self.resize_dim, interpolation=cv2.INTER_AREA) return (processed_obs * 255).astype(np.uint8) def _stack_frames(self, processed_obs): self.frames = np.roll(self.frames, shift=-1, axis=-1) if self.grayscale: self.frames[..., -1] = processed_obs[..., 0] else: self.frames[..., -3:] = processed_obs return self.frames def reset(self, **kwargs): obs, info = self.env.reset(**kwargs) processed_obs = self._process_obs(obs) if self.grayscale: self.frames = np.stack([processed_obs[..., 0]] * self.num_stack_frames, axis=-1) else: self.frames = np.stack([processed_obs] * self.num_stack_frames, axis=-1) return self.frames, info def step(self, action): obs, reward, terminated, truncated, info = self.env.step(action) is_stuck = np.mean(obs[64:72, 32:64, 1]) < 10 if is_stuck and reward < 0: truncated = True is_on_grass = np.mean(obs[64:72, 32:64, 1]) > 100 if is_on_grass: reward -= 10 if info.get('is_complete'): reward += 100 processed_obs = self._process_obs(obs) stacked_frames = self._stack_frames(processed_obs) info['original_reward'] = info.get('original_reward', reward) return stacked_frames, reward, terminated, truncated, info