Car_Race_AI_V0 / Environment_Wrapper.py
privateboss's picture
Upload 6 files
f0c8b65 verified
import gymnasium as gym
import numpy as np
import cv2
from gymnasium.spaces import Box
class CarRacingEnvWrapper(gym.Wrapper):
def __init__(self, env, num_stack_frames, grayscale, resize_dim):
super().__init__(env)
self.num_stack_frames = num_stack_frames
self.grayscale = grayscale
self.resize_dim = resize_dim
self.orig_observation_space = self.env.observation_space
if self.grayscale:
shape = (self.resize_dim[0], self.resize_dim[1], self.num_stack_frames)
else:
shape = (self.resize_dim[0], self.resize_dim[1], self.env.observation_space.shape[2] * self.num_stack_frames)
self.observation_space = Box(low=0, high=255, shape=shape, dtype=np.uint8)
self.frames = np.zeros(shape, dtype=np.uint8)
def _process_obs(self, obs):
obs = obs.astype(np.float32) / 255.0
if self.grayscale:
grayscale_obs = np.dot(obs[...,:3], [0.2989, 0.5870, 0.1140])
resized_obs = cv2.resize(grayscale_obs, self.resize_dim, interpolation=cv2.INTER_AREA)
processed_obs = np.expand_dims(resized_obs, axis=-1)
else:
processed_obs = cv2.resize(obs, self.resize_dim, interpolation=cv2.INTER_AREA)
return (processed_obs * 255).astype(np.uint8)
def _stack_frames(self, processed_obs):
self.frames = np.roll(self.frames, shift=-1, axis=-1)
if self.grayscale:
self.frames[..., -1] = processed_obs[..., 0]
else:
self.frames[..., -3:] = processed_obs
return self.frames
def reset(self, **kwargs):
obs, info = self.env.reset(**kwargs)
processed_obs = self._process_obs(obs)
if self.grayscale:
self.frames = np.stack([processed_obs[..., 0]] * self.num_stack_frames, axis=-1)
else:
self.frames = np.stack([processed_obs] * self.num_stack_frames, axis=-1)
return self.frames, info
def step(self, action):
obs, reward, terminated, truncated, info = self.env.step(action)
is_stuck = np.mean(obs[64:72, 32:64, 1]) < 10
if is_stuck and reward < 0:
truncated = True
is_on_grass = np.mean(obs[64:72, 32:64, 1]) > 100
if is_on_grass:
reward -= 10
if info.get('is_complete'):
reward += 100
processed_obs = self._process_obs(obs)
stacked_frames = self._stack_frames(processed_obs)
info['original_reward'] = info.get('original_reward', reward)
return stacked_frames, reward, terminated, truncated, info