Tetris-RL / tetris_env.py
BaljinderH's picture
Update tetris_env.py
d467399 verified
import gym
from gym import spaces
import numpy as np
import pygame
import random
from sandtris import Tetris
import os
colors = [
(0, 0, 0),
(120, 37, 179),
(100, 179, 179),
(80, 34, 22),
(80, 134, 22),
(180, 34, 22),
(180, 34, 122),
]
class TetrisEnv(gym.Env):
metadata = {'render.modes': ['human', 'rgb_array']}
def __init__(self):
super(TetrisEnv, self).__init__()
self.action_space = spaces.Discrete(5)
self.height = 20
self.width = 10
self.observation_space = spaces.Box(low=0, high=6,
shape=(self.height, self.width), dtype=np.int32)
self.game = Tetris(self.height, self.width)
self.screen = None
self.zoom = 20
self.x = 100
self.y = 60
def reset(self):
self.game = Tetris(self.height, self.width)
self.game.new_figure()
return self._get_obs()
def step(self, action):
done = False
reward = 0
if action == 0:
self.game.go_side(-1) # Move left
elif action == 1:
self.game.go_side(1) # Move right
elif action == 2:
self.game.rotate() # Rotate
elif action == 3:
self.game.go_space() # Drop
elif action == 4:
pass # No operation
self.game.go_down()
lines_cleared = self.game.score
reward += lines_cleared * 10
aggregate_height = self.calculate_aggregate_height()
holes = self.calculate_holes()
bumpiness = self.calculate_bumpiness()
reward -= (aggregate_height * 0.5 + holes * 0.7 + bumpiness * 0.3)
if self.game.state == "gameover":
done = True
reward -= 10
return self._get_obs(), reward, done, {}
def _get_obs(self):
if self.game.field is None:
raise ValueError("The field attribute in self.game is None.")
return np.array(self.game.field)
def render(self, mode='human'):
if mode == 'rgb_array':
if self.screen is None:
pygame.init()
size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
self.screen = pygame.Surface(size)
self.screen.fill((173, 216, 230))
if self.game.field is None:
raise ValueError("Game field is None.")
for i in range(self.game.height):
for j in range(self.game.width):
rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
if self.game.field[i][j] > 0:
pygame.draw.rect(self.screen,
colors[self.game.field[i][j]],
rect.inflate(-2, -2))
if self.game.figure is not None:
for i in range(4):
for j in range(4):
p = i * 4 + j
if p in self.game.figure.image():
rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
self.y + self.zoom * (i + self.game.figure.y),
self.zoom, self.zoom)
pygame.draw.rect(self.screen,
colors[self.game.figure.color],
rect.inflate(-2, -2))
return pygame.surfarray.array3d(self.screen)
elif mode == 'human':
if self.screen is None:
pygame.init()
size = (self.x * 2 + self.zoom * self.width, self.y * 2 + self.zoom * self.height)
self.screen = pygame.display.set_mode(size)
pygame.display.set_caption("Tetris RL")
self.screen.fill((173, 216, 230))
for i in range(self.game.height):
for j in range(self.game.width):
rect = pygame.Rect(self.x + self.zoom * j, self.y + self.zoom * i, self.zoom, self.zoom)
pygame.draw.rect(self.screen, (128, 128, 128), rect, 1) # Grid lines
if self.game.field[i][j] > 0:
pygame.draw.rect(self.screen,
colors[self.game.field[i][j]],
rect.inflate(-2, -2))
if self.game.figure is not None:
for i in range(4):
for j in range(4):
p = i * 4 + j
if p in self.game.figure.image():
rect = pygame.Rect(self.x + self.zoom * (j + self.game.figure.x),
self.y + self.zoom * (i + self.game.figure.y),
self.zoom, self.zoom)
pygame.draw.rect(self.screen,
colors[self.game.figure.color],
rect.inflate(-2, -2))
pygame.display.flip()
def close(self):
if self.screen is not None:
pygame.display.quit()
pygame.quit()
def calculate_aggregate_height(self):
heights = [0 for _ in range(self.width)]
for j in range(self.width):
for i in range(self.height):
if self.game.field[i][j] != 0:
heights[j] = self.height - i
break
return sum(heights)
def calculate_holes(self):
holes = 0
for j in range(self.width):
block_found = False
for i in range(self.height):
if self.game.field[i][j] != 0:
block_found = True
elif block_found and self.game.field[i][j] == 0:
holes += 1
return holes
def calculate_bumpiness(self):
heights = [0 for _ in range(self.width)]
for j in range(self.width):
for i in range(self.height):
if self.game.field[i][j] != 0:
heights[j] = self.height - i
break
bumpiness = 0
for j in range(self.width - 1):
bumpiness += abs(heights[j] - heights[j + 1])
return bumpiness