MARL-for-Chess / app.py
vinzur's picture
Update app.py
78e41b2 verified
import gradio as gr
import chess
import numpy as np
import random
from pettingzoo import AECEnv
#from pettingzoo.utils.agent_selector import AgentSelector
from gymnasium import spaces
import zipfile
import os
from stable_baselines3 import PPO, DQN
# Chess Environment Setup
class ChessEnvironment(AECEnv):
metadata = {'render_modes': ['human'], 'name': "Chess-v0", 'is_parallelizable': True}
def __init__(self, render_mode=None):
super().__init__()
self.render_mode = render_mode
self.board = chess.Board()
self.agents = ["w", "b"]
self.possible_agents = self.agents[:]
self.current_agent_index = 0
self.current_agent = self.agents[self.current_agent_index]
self.agent_selection = self.current_agent
self._cumulative_rewards = {agent: 0 for agent in self.agents}
self.rewards = {agent: 0 for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.infos = {agent: {} for agent in self.agents}
# Define action and observation space
self._action_space = spaces.Discrete(4672)
self._observation_space = spaces.Box(low=0, high=1, shape=(8, 8, 12), dtype=np.int8)
# Load models
self.model_black = DQN.load("dqn_model_black.zip")
self.model_white = DQN.load("dqn_model_white.zip")
self.ppo_model = PPO.load("ppo_chess_model.zip")
def action_space(self, agent):
return self._action_space
def observation_space(self, agent):
return self._observation_space
def reset(self, seed=None, options=None):
self.board.reset()
self.agents = ["w", "b"]
self.current_agent_index = 0
self.current_agent = self.agents[self.current_agent_index]
self.agent_selection = self.current_agent
self._cumulative_rewards = {agent: 0 for agent in self.agents}
self.rewards = {agent: 0 for agent in self.agents}
self.terminations = {agent: False for agent in self.agents}
self.truncations = {agent: False for agent in self.agents}
self.infos = {agent: {} for agent in self.agents}
self.agent_selection = self.agents[0]
self._game_over_pending = False
def observe(self, agent):
return self._board_to_tensor()
def _board_to_tensor(self):
piece_map = self.board.piece_map()
tensor = np.zeros((8, 8, 12), dtype=np.int8)
for square, piece in piece_map.items():
row = 7 - (square // 8)
col = square % 8
piece_type = piece.piece_type - 1
color_offset = 0 if piece.color == chess.WHITE else 6
tensor[row, col, piece_type + color_offset] = 1
return tensor
def step(self, action):
agent = self.agent_selection
if self.terminations[agent] or self.truncations[agent]:
self.agent_selection = self.agents[self.current_agent_index]
return
# Get legal move mask
legal_move_mask = self._get_legal_move_mask()
if legal_move_mask[action] == 1:
move = self._index_to_move(action)
self.board.push(move)
else:
self.rewards[agent] = -1
other_agent = [a for a in self.agents if a != agent][0]
self.rewards[other_agent] = 1
self._game_over_pending = True
self.current_agent_index = 1 - self.current_agent_index
self.current_agent = self.agents[self.current_agent_index]
self.agent_selection = self.current_agent
return self.observe(self.current_agent), self.rewards[self.current_agent], self.terminations[self.current_agent], self.truncations[self.current_agent], self.infos[self.current_agent]
# Game over detection
game_over = False
if self.board.is_game_over():
result = self.board.result()
if result == "1-0":
self.rewards = {"w": 1, "b": -1}
elif result == "0-1":
self.rewards = {"w": -1, "b": 1}
elif result == "1/2-1/2": # Draw condition
self.rewards = {"w": 0, "b": 0}
game_over = True
if game_over:
self._game_over_pending = True
self.current_agent_index = 1 - self.current_agent_index
self.current_agent = self.agents[self.current_agent_index]
self.agent_selection = self.current_agent
if self._game_over_pending and self.agent_selection == self.agents[0]:
self.terminations = {agent: True for agent in self.agents}
self._game_over_pending = False
return self.observe(self.current_agent), self.rewards[self.current_agent], self.terminations[self.current_agent], self.truncations[self.current_agent], self.infos[self.current_agent]
def _index_to_move(self, action_index):
legal_moves = list(self.board.legal_moves)
return legal_moves[action_index]
def render(self):
return str(self.board)
def close(self):
pass
def _get_legal_move_mask(self):
legal_moves = list(self.board.legal_moves)
legal_move_mask = np.zeros(self._action_space.n, dtype=np.int8)
for move in legal_moves:
move_index = self._move_to_index(move)
legal_move_mask[move_index] = 1
return legal_move_mask
def _move_to_index(self, move):
legal_moves = list(self.board.legal_moves)
return legal_moves.index(move)
# Gradio Interface Setup
def start_game():
env = ChessEnvironment()
env.reset()
return env.render()
def make_move(action):
env.step(action)
return env.render()
# Gradio UI for Chess Game
with gr.Blocks() as demo:
game_output = gr.Textbox(label="Chess Board", interactive=False)
action_input = gr.Slider(minimum=0, maximum=4671, label="Choose a Move", step=1)
start_button = gr.Button("Start Game")
move_button = gr.Button("Make Move")
start_button.click(start_game, outputs=game_output)
move_button.click(make_move, inputs=action_input, outputs=game_output)
demo.launch()