privateboss commited on
Commit
93dd337
·
verified ·
1 Parent(s): e57d2f3

Create Snake_EnvAndAgent.py

Browse files
Files changed (1) hide show
  1. Snake_EnvAndAgent.py +251 -0
Snake_EnvAndAgent.py ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gymnasium as gym
2
+ from gymnasium import spaces
3
+ import random
4
+ import pygame
5
+ import numpy as np
6
+ import collections
7
+ from collections import deque
8
+ from Environment_Constants import (
9
+     GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT,
10
+     WHITE, BLACK, GREEN, RED, BLUE,
11
+     UP, DOWN, LEFT, RIGHT,
12
+     REWARD_FOOD, REWARD_COLLISION, REWARD_STEP,
13
+     FPS
14
+ )
15
+
16
+ class SnakeGameEnv(gym.Env):
17
+     metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS}
18
+
19
+     def __init__(self, render_mode=None):
20
+         super().__init__()
21
+         self.grid_size = GRID_SIZE
22
+         self.cell_size = CELL_SIZE
23
+         self.screen_width = SCREEN_WIDTH
24
+         self.screen_height = SCREEN_HEIGHT
25
+
26
+         self.action_space = spaces.Discrete(3)
27
+
28
+         self.observation_space = spaces.Box(low=0, high=1, shape=(11,), dtype=np.float32)
29
+
30
+         self.render_mode = render_mode
31
+         self.window = None
32
+         self.clock = None
33
+
34
+         self._init_game_state()
35
+
36
+     def _init_game_state(self):
37
+         self.snake = deque()
38
+         self.head = (self.grid_size // 2, self.grid_size // 2)
39
+         self.snake.append(self.head)
40
+    
41
+         for _ in range(2):
42
+             self.snake.append((self.head[0], self.head[1] + (_ + 1)))
43
+
44
+         self.direction = UP
45
+         self.score = 0
46
+         self.food = self._place_food()
47
+         self.game_over = False
48
+         self.steps_since_food = 0
49
+
50
+     def _place_food(self):
51
+         while True:
52
+             x = random.randrange(self.grid_size)
53
+             y = random.randrange(self.grid_size)
54
+             food_pos = (x, y)
55
+             if food_pos not in self.snake:
56
+                 return food_pos
57
+
58
+     def _get_observation(self):
59
+      
60
+         obs = np.zeros(11, dtype=np.float32)
61
+
62
+         hx, hy = self.head
63
+
64
+         if self.direction == UP:
65
+             dir_straight = UP
66
+             dir_right = RIGHT
67
+             dir_left = LEFT
68
+         elif self.direction == DOWN:
69
+             dir_straight = DOWN
70
+             dir_right = LEFT
71
+             dir_left = RIGHT
72
+         elif self.direction == LEFT:
73
+             dir_straight = LEFT
74
+             dir_right = UP
75
+             dir_left = DOWN
76
+         elif self.direction == RIGHT:
77
+             dir_straight = RIGHT
78
+             dir_right = DOWN
79
+             dir_left = UP
80
+
81
+         check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1])
82
+         check_pos_right = (hx + dir_right[0], hy + dir_right[1])
83
+         check_pos_left = (hx + dir_left[0], hy + dir_left[1])
84
+
85
+         def is_danger(pos):
86
+             px, py = pos
87
+
88
+             if not (0 <= px < self.grid_size and 0 <= py < self.grid_size):
89
+                 return True
90
+            
91
+             if pos in list(self.snake)[1:]:
92
+                 return True
93
+             return False
94
+
95
+         obs[0] = 1 if is_danger(check_pos_straight) else 0
96
+         obs[1] = 1 if is_danger(check_pos_right) else 0  
97
+         obs[2] = 1 if is_danger(check_pos_left) else 0
98
+
99
+        
100
+         fx, fy = self.food
101
+         if fy < hy: obs[3] = 1
102
+         if fy > hy: obs[4] = 1
103
+         if fx < hx: obs[5] = 1
104
+         if fx > hx: obs[6] = 1
105
+
106
+         if self.direction == UP: obs[7] = 1
107
+         elif self.direction == DOWN: obs[8] = 1
108
+         elif self.direction == LEFT: obs[9] = 1
109
+         elif self.direction == RIGHT: obs[10] = 1
110
+
111
+         return obs
112
+
113
+     def reset(self, seed=None, options=None):
114
+         super().reset(seed=seed)
115
+         self._init_game_state()
116
+
117
+         observation = self._get_observation()
118
+         info = self._get_info()
119
+        
120
+         if self.render_mode == 'human':
121
+             self._render_frame()
122
+
123
+         return observation, info
124
+
125
+     def _get_info(self):
126
+         return {"score": self.score, "snake_length": len(self.snake)}
127
+
128
+     def step(self, action):
129
+      
130
+         current_dir_idx = [UP, DOWN, LEFT, RIGHT].index(self.direction)
131
+        
132
+         if action == 0:
133
+             new_direction = self.direction
134
+         elif action == 1:
135
+             if self.direction == UP: new_direction = RIGHT
136
+             elif self.direction == DOWN: new_direction = LEFT
137
+             elif self.direction == LEFT: new_direction = UP
138
+             elif self.direction == RIGHT: new_direction = DOWN
139
+         elif action == 2:
140
+             if self.direction == UP: new_direction = LEFT
141
+             elif self.direction == DOWN: new_direction = RIGHT
142
+             elif self.direction == LEFT: new_direction = DOWN
143
+             elif self.direction == RIGHT: new_direction = UP
144
+         else:
145
+             raise ValueError(f"Received invalid action={action} which is not part of the action space.")
146
+
147
+         self.direction = new_direction
148
+
149
+         hx, hy = self.head
150
+         dx, dy = self.direction
151
+         new_head = (hx + dx, hy + dy)
152
+
153
+         reward = REWARD_STEP
154
+
155
+         terminated = False
156
+        
157
+         if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size):
158
+             terminated = True
159
+             reward = REWARD_COLLISION
160
+        
161
+         elif new_head in list(self.snake):
162
+             terminated = True
163
+             reward = REWARD_COLLISION
164
+
165
+         if terminated:
166
+             self.game_over = True
167
+         else:
168
+             self.snake.appendleft(new_head)
169
+             self.head = new_head
170
+
171
+             if new_head == self.food:
172
+                 self.score += 1
173
+                 reward = REWARD_FOOD
174
+                 self.food = self._place_food()
175
+                 self.steps_since_food = 0
176
+             else:
177
+                 self.snake.pop()
178
+                 self.steps_since_food += 1
179
+
180
+             if self.steps_since_food > self.grid_size * self.grid_size * 2:
181
+                  terminated = True
182
+                  reward = REWARD_COLLISION
183
+
184
+         observation = self._get_observation()
185
+         info = self._get_info()
186
+         truncated = False
187
+
188
+         if self.render_mode == 'human':
189
+             self._render_frame()
190
+
191
+         return observation, reward, terminated, truncated, info
192
+
193
+     def _render_frame(self):
194
+         if self.window is None and self.render_mode == 'human':
195
+             pygame.init()
196
+             pygame.display.init()
197
+             self.window = pygame.display.set_mode((self.screen_width, self.screen_height))
198
+             pygame.display.set_caption("Snake AI Training")
199
+         if self.clock is None and self.render_mode == 'human':
200
+             self.clock = pygame.time.Clock()
201
+
202
+         if self.render_mode == 'human':
203
+        
204
+             self.window.fill(BLACK)
205
+
206
+             pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size,
207
+                                                self.food[1] * self.cell_size,
208
+                                                self.cell_size, self.cell_size))
209
+
210
+             for i, segment in enumerate(self.snake):
211
+                 color = BLUE if i == 0 else GREEN
212
+                 pygame.draw.rect(self.window, color, (segment[0] * self.cell_size,
213
+                                                       segment[1] * self.cell_size,
214
+                                                       self.cell_size, self.cell_size))
215
+
216
+             for x in range(0, self.screen_width, self.cell_size):
217
+                 pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height))
218
+             for y in range(0, self.screen_height, self.cell_size):
219
+                 pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y))
220
+
221
+             font = pygame.font.Font(None, 25)
222
+             text = font.render(f"Score: {self.score}", True, WHITE)
223
+             self.window.blit(text, (5, 5))
224
+
225
+             pygame.event.pump()
226
+             pygame.display.flip()
227
+
228
+             self.clock.tick(self.metadata["render_fps"])
229
+         elif self.render_mode == "rgb_array":
230
+
231
+             surf = pygame.Surface((self.screen_width, self.screen_height))
232
+             surf.fill(BLACK)
233
+
234
+             pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size,
235
+                                          self.food[1] * self.cell_size,
236
+                                          self.cell_size, self.cell_size))
237
+            
238
+             for i, segment in enumerate(self.snake):
239
+                 color = BLUE if i == 0 else GREEN
240
+                 pygame.draw.rect(surf, color, (segment[0] * self.cell_size,
241
+                                                segment[1] * self.cell_size,
242
+                                                self.cell_size, self.cell_size))
243
+
244
+             return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2))
245
+
246
+     def close(self):
247
+         if self.window is not None:
248
+             pygame.display.quit()
249
+             pygame.quit()
250
+             self.window = None
251
+             self.clock = None