Spaces:
Runtime error
Runtime error
| import gradio as gr | |
| import chess | |
| import numpy as np | |
| import random | |
| from pettingzoo import AECEnv | |
| #from pettingzoo.utils.agent_selector import AgentSelector | |
| from gymnasium import spaces | |
| import zipfile | |
| import os | |
| from stable_baselines3 import PPO, DQN | |
| # Chess Environment Setup | |
| class ChessEnvironment(AECEnv): | |
| metadata = {'render_modes': ['human'], 'name': "Chess-v0", 'is_parallelizable': True} | |
| def __init__(self, render_mode=None): | |
| super().__init__() | |
| self.render_mode = render_mode | |
| self.board = chess.Board() | |
| self.agents = ["w", "b"] | |
| self.possible_agents = self.agents[:] | |
| self.current_agent_index = 0 | |
| self.current_agent = self.agents[self.current_agent_index] | |
| self.agent_selection = self.current_agent | |
| self._cumulative_rewards = {agent: 0 for agent in self.agents} | |
| self.rewards = {agent: 0 for agent in self.agents} | |
| self.terminations = {agent: False for agent in self.agents} | |
| self.truncations = {agent: False for agent in self.agents} | |
| self.infos = {agent: {} for agent in self.agents} | |
| # Define action and observation space | |
| self._action_space = spaces.Discrete(4672) | |
| self._observation_space = spaces.Box(low=0, high=1, shape=(8, 8, 12), dtype=np.int8) | |
| # Load models | |
| self.model_black = DQN.load("dqn_model_black.zip") | |
| self.model_white = DQN.load("dqn_model_white.zip") | |
| self.ppo_model = PPO.load("ppo_chess_model.zip") | |
| def action_space(self, agent): | |
| return self._action_space | |
| def observation_space(self, agent): | |
| return self._observation_space | |
| def reset(self, seed=None, options=None): | |
| self.board.reset() | |
| self.agents = ["w", "b"] | |
| self.current_agent_index = 0 | |
| self.current_agent = self.agents[self.current_agent_index] | |
| self.agent_selection = self.current_agent | |
| self._cumulative_rewards = {agent: 0 for agent in self.agents} | |
| self.rewards = {agent: 0 for agent in self.agents} | |
| self.terminations = {agent: False for agent in self.agents} | |
| self.truncations = {agent: False for agent in self.agents} | |
| self.infos = {agent: {} for agent in self.agents} | |
| self.agent_selection = self.agents[0] | |
| self._game_over_pending = False | |
| def observe(self, agent): | |
| return self._board_to_tensor() | |
| def _board_to_tensor(self): | |
| piece_map = self.board.piece_map() | |
| tensor = np.zeros((8, 8, 12), dtype=np.int8) | |
| for square, piece in piece_map.items(): | |
| row = 7 - (square // 8) | |
| col = square % 8 | |
| piece_type = piece.piece_type - 1 | |
| color_offset = 0 if piece.color == chess.WHITE else 6 | |
| tensor[row, col, piece_type + color_offset] = 1 | |
| return tensor | |
| def step(self, action): | |
| agent = self.agent_selection | |
| if self.terminations[agent] or self.truncations[agent]: | |
| self.agent_selection = self.agents[self.current_agent_index] | |
| return | |
| # Get legal move mask | |
| legal_move_mask = self._get_legal_move_mask() | |
| if legal_move_mask[action] == 1: | |
| move = self._index_to_move(action) | |
| self.board.push(move) | |
| else: | |
| self.rewards[agent] = -1 | |
| other_agent = [a for a in self.agents if a != agent][0] | |
| self.rewards[other_agent] = 1 | |
| self._game_over_pending = True | |
| self.current_agent_index = 1 - self.current_agent_index | |
| self.current_agent = self.agents[self.current_agent_index] | |
| self.agent_selection = self.current_agent | |
| return self.observe(self.current_agent), self.rewards[self.current_agent], self.terminations[self.current_agent], self.truncations[self.current_agent], self.infos[self.current_agent] | |
| # Game over detection | |
| game_over = False | |
| if self.board.is_game_over(): | |
| result = self.board.result() | |
| if result == "1-0": | |
| self.rewards = {"w": 1, "b": -1} | |
| elif result == "0-1": | |
| self.rewards = {"w": -1, "b": 1} | |
| elif result == "1/2-1/2": # Draw condition | |
| self.rewards = {"w": 0, "b": 0} | |
| game_over = True | |
| if game_over: | |
| self._game_over_pending = True | |
| self.current_agent_index = 1 - self.current_agent_index | |
| self.current_agent = self.agents[self.current_agent_index] | |
| self.agent_selection = self.current_agent | |
| if self._game_over_pending and self.agent_selection == self.agents[0]: | |
| self.terminations = {agent: True for agent in self.agents} | |
| self._game_over_pending = False | |
| return self.observe(self.current_agent), self.rewards[self.current_agent], self.terminations[self.current_agent], self.truncations[self.current_agent], self.infos[self.current_agent] | |
| def _index_to_move(self, action_index): | |
| legal_moves = list(self.board.legal_moves) | |
| return legal_moves[action_index] | |
| def render(self): | |
| return str(self.board) | |
| def close(self): | |
| pass | |
| def _get_legal_move_mask(self): | |
| legal_moves = list(self.board.legal_moves) | |
| legal_move_mask = np.zeros(self._action_space.n, dtype=np.int8) | |
| for move in legal_moves: | |
| move_index = self._move_to_index(move) | |
| legal_move_mask[move_index] = 1 | |
| return legal_move_mask | |
| def _move_to_index(self, move): | |
| legal_moves = list(self.board.legal_moves) | |
| return legal_moves.index(move) | |
| # Gradio Interface Setup | |
| def start_game(): | |
| env = ChessEnvironment() | |
| env.reset() | |
| return env.render() | |
| def make_move(action): | |
| env.step(action) | |
| return env.render() | |
| # Gradio UI for Chess Game | |
| with gr.Blocks() as demo: | |
| game_output = gr.Textbox(label="Chess Board", interactive=False) | |
| action_input = gr.Slider(minimum=0, maximum=4671, label="Choose a Move", step=1) | |
| start_button = gr.Button("Start Game") | |
| move_button = gr.Button("Make Move") | |
| start_button.click(start_game, outputs=game_output) | |
| move_button.click(make_move, inputs=action_input, outputs=game_output) | |
| demo.launch() |