import gymnasium as gym from gymnasium import spaces import random import pygame import numpy as np import collections from collections import deque from Environment_Constants import ( GRID_SIZE, CELL_SIZE, SCREEN_WIDTH, SCREEN_HEIGHT, WHITE, BLACK, GREEN, RED, BLUE, UP, DOWN, LEFT, RIGHT, REWARD_FOOD, REWARD_COLLISION, REWARD_STEP, FPS, OBSERVATION_SPACE_SIZE ) class SnakeGameEnv(gym.Env): metadata = {'render_modes': ['human', 'rgb_array'], 'render_fps': FPS} def __init__(self, render_mode=None): super().__init__() self.grid_size = GRID_SIZE self.cell_size = CELL_SIZE self.screen_width = SCREEN_WIDTH self.screen_height = SCREEN_HEIGHT self.action_space = spaces.Discrete(3) self.observation_space = spaces.Box(low=0, high=1, shape=(OBSERVATION_SPACE_SIZE,), dtype=np.float32) self.render_mode = render_mode self.window = None self.clock = None self._init_game_state() def _init_game_state(self): self.snake = deque() self.head = (self.grid_size // 2, self.grid_size // 2) self.snake.append(self.head) self.snake.append((self.head[0], self.head[1] + 1)) self.snake.append((self.head[0], self.head[1] + 2)) self.direction = UP self.score = 0 self.food = self._place_food() self.game_over = False self.steps_since_food = 0 self.length = len(self.snake) def _place_food(self): while True: x = random.randrange(self.grid_size) y = random.randrange(self.grid_size) food_pos = (x, y) if food_pos not in self.snake: return food_pos def _is_position_safe_for_observation(self, pos): px, py = pos if not (0 <= px < self.grid_size and 0 <= py < self.grid_size): return False if pos in list(self.snake)[1:]: return False return True def _get_observation(self): obs = np.zeros(OBSERVATION_SPACE_SIZE, dtype=np.float32) hx, hy = self.head if self.direction == UP: dir_straight = UP dir_right = RIGHT dir_left = LEFT elif self.direction == DOWN: dir_straight = DOWN dir_right = LEFT dir_left = RIGHT elif self.direction == LEFT: dir_straight = LEFT dir_right = UP dir_left = DOWN elif self.direction == RIGHT: dir_straight = RIGHT dir_right = DOWN dir_left = UP check_pos_straight = (hx + dir_straight[0], hy + dir_straight[1]) check_pos_right = (hx + dir_right[0], hy + dir_right[1]) check_pos_left = (hx + dir_left[0], hy + dir_left[1]) obs[0] = 1 if not self._is_position_safe_for_observation(check_pos_straight) else 0 obs[1] = 1 if not self._is_position_safe_for_observation(check_pos_right) else 0 obs[2] = 1 if not self._is_position_safe_for_observation(check_pos_left) else 0 fx, fy = self.food if fy < hy: obs[3] = 1 if fy > hy: obs[4] = 1 if fx < hx: obs[5] = 1 if fx > hx: obs[6] = 1 if self.direction == UP: obs[7] = 1 elif self.direction == DOWN: obs[8] = 1 elif self.direction == LEFT: obs[9] = 1 elif self.direction == RIGHT: obs[10] = 1 return obs def _get_action_mask(self): mask = np.array([True, True, True], dtype=bool) hx, hy = self.head potential_directions = [ self.direction, None, None ] if self.direction == UP: potential_directions[1] = RIGHT potential_directions[2] = LEFT elif self.direction == DOWN: potential_directions[1] = LEFT potential_directions[2] = RIGHT elif self.direction == LEFT: potential_directions[1] = UP potential_directions[2] = DOWN elif self.direction == RIGHT: potential_directions[1] = DOWN potential_directions[2] = UP def _is_potential_move_illegal(pos_to_check, current_snake, food_pos): if not (0 <= pos_to_check[0] < self.grid_size and 0 <= pos_to_check[1] < self.grid_size): return True if pos_to_check in list(current_snake)[:-1]: return True if pos_to_check == current_snake[-1]: if pos_to_check != food_pos: return True return False for action_idx, new_dir in enumerate(potential_directions): dx, dy = new_dir potential_head = (hx + dx, hy + dy) if _is_potential_move_illegal(potential_head, self.snake, self.food): mask[action_idx] = False if not np.any(mask): print(f"Warning: All actions masked out at head {self.head}, direction {self.direction}, food {self.food}. Attempting to find a fallback action.") found_fallback = False for i in range(3): # Check Straight, Right, Left dx, dy = potential_directions[i] potential_head = (hx + dx, hy + dy) if not _is_potential_move_illegal(potential_head, self.snake, self.food): mask[i] = True found_fallback = True if not found_fallback: mask[np.random.choice(3)] = True print("Critical Warning: No legal actions found even after fallback logic. Enabling a random action to prevent deadlock.") return mask def reset(self, seed=None, options=None): super().reset(seed=seed) self._init_game_state() observation = self._get_observation() info = self._get_info() if not np.any(info['action_mask']): print("Warning: No valid actions found in initial reset state.") if self.render_mode == 'human': self._render_frame() return observation, info def _get_info(self): """Returns environment information, including the action mask.""" return { "score": self.score, "snake_length": len(self.snake), "action_mask": self._get_action_mask() } def step(self, action): new_direction = self.direction if action == 1: if self.direction == UP: new_direction = RIGHT elif self.direction == DOWN: new_direction = LEFT elif self.direction == LEFT: new_direction = UP elif self.direction == RIGHT: new_direction = DOWN elif action == 2: if self.direction == UP: new_direction = LEFT elif self.direction == DOWN: new_direction = RIGHT elif self.direction == LEFT: new_direction = DOWN elif self.direction == RIGHT: new_direction = UP elif action != 0: raise ValueError(f"Received invalid action={action} which is not part of the action space.") self.direction = new_direction hx, hy = self.head dx, dy = self.direction new_head = (hx + dx, hy + dy) reward = REWARD_STEP terminated = False truncated = False if not (0 <= new_head[0] < self.grid_size and 0 <= new_head[1] < self.grid_size): terminated = True reward = REWARD_COLLISION elif new_head in list(self.snake)[:-1]: terminated = True reward = REWARD_COLLISION elif new_head == self.snake[-1] and new_head != self.food: terminated = True reward = REWARD_COLLISION if terminated: self.game_over = True else: self.snake.appendleft(new_head) self.head = new_head if new_head == self.food: self.score += 1 self.length += 1 reward = REWARD_FOOD self.food = self._place_food() self.steps_since_food = 0 else: self.snake.pop() self.steps_since_food += 1 if self.steps_since_food >= self.grid_size * self.grid_size * 1.5: terminated = True truncated = True reward = REWARD_COLLISION observation = self._get_observation() info = self._get_info() if self.render_mode == 'human': self._render_frame() return observation, reward, terminated, truncated, info def _render_frame(self): if self.window is None and self.render_mode == 'human': pygame.init() pygame.display.init() self.window = pygame.display.set_mode((self.screen_width, self.screen_height)) pygame.display.set_caption("Snake AI Training") if self.clock is None and self.render_mode == 'human': self.clock = pygame.time.Clock() if self.render_mode == 'human': self.window.fill(BLACK) pygame.draw.rect(self.window, RED, (self.food[0] * self.cell_size, self.food[1] * self.cell_size, self.cell_size, self.cell_size)) for i, segment in enumerate(self.snake): color = BLUE if i == 0 else GREEN pygame.draw.rect(self.window, color, (segment[0] * self.cell_size, segment[1] * self.cell_size, self.cell_size, self.cell_size)) for x in range(0, self.screen_width, self.cell_size): pygame.draw.line(self.window, WHITE, (x, 0), (x, self.screen_height)) for y in range(0, self.screen_height, self.cell_size): pygame.draw.line(self.window, WHITE, (0, y), (self.screen_width, y)) font = pygame.font.Font(None, 25) text = font.render(f"Score: {self.score}", True, WHITE) self.window.blit(text, (5, 5)) pygame.event.pump() pygame.display.flip() self.clock.tick(self.metadata["render_fps"]) elif self.render_mode == "rgb_array": surf = pygame.Surface((self.screen_width, self.screen_height)) surf.fill(BLACK) pygame.draw.rect(surf, RED, (self.food[0] * self.cell_size, self.food[1] * self.cell_size, self.cell_size, self.cell_size)) for i, segment in enumerate(self.snake): color = BLUE if i == 0 else GREEN pygame.draw.rect(surf, color, (segment[0] * self.cell_size, segment[1] * self.cell_size, self.cell_size, self.cell_size)) return np.transpose(np.array(pygame.surfarray.pixels3d(surf)), axes=(1, 0, 2)) def close(self): if self.window is not None: pygame.display.quit() pygame.quit() self.window = None self.clock = None