import pygame import random import numpy as np import sys from enum import Enum pygame.init() SCREEN_WIDTH = 640 SCREEN_HEIGHT = 480 BLOCK_SIZE = 20 class Direction(Enum): UP = 1 DOWN = 2 LEFT = 3 RIGHT = 4 class SnakeEnv: def __init__(self, render=True): """ render=False → runs invisibly (fast training, no window) render=True → shows the game window (watch it play) """ self.render_mode = render self.w = SCREEN_WIDTH self.h = SCREEN_HEIGHT if self.render_mode: self.display = pygame.display.set_mode((self.w, self.h)) pygame.display.set_caption("Snake AI Training") self.clock = pygame.time.Clock() self.font = pygame.font.Font(None, 36) self.reset() def reset(self): """Start a fresh game. Returns the initial state.""" self.direction = Direction.RIGHT head_x = (self.w // 2 // BLOCK_SIZE) * BLOCK_SIZE head_y = (self.h // 2 // BLOCK_SIZE) * BLOCK_SIZE self.snake = [ [head_x, head_y], [head_x - BLOCK_SIZE, head_y], [head_x - 2 * BLOCK_SIZE, head_y], ] self.score = 0 self.steps = 0 # track steps to detect loops self.max_steps = len(self.snake) * 100 # reset if stuck in loop self._place_food() return self.get_state() def _place_food(self): while True: x = random.randint(0, (self.w - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE y = random.randint(0, (self.h - BLOCK_SIZE) // BLOCK_SIZE) * BLOCK_SIZE self.food = [x, y] if self.food not in self.snake: break def get_state(self): head = self.snake[0] dir = self.direction # Points one step ahead in each relative direction point_straight = self._next_point(dir) point_right = self._next_point(self._turn_right(dir)) point_left = self._next_point(self._turn_left(dir)) state = [ # --- 3 danger sensors --- self._is_dangerous(point_straight), # wall or body ahead? self._is_dangerous(point_right), self._is_dangerous(point_left), # --- 4 current direction (one-hot) --- dir == Direction.UP, dir == Direction.DOWN, dir == Direction.LEFT, dir == Direction.RIGHT, # --- 4 food direction --- self.food[1] < head[1], # food is up self.food[1] > head[1], # food is down self.food[0] < head[0], # food is left self.food[0] > head[0], # food is right ] return np.array(state, dtype=float) def _next_point(self, direction): head = self.snake[0] if direction == Direction.UP: return [head[0], head[1] - BLOCK_SIZE] if direction == Direction.DOWN: return [head[0], head[1] + BLOCK_SIZE] if direction == Direction.LEFT: return [head[0] - BLOCK_SIZE, head[1]] if direction == Direction.RIGHT: return [head[0] + BLOCK_SIZE, head[1]] def _is_dangerous(self, point): """True if this point is a wall or snake body.""" x, y = point wall = x < 0 or x >= self.w or y < 0 or y >= self.h body = point in self.snake[1:] return float(wall or body) def _turn_right(self, d): order = [Direction.UP, Direction.RIGHT, Direction.DOWN, Direction.LEFT] return order[(order.index(d) + 1) % 4] def _turn_left(self, d): order = [Direction.UP, Direction.RIGHT, Direction.DOWN, Direction.LEFT] return order[(order.index(d) - 1) % 4] def step(self, action): """ action: 0=turn left, 1=go straight, 2=turn right returns: (next_state, reward, done) """ self.steps += 1 # Handle pygame quit even in training mode if self.render_mode: for event in pygame.event.get(): if event.type == pygame.QUIT: pygame.quit() sys.exit() # Convert relative action to absolute direction if action == 0: self.direction = self._turn_left(self.direction) elif action == 2: self.direction = self._turn_right(self.direction) # action == 1 → keep going straight # Move snake new_head = self._next_point(self.direction) self.snake.insert(0, new_head) # --- Reward logic --- reward = 0 done = False if self._is_dangerous(new_head): reward = -10 done = True self.snake.pop() if self.render_mode: self._draw() return self.get_state(), reward, done if new_head == self.food: reward = 10 self.score += 1 self.steps = 0 # reset loop counter on food self.max_steps = len(self.snake) * 100 self._place_food() else: reward = 1 # small reward for surviving self.snake.pop() # Punish the AI if it loops forever without eating if self.steps >= self.max_steps: reward = -10 done = True if self.render_mode: self._draw() return self.get_state(), reward, done def _draw(self): self.display.fill((15, 25, 35)) # Draw snake for i, seg in enumerate(self.snake): color = (46, 213, 115) if i == 0 else (34, 166, 90) pygame.draw.rect(self.display, color, pygame.Rect(seg[0], seg[1], BLOCK_SIZE, BLOCK_SIZE), border_radius=4) # Draw food pygame.draw.rect(self.display, (252, 92, 101), pygame.Rect(self.food[0], self.food[1], BLOCK_SIZE, BLOCK_SIZE), border_radius=BLOCK_SIZE // 2) # Score score_text = self.font.render(f"Score: {self.score}", True, (255, 255, 255)) self.display.blit(score_text, (10, 10)) pygame.display.flip() self.clock.tick(30)