Update Environment_Wrapper.py
Browse files- Environment_Wrapper.py +73 -71
Environment_Wrapper.py
CHANGED
|
@@ -1,72 +1,74 @@
|
|
| 1 |
-
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
import
|
| 5 |
-
|
| 6 |
-
|
| 7 |
-
|
| 8 |
-
|
| 9 |
-
|
| 10 |
-
|
| 11 |
-
self.
|
| 12 |
-
|
| 13 |
-
self.
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
self.observation_shape =
|
| 24 |
-
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
if self.
|
| 39 |
-
frame = cv2.
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
| 49 |
-
|
| 50 |
-
|
| 51 |
-
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
|
| 55 |
-
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
|
| 61 |
-
|
| 62 |
-
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
info['
|
| 66 |
-
|
| 67 |
-
|
| 68 |
-
|
| 69 |
-
self.
|
| 70 |
-
|
| 71 |
-
|
|
|
|
|
|
|
| 72 |
return stacked_frames, modified_reward, terminated, truncated, info
|
|
|
|
| 1 |
+
#G
|
| 2 |
+
#D
|
| 3 |
+
import gymnasium as gym
|
| 4 |
+
import numpy as np
|
| 5 |
+
from collections import deque
|
| 6 |
+
import cv2
|
| 7 |
+
|
| 8 |
+
class CarRacingEnvWrapper(gym.Wrapper):
|
| 9 |
+
def __init__(self, env, num_stack_frames=4, grayscale=True, resize_dim=(84, 84)):
|
| 10 |
+
super().__init__(env)
|
| 11 |
+
self.num_stack_frames = num_stack_frames
|
| 12 |
+
self.grayscale = grayscale
|
| 13 |
+
self.resize_dim = resize_dim
|
| 14 |
+
|
| 15 |
+
self.frames = deque(maxlen=num_stack_frames)
|
| 16 |
+
|
| 17 |
+
original_shape = self.env.observation_space.shape
|
| 18 |
+
if grayscale:
|
| 19 |
+
|
| 20 |
+
original_shape = original_shape[:-1]
|
| 21 |
+
|
| 22 |
+
if resize_dim:
|
| 23 |
+
self.observation_shape = (resize_dim[1], resize_dim[0])
|
| 24 |
+
else:
|
| 25 |
+
self.observation_shape = original_shape[:2]
|
| 26 |
+
|
| 27 |
+
self.observation_space = gym.spaces.Box(
|
| 28 |
+
low=0, high=255,
|
| 29 |
+
shape=(self.observation_shape[0], self.observation_shape[1], num_stack_frames),
|
| 30 |
+
dtype=np.uint8
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
self.OFF_TRACK_PENALTY_SCALE = 0.1
|
| 34 |
+
self.GRASS_COLOR_THRESHOLD = 180
|
| 35 |
+
|
| 36 |
+
def _preprocess_frame(self, frame):
|
| 37 |
+
|
| 38 |
+
if self.grayscale:
|
| 39 |
+
frame = cv2.cvtColor(frame, cv2.COLOR_RGB2GRAY)
|
| 40 |
+
if self.resize_dim:
|
| 41 |
+
frame = cv2.resize(frame, self.resize_dim, interpolation=cv2.INTER_AREA)
|
| 42 |
+
return frame
|
| 43 |
+
|
| 44 |
+
def reset(self, **kwargs):
|
| 45 |
+
observation, info = self.env.reset(**kwargs)
|
| 46 |
+
processed_frame = self._preprocess_frame(observation)
|
| 47 |
+
|
| 48 |
+
for _ in range(self.num_stack_frames):
|
| 49 |
+
self.frames.append(processed_frame)
|
| 50 |
+
|
| 51 |
+
stacked_frames = np.stack(self.frames, axis=-1)
|
| 52 |
+
return stacked_frames, info
|
| 53 |
+
|
| 54 |
+
def step(self, action):
|
| 55 |
+
|
| 56 |
+
observation, reward, terminated, truncated, info = self.env.step(action)
|
| 57 |
+
|
| 58 |
+
modified_reward = reward
|
| 59 |
+
|
| 60 |
+
is_on_grass = np.mean(observation[:, :, 1]) > self.GRASS_COLOR_THRESHOLD
|
| 61 |
+
|
| 62 |
+
if is_on_grass:
|
| 63 |
+
modified_reward -= self.OFF_TRACK_PENALTY_SCALE
|
| 64 |
+
|
| 65 |
+
info['is_on_grass'] = is_on_grass
|
| 66 |
+
info['original_reward'] = reward
|
| 67 |
+
info['modified_reward'] = modified_reward
|
| 68 |
+
|
| 69 |
+
processed_frame = self._preprocess_frame(observation)
|
| 70 |
+
|
| 71 |
+
self.frames.append(processed_frame)
|
| 72 |
+
stacked_frames = np.stack(self.frames, axis=-1)
|
| 73 |
+
|
| 74 |
return stacked_frames, modified_reward, terminated, truncated, info
|