Instructions to use doraking/AlphaQuoridor with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Keras
How to use doraking/AlphaQuoridor with Keras:
# Available backend options are: "jax", "torch", "tensorflow". import os os.environ["KERAS_BACKEND"] = "jax" import keras model = keras.saving.load_model("hf://doraking/AlphaQuoridor") - Notebooks
- Google Colab
- Kaggle
Upload 10 files
Browse files- dual_network.py +79 -0
- evaluate_best_player.py +89 -0
- evaluate_network.py +95 -0
- game.py +594 -0
- human_play.py +229 -0
- pv_mcts.py +168 -0
- requirements.txt +4 -0
- self_play.py +101 -0
- train_cycle.py +33 -0
- train_network.py +66 -0
dual_network.py
ADDED
|
@@ -0,0 +1,79 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Creating the Dual Network
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Importing packages
|
| 6 |
+
from tensorflow.keras.layers import Activation, Add, BatchNormalization, Conv2D, Dense, GlobalAveragePooling2D, Input
|
| 7 |
+
from tensorflow.keras.models import Model
|
| 8 |
+
from tensorflow.keras.regularizers import l2
|
| 9 |
+
from tensorflow.keras import backend as K
|
| 10 |
+
import os
|
| 11 |
+
|
| 12 |
+
# Preparing parameters
|
| 13 |
+
DN_FILTERS = 128 # Number of kernels in the convolutional layer (256 in the original version)
|
| 14 |
+
DN_RESIDUAL_NUM = 16 # Number of residual blocks (19 in the original version)
|
| 15 |
+
DN_INPUT_SHAPE = (3, 3, 6) # Input shape
|
| 16 |
+
DN_OUTPUT_SIZE = 9 + 4 * 2 # Number of actions (placement locations (3*3))
|
| 17 |
+
|
| 18 |
+
# Creating the convolutional layer
|
| 19 |
+
def conv(filters):
|
| 20 |
+
return Conv2D(filters, 3, padding='same', use_bias=False,
|
| 21 |
+
kernel_initializer='he_normal', kernel_regularizer=l2(0.0005))
|
| 22 |
+
|
| 23 |
+
# Creating the residual block
|
| 24 |
+
def residual_block():
|
| 25 |
+
def f(x):
|
| 26 |
+
sc = x
|
| 27 |
+
x = conv(DN_FILTERS)(x)
|
| 28 |
+
x = BatchNormalization()(x)
|
| 29 |
+
x = Activation('relu')(x)
|
| 30 |
+
x = conv(DN_FILTERS)(x)
|
| 31 |
+
x = BatchNormalization()(x)
|
| 32 |
+
x = Add()([x, sc])
|
| 33 |
+
x = Activation('relu')(x)
|
| 34 |
+
return x
|
| 35 |
+
return f
|
| 36 |
+
|
| 37 |
+
# Creating the dual network
|
| 38 |
+
def dual_network():
|
| 39 |
+
# Do nothing if the model is already created
|
| 40 |
+
if os.path.exists('./model/best.keras'):
|
| 41 |
+
return
|
| 42 |
+
|
| 43 |
+
# Input layer
|
| 44 |
+
input = Input(shape=DN_INPUT_SHAPE)
|
| 45 |
+
|
| 46 |
+
# Convolutional layer
|
| 47 |
+
x = conv(DN_FILTERS)(input)
|
| 48 |
+
x = BatchNormalization()(x)
|
| 49 |
+
x = Activation('relu')(x)
|
| 50 |
+
|
| 51 |
+
# Residual blocks x 16
|
| 52 |
+
for i in range(DN_RESIDUAL_NUM):
|
| 53 |
+
x = residual_block()(x)
|
| 54 |
+
|
| 55 |
+
# Pooling layer
|
| 56 |
+
x = GlobalAveragePooling2D()(x)
|
| 57 |
+
|
| 58 |
+
# Policy output
|
| 59 |
+
p = Dense(DN_OUTPUT_SIZE, kernel_regularizer=l2(0.0005),
|
| 60 |
+
activation='softmax', name='pi')(x)
|
| 61 |
+
|
| 62 |
+
# Value output
|
| 63 |
+
v = Dense(1, kernel_regularizer=l2(0.0005))(x)
|
| 64 |
+
v = Activation('tanh', name='v')(v)
|
| 65 |
+
|
| 66 |
+
# Creating the model
|
| 67 |
+
model = Model(inputs=input, outputs=[p, v])
|
| 68 |
+
|
| 69 |
+
# Saving the model
|
| 70 |
+
os.makedirs('./model/', exist_ok=True) # Create folder if it does not exist
|
| 71 |
+
model.save('./model/best.keras') # Best player's model
|
| 72 |
+
|
| 73 |
+
# Clearing the model
|
| 74 |
+
K.clear_session()
|
| 75 |
+
del model
|
| 76 |
+
|
| 77 |
+
# Running the function
|
| 78 |
+
if __name__ == '__main__':
|
| 79 |
+
dual_network()
|
evaluate_best_player.py
ADDED
|
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Evaluation of Best Player
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Import packages
|
| 6 |
+
from game import State, random_action, alpha_beta_action, mcts_action
|
| 7 |
+
from pv_mcts import pv_mcts_action
|
| 8 |
+
from tensorflow.keras.models import load_model
|
| 9 |
+
from tensorflow.keras import backend as K
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
# Prepare parameters
|
| 14 |
+
EP_GAME_COUNT = 10 # Number of games per evaluation
|
| 15 |
+
|
| 16 |
+
# Points for the first player
|
| 17 |
+
def first_player_point(ended_state):
|
| 18 |
+
# 1: first player wins, 0: first player loses, 0.5: draw
|
| 19 |
+
if ended_state.is_lose():
|
| 20 |
+
return 0 if ended_state.is_first_player() else 1
|
| 21 |
+
return 0.5
|
| 22 |
+
|
| 23 |
+
# Execute one game
|
| 24 |
+
def play(next_actions):
|
| 25 |
+
# Generate state
|
| 26 |
+
state = State()
|
| 27 |
+
|
| 28 |
+
# Loop until the game ends
|
| 29 |
+
while True:
|
| 30 |
+
# When the game ends
|
| 31 |
+
if state.is_done():
|
| 32 |
+
break
|
| 33 |
+
|
| 34 |
+
# Get action
|
| 35 |
+
next_action = next_actions[0] if state.is_first_player() else next_actions[1]
|
| 36 |
+
action = next_action(state)
|
| 37 |
+
|
| 38 |
+
# Get the next state
|
| 39 |
+
state = state.next(action)
|
| 40 |
+
|
| 41 |
+
# Return points for the first player
|
| 42 |
+
return first_player_point(state)
|
| 43 |
+
|
| 44 |
+
# Evaluation of any algorithm
|
| 45 |
+
def evaluate_algorithm_of(label, next_actions):
|
| 46 |
+
# Repeat multiple matches
|
| 47 |
+
total_point = 0
|
| 48 |
+
for i in range(EP_GAME_COUNT):
|
| 49 |
+
# Execute one game
|
| 50 |
+
if i % 2 == 0:
|
| 51 |
+
total_point += play(next_actions)
|
| 52 |
+
else:
|
| 53 |
+
total_point += 1 - play(list(reversed(next_actions)))
|
| 54 |
+
|
| 55 |
+
# Output
|
| 56 |
+
print('\rEvaluate {}/{}'.format(i + 1, EP_GAME_COUNT), end='')
|
| 57 |
+
print('')
|
| 58 |
+
|
| 59 |
+
# Calculate average points
|
| 60 |
+
average_point = total_point / EP_GAME_COUNT
|
| 61 |
+
print(label, average_point)
|
| 62 |
+
|
| 63 |
+
# Evaluation of the best player
|
| 64 |
+
def evaluate_best_player():
|
| 65 |
+
# Load the model of the best player
|
| 66 |
+
model = load_model('./model/best.keras')
|
| 67 |
+
|
| 68 |
+
# Generate a function to select actions using PV MCTS
|
| 69 |
+
next_pv_mcts_action = pv_mcts_action(model, 0.0)
|
| 70 |
+
|
| 71 |
+
# VS Random
|
| 72 |
+
next_actions = (next_pv_mcts_action, random_action)
|
| 73 |
+
evaluate_algorithm_of('VS_Random', next_actions)
|
| 74 |
+
|
| 75 |
+
# VS Alpha-Beta
|
| 76 |
+
next_actions = (next_pv_mcts_action, alpha_beta_action)
|
| 77 |
+
evaluate_algorithm_of('VS_AlphaBeta', next_actions)
|
| 78 |
+
|
| 79 |
+
# VS Monte Carlo Tree Search
|
| 80 |
+
next_actions = (next_pv_mcts_action, mcts_action)
|
| 81 |
+
evaluate_algorithm_of('VS_MCTS', next_actions)
|
| 82 |
+
|
| 83 |
+
# Clear model
|
| 84 |
+
K.clear_session()
|
| 85 |
+
del model
|
| 86 |
+
|
| 87 |
+
# Operation check
|
| 88 |
+
if __name__ == '__main__':
|
| 89 |
+
evaluate_best_player()
|
evaluate_network.py
ADDED
|
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# New Parameter Evaluation Section
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Import packages
|
| 6 |
+
from game import State
|
| 7 |
+
from pv_mcts import pv_mcts_action
|
| 8 |
+
from tensorflow.keras.models import load_model
|
| 9 |
+
from tensorflow.keras import backend as K
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
from shutil import copy
|
| 12 |
+
import numpy as np
|
| 13 |
+
|
| 14 |
+
# Prepare parameters
|
| 15 |
+
EN_GAME_COUNT = 10 # Number of games per evaluation (originally 400)
|
| 16 |
+
EN_TEMPERATURE = 1.0 # Temperature of the Boltzmann distribution
|
| 17 |
+
|
| 18 |
+
# Points for the first player
|
| 19 |
+
def first_player_point(ended_state):
|
| 20 |
+
# 1: first player wins, 0: first player loses, 0.5: draw
|
| 21 |
+
if ended_state.is_lose():
|
| 22 |
+
return 0 if ended_state.is_first_player() else 1
|
| 23 |
+
return 0.5
|
| 24 |
+
|
| 25 |
+
# Execute one game
|
| 26 |
+
def play(next_actions):
|
| 27 |
+
# Generate state
|
| 28 |
+
state = State()
|
| 29 |
+
|
| 30 |
+
# Loop until the game ends
|
| 31 |
+
while True:
|
| 32 |
+
# When the game ends
|
| 33 |
+
if state.is_done():
|
| 34 |
+
break
|
| 35 |
+
|
| 36 |
+
# Get action
|
| 37 |
+
next_action = next_actions[0] if state.is_first_player() else next_actions[1]
|
| 38 |
+
action = next_action(state)
|
| 39 |
+
|
| 40 |
+
# Get the next state
|
| 41 |
+
state = state.next(action)
|
| 42 |
+
|
| 43 |
+
# Return points for the first player
|
| 44 |
+
return first_player_point(state)
|
| 45 |
+
|
| 46 |
+
# Replace the best player
|
| 47 |
+
def update_best_player():
|
| 48 |
+
copy('./model/latest.keras', './model/best.keras')
|
| 49 |
+
print('Change BestPlayer')
|
| 50 |
+
|
| 51 |
+
# Network evaluation
|
| 52 |
+
def evaluate_network():
|
| 53 |
+
# Load the model of the latest player
|
| 54 |
+
model0 = load_model('./model/latest.keras')
|
| 55 |
+
|
| 56 |
+
# Load the model of the best player
|
| 57 |
+
model1 = load_model('./model/best.keras')
|
| 58 |
+
|
| 59 |
+
# Generate a function to select actions using PV MCTS
|
| 60 |
+
next_action0 = pv_mcts_action(model0, EN_TEMPERATURE)
|
| 61 |
+
next_action1 = pv_mcts_action(model1, EN_TEMPERATURE)
|
| 62 |
+
next_actions = (next_action0, next_action1)
|
| 63 |
+
|
| 64 |
+
# Repeat multiple matches
|
| 65 |
+
total_point = 0
|
| 66 |
+
for i in range(EN_GAME_COUNT):
|
| 67 |
+
# Execute one game
|
| 68 |
+
if i % 2 == 0:
|
| 69 |
+
total_point += play(next_actions)
|
| 70 |
+
else:
|
| 71 |
+
total_point += 1 - play(list(reversed(next_actions)))
|
| 72 |
+
|
| 73 |
+
# Output
|
| 74 |
+
print('\rEvaluate {}/{}'.format(i + 1, EN_GAME_COUNT), end='')
|
| 75 |
+
print('')
|
| 76 |
+
|
| 77 |
+
# Calculate average points
|
| 78 |
+
average_point = total_point / EN_GAME_COUNT
|
| 79 |
+
print('AveragePoint', average_point)
|
| 80 |
+
|
| 81 |
+
# Clear models
|
| 82 |
+
K.clear_session()
|
| 83 |
+
del model0
|
| 84 |
+
del model1
|
| 85 |
+
|
| 86 |
+
# Replace the best player
|
| 87 |
+
if average_point > 0.5:
|
| 88 |
+
update_best_player()
|
| 89 |
+
return True
|
| 90 |
+
else:
|
| 91 |
+
return False
|
| 92 |
+
|
| 93 |
+
# Operation check
|
| 94 |
+
if __name__ == '__main__':
|
| 95 |
+
evaluate_network()
|
game.py
ADDED
|
@@ -0,0 +1,594 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Quoridor (3 x 3), wall = 1
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Importing packages
|
| 6 |
+
import random
|
| 7 |
+
import math
|
| 8 |
+
from collections import deque
|
| 9 |
+
import copy
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
|
| 12 |
+
# Game state
|
| 13 |
+
class State:
|
| 14 |
+
def __init__(self, board_size=3, num_walls=1, player=None, enemy=None, walls=None, depth=0):
|
| 15 |
+
self.N = board_size
|
| 16 |
+
N = self.N
|
| 17 |
+
if N % 2 == 0:
|
| 18 |
+
raise ValueError('The board size must be an odd number.')
|
| 19 |
+
self.directions = [(-1, 0), (1, 0), (0, -1), (0, 1)]
|
| 20 |
+
self.player = player if player != None else [0] * 2 # Position, number of walls
|
| 21 |
+
self.enemy = enemy if enemy != None else [0] * 2
|
| 22 |
+
self.walls = walls if walls != None else [0] * ((N - 1) ** 2)
|
| 23 |
+
self.depth = depth
|
| 24 |
+
self.draw_depth = 30
|
| 25 |
+
|
| 26 |
+
if player == None or enemy == None:
|
| 27 |
+
init_pos = N * (N - 1) + N // 2
|
| 28 |
+
self.player[0] = init_pos
|
| 29 |
+
self.player[1] = num_walls
|
| 30 |
+
self.enemy[0] = init_pos
|
| 31 |
+
self.enemy[1] = num_walls
|
| 32 |
+
|
| 33 |
+
# Check if it's a loss
|
| 34 |
+
def is_lose(self):
|
| 35 |
+
if self.enemy[0] // self.N == 0:
|
| 36 |
+
return True
|
| 37 |
+
return False
|
| 38 |
+
|
| 39 |
+
# Check if it's a draw
|
| 40 |
+
def is_draw(self):
|
| 41 |
+
return self.depth >= self.draw_depth
|
| 42 |
+
|
| 43 |
+
# Check if the game is over
|
| 44 |
+
def is_done(self):
|
| 45 |
+
return self.is_lose() or self.is_draw()
|
| 46 |
+
|
| 47 |
+
def pieces_array(self):
|
| 48 |
+
N = self.N
|
| 49 |
+
def pieces_of(pieces):
|
| 50 |
+
tables = []
|
| 51 |
+
|
| 52 |
+
table = [0] * (N ** 2)
|
| 53 |
+
table[pieces[0]] = 1
|
| 54 |
+
tables.append(table)
|
| 55 |
+
|
| 56 |
+
table = [pieces[1]] * (N ** 2)
|
| 57 |
+
tables.append(table)
|
| 58 |
+
|
| 59 |
+
return tables
|
| 60 |
+
|
| 61 |
+
def walls_of(walls):
|
| 62 |
+
tables = []
|
| 63 |
+
|
| 64 |
+
table_h = [0] * (N ** 2)
|
| 65 |
+
table_v = [0] * (N ** 2)
|
| 66 |
+
|
| 67 |
+
for wp in range((N - 1) ** 2):
|
| 68 |
+
x, y = wp // (N - 1), wp % (N - 1)
|
| 69 |
+
|
| 70 |
+
if x < (N - 1) // 2 and y < (N - 1) // 2:
|
| 71 |
+
pos = N * x + y
|
| 72 |
+
elif x > (N - 1) // 2 and y < (N - 1) // 2:
|
| 73 |
+
pos = N * x + (y + 1)
|
| 74 |
+
elif x < (N - 1) // 2 and y > (N - 1) // 2:
|
| 75 |
+
pos = N * (x + 1) + y
|
| 76 |
+
else:
|
| 77 |
+
pos = N * (x + 1) + (y + 1)
|
| 78 |
+
|
| 79 |
+
if walls[wp] == 1:
|
| 80 |
+
table_h[pos] = 1
|
| 81 |
+
elif walls[wp] == 2:
|
| 82 |
+
table_v[pos] = 1
|
| 83 |
+
|
| 84 |
+
tables.append(table_h)
|
| 85 |
+
tables.append(table_v)
|
| 86 |
+
|
| 87 |
+
return tables
|
| 88 |
+
|
| 89 |
+
return [pieces_of(self.player), pieces_of(self.enemy), walls_of(self.walls)]
|
| 90 |
+
|
| 91 |
+
def legal_actions(self):
|
| 92 |
+
"""
|
| 93 |
+
0 - (N ** 2 - 1): Move to a position
|
| 94 |
+
N ** 2- (N ** 2 + (N - 1) ** 2 - 1): Place a horizontal wall
|
| 95 |
+
(N ** 2 + (N - 1) ** 2) - (N ** 2 + 2 * (N - 1) ** 2 - 1): Place a vertical wall
|
| 96 |
+
"""
|
| 97 |
+
actions = []
|
| 98 |
+
actions.extend(self.legal_actions_pos(self.player[0]))
|
| 99 |
+
|
| 100 |
+
if self.player[1] > 0:
|
| 101 |
+
for pos in range((self.N - 1) ** 2):
|
| 102 |
+
actions.extend(self.legal_actions_wall(pos))
|
| 103 |
+
|
| 104 |
+
return actions
|
| 105 |
+
|
| 106 |
+
def legal_actions_pos(self, pos):
|
| 107 |
+
actions = []
|
| 108 |
+
|
| 109 |
+
N = self.N
|
| 110 |
+
walls = self.walls
|
| 111 |
+
ep = self.enemy[0]
|
| 112 |
+
|
| 113 |
+
x, y = pos // N, pos % N
|
| 114 |
+
for dx, dy in self.directions:
|
| 115 |
+
nx, ny = x + dx, y + dy
|
| 116 |
+
if 0 <= nx < N and 0 <= ny < N:
|
| 117 |
+
np = N * nx + ny
|
| 118 |
+
wp = (N - 1) * nx + ny
|
| 119 |
+
|
| 120 |
+
if nx < x:
|
| 121 |
+
if y == 0:
|
| 122 |
+
if walls[wp] != 1:
|
| 123 |
+
if np + ep != N ** 2 - 1:
|
| 124 |
+
actions.append(np)
|
| 125 |
+
else:
|
| 126 |
+
if nx > 0 and walls[wp - (N - 1)] != 1:
|
| 127 |
+
nnp = np - N
|
| 128 |
+
actions.append(nnp)
|
| 129 |
+
elif (nx == 0 and walls[wp] != 2) or (nx > 0 and walls[wp - (N - 1)] != 2 and walls[wp] != 2):
|
| 130 |
+
nnp = np + 1
|
| 131 |
+
actions.append(nnp)
|
| 132 |
+
elif y == (N - 1):
|
| 133 |
+
if walls[wp - 1] != 1:
|
| 134 |
+
if np + ep != N ** 2 - 1:
|
| 135 |
+
actions.append(np)
|
| 136 |
+
else:
|
| 137 |
+
if nx > 0 and walls[wp - (N - 1) - 1] != 1:
|
| 138 |
+
nnp = np - N
|
| 139 |
+
actions.append(nnp)
|
| 140 |
+
elif (nx == 0 and walls[wp - 1] != 2) or (nx > 0 and walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2):
|
| 141 |
+
nnp = np - 1
|
| 142 |
+
actions.append(nnp)
|
| 143 |
+
else:
|
| 144 |
+
if walls[wp - 1] != 1 and walls[wp] != 1:
|
| 145 |
+
if np + ep != N ** 2 - 1:
|
| 146 |
+
actions.append(np)
|
| 147 |
+
else:
|
| 148 |
+
if nx > 0 and walls[wp - (N - 1)] != 1 and walls[wp - (N - 1) - 1] != 1:
|
| 149 |
+
nnp = np - N
|
| 150 |
+
actions.append(nnp)
|
| 151 |
+
else:
|
| 152 |
+
if (nx == 0 and walls[wp - 1] != 2) or (nx > 0 and walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2):
|
| 153 |
+
nnp = np - 1
|
| 154 |
+
actions.append(nnp)
|
| 155 |
+
if (nx == 0 and walls[wp] != 2) or (nx > 0 and walls[wp - (N - 1)] != 2 and walls[wp] != 2):
|
| 156 |
+
nnp = np + 1
|
| 157 |
+
actions.append(nnp)
|
| 158 |
+
if nx > x:
|
| 159 |
+
if y == 0:
|
| 160 |
+
if walls[wp - (N - 1)] != 1:
|
| 161 |
+
if np + ep != N ** 2 - 1:
|
| 162 |
+
actions.append(np)
|
| 163 |
+
else:
|
| 164 |
+
if nx < (N - 1) and walls[wp] != 1:
|
| 165 |
+
nnp = np + N
|
| 166 |
+
actions.append(nnp)
|
| 167 |
+
elif (nx == (N - 1) and walls[wp - (N - 1)] != 2) or (nx < (N - 1) and walls[wp - (N - 1)] != 2 and walls[wp] != 2):
|
| 168 |
+
nnp = np + 1
|
| 169 |
+
actions.append(nnp)
|
| 170 |
+
elif y == (N - 1):
|
| 171 |
+
if walls[wp - (N - 1) - 1] != 1:
|
| 172 |
+
if np + ep != N ** 2 - 1:
|
| 173 |
+
actions.append(np)
|
| 174 |
+
else:
|
| 175 |
+
if nx < (N - 1) and walls[wp - 1] != 1:
|
| 176 |
+
nnp = np + N
|
| 177 |
+
actions.append(nnp)
|
| 178 |
+
elif (nx == (N - 1) and walls[wp - (N - 1) - 1] != 2) or (nx < (N - 1) and walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2):
|
| 179 |
+
nnp = np - 1
|
| 180 |
+
actions.append(nnp)
|
| 181 |
+
else:
|
| 182 |
+
if walls[wp - (N - 1) - 1] != 1 and walls[wp - (N - 1)] != 1:
|
| 183 |
+
if np + ep != N ** 2 - 1:
|
| 184 |
+
actions.append(np)
|
| 185 |
+
else:
|
| 186 |
+
if nx < (N - 1) and walls[wp - 1] != 1 and walls[wp] != 1:
|
| 187 |
+
nnp = np + N
|
| 188 |
+
actions.append(nnp)
|
| 189 |
+
else:
|
| 190 |
+
if (nx == (N - 1) and walls[wp - (N - 1) - 1] != 2) or (nx < (N - 1) and walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2):
|
| 191 |
+
nnp = np - 1
|
| 192 |
+
actions.append(nnp)
|
| 193 |
+
if (nx == (N - 1) and walls[wp - (N - 1)] != 2) or (nx < (N - 1) and walls[wp - (N - 1)] != 2 and walls[wp] != 2):
|
| 194 |
+
nnp = np + 1
|
| 195 |
+
actions.append(nnp)
|
| 196 |
+
if ny < y:
|
| 197 |
+
if x == 0:
|
| 198 |
+
if walls[wp] != 2:
|
| 199 |
+
if np + ep != N ** 2 - 1:
|
| 200 |
+
actions.append(np)
|
| 201 |
+
else:
|
| 202 |
+
if ny > 0 and walls[wp - 1] != 2:
|
| 203 |
+
nnp = np - 1
|
| 204 |
+
actions.append(nnp)
|
| 205 |
+
elif (ny == 0 and walls[wp] != 1) or (ny > 0 and walls[wp - 1] != 1 and walls[wp] != 1):
|
| 206 |
+
nnp = np + N
|
| 207 |
+
actions.append(nnp)
|
| 208 |
+
elif x == (N - 1):
|
| 209 |
+
if walls[wp - (N - 1)] != 2:
|
| 210 |
+
if np + ep != N ** 2 - 1:
|
| 211 |
+
actions.append(np)
|
| 212 |
+
else:
|
| 213 |
+
if ny > 0 and walls[wp - (N - 1) - 1] != 2:
|
| 214 |
+
nnp = np - 1
|
| 215 |
+
actions.append(nnp)
|
| 216 |
+
elif (ny == 0 and walls[wp - (N - 1)] != 1) or (ny > 0 and walls[wp - (N - 1) - 1] != 2 and walls[wp - (N - 1)] != 1):
|
| 217 |
+
nnp = np - N
|
| 218 |
+
actions.append(nnp)
|
| 219 |
+
else:
|
| 220 |
+
if walls[wp - (N - 1)] != 2 and walls[wp] != 2:
|
| 221 |
+
if np + ep != N ** 2 - 1:
|
| 222 |
+
actions.append(np)
|
| 223 |
+
else:
|
| 224 |
+
if ny > 0 and walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2:
|
| 225 |
+
nnp = np - 1
|
| 226 |
+
actions.append(nnp)
|
| 227 |
+
else:
|
| 228 |
+
if (ny == 0 and walls[wp - (N - 1)] != 1) or (ny > 0 and walls[wp - (N - 1) - 1] != 2 and walls[wp - (N - 1)] != 1):
|
| 229 |
+
nnp = np - N
|
| 230 |
+
actions.append(nnp)
|
| 231 |
+
if (ny == 0 and walls[wp] != 1) or (ny > 0 and (walls[wp - 1] != 1 or walls[wp] != 1)):
|
| 232 |
+
nnp = np + N
|
| 233 |
+
actions.append(nnp)
|
| 234 |
+
if ny > y:
|
| 235 |
+
if x == 0:
|
| 236 |
+
if walls[wp - 1] != 2:
|
| 237 |
+
if np + ep != N ** 2 - 1:
|
| 238 |
+
actions.append(np)
|
| 239 |
+
else:
|
| 240 |
+
if ny < (N - 1) and walls[wp] != 2:
|
| 241 |
+
nnp = np + 1
|
| 242 |
+
actions.append(nnp)
|
| 243 |
+
elif (ny == (N - 1) and walls[wp - 1] != 1) or (ny < (N - 1) and walls[wp - 1] != 1 and walls[wp] != 1):
|
| 244 |
+
nnp = np + N
|
| 245 |
+
actions.append(nnp)
|
| 246 |
+
elif x == (N - 1):
|
| 247 |
+
if walls[wp - (N - 1) - 1] != 2:
|
| 248 |
+
if np + ep != N ** 2 - 1:
|
| 249 |
+
actions.append(np)
|
| 250 |
+
else:
|
| 251 |
+
if ny < (N - 1) and walls[wp - (N - 1)] != 2:
|
| 252 |
+
nnp = np + 1
|
| 253 |
+
actions.append(nnp)
|
| 254 |
+
elif (ny == (N - 1) and walls[wp - (N - 1) - 1] != 1) or (ny < (N - 1) and walls[wp - (N - 1) - 1] != 1 and walls[wp - (N - 1)] != 1):
|
| 255 |
+
nnp = np - N
|
| 256 |
+
actions.append(nnp)
|
| 257 |
+
else:
|
| 258 |
+
if walls[wp - (N - 1) - 1] != 2 and walls[wp - 1] != 2:
|
| 259 |
+
if np + ep != N ** 2 - 1:
|
| 260 |
+
actions.append(np)
|
| 261 |
+
else:
|
| 262 |
+
if ny < (N - 1) and walls[wp - (N - 1)] != 2 and walls[wp] != 2:
|
| 263 |
+
nnp = np + 1
|
| 264 |
+
actions.append(nnp)
|
| 265 |
+
else:
|
| 266 |
+
if (ny == (N - 1) and walls[wp - (N - 1) - 1] != 1) or (ny < (N - 1) and walls[wp - (N - 1) - 1] != 1 and walls[wp - (N - 1)] != 1):
|
| 267 |
+
nnp = np - N
|
| 268 |
+
actions.append(nnp)
|
| 269 |
+
if (ny == (N - 1) and walls[wp - 1] != 1) or (ny < (N - 1) and (walls[wp - 1] != 1 or walls[wp] != 1)):
|
| 270 |
+
nnp = np + N
|
| 271 |
+
actions.append(nnp)
|
| 272 |
+
|
| 273 |
+
return actions
|
| 274 |
+
|
| 275 |
+
def legal_actions_wall(self, pos):
|
| 276 |
+
N = self.N
|
| 277 |
+
walls = self.walls
|
| 278 |
+
def can_place_wall(orientation, pos):
|
| 279 |
+
if walls[pos] != 0:
|
| 280 |
+
return False
|
| 281 |
+
x, y = pos // (N - 1), pos % (N - 1)
|
| 282 |
+
if orientation == 1:
|
| 283 |
+
if y == 0:
|
| 284 |
+
if walls[pos + 1] == 1:
|
| 285 |
+
return False
|
| 286 |
+
elif y == (N - 2):
|
| 287 |
+
if walls[pos - 1] == 1:
|
| 288 |
+
return False
|
| 289 |
+
else:
|
| 290 |
+
if walls[pos - 1] == 1 or walls[pos + 1] == 1:
|
| 291 |
+
return False
|
| 292 |
+
else:
|
| 293 |
+
if x == 0:
|
| 294 |
+
if walls[pos + (N - 1)] == 2:
|
| 295 |
+
return False
|
| 296 |
+
elif x == (N - 2):
|
| 297 |
+
if walls[pos - (N - 1)] == 2:
|
| 298 |
+
return False
|
| 299 |
+
else:
|
| 300 |
+
if walls[pos - (N - 1)] == 2 or walls[pos + (N - 1)] == 2:
|
| 301 |
+
return False
|
| 302 |
+
return True
|
| 303 |
+
|
| 304 |
+
def can_reach_goal(orientation, pos):
|
| 305 |
+
def bfs(state):
|
| 306 |
+
queue = deque([state.player[0]])
|
| 307 |
+
visited = set()
|
| 308 |
+
while queue:
|
| 309 |
+
pos = queue.popleft()
|
| 310 |
+
nps = state.legal_actions_pos(pos)
|
| 311 |
+
for np in nps:
|
| 312 |
+
x, y = np // N, np % N
|
| 313 |
+
if y == 0:
|
| 314 |
+
return True
|
| 315 |
+
if np not in visited:
|
| 316 |
+
visited.add(np)
|
| 317 |
+
queue.append(np)
|
| 318 |
+
return False
|
| 319 |
+
|
| 320 |
+
self.walls[pos] = orientation
|
| 321 |
+
|
| 322 |
+
player_state = State(board_size=N, player=self.player.copy(), enemy=self.enemy.copy(), walls=deepcopy(self.walls), depth=self.depth)
|
| 323 |
+
|
| 324 |
+
can_reach_player = bfs(player_state)
|
| 325 |
+
|
| 326 |
+
action = pos
|
| 327 |
+
if orientation == 1:
|
| 328 |
+
action += N ** 2
|
| 329 |
+
else:
|
| 330 |
+
action += N ** 2 + (N - 1) ** 2
|
| 331 |
+
|
| 332 |
+
enemy_state = player_state.next(action)
|
| 333 |
+
|
| 334 |
+
can_reach_enemy = bfs(enemy_state)
|
| 335 |
+
|
| 336 |
+
self.walls[pos] = 0
|
| 337 |
+
|
| 338 |
+
return can_reach_player and can_reach_enemy
|
| 339 |
+
|
| 340 |
+
actions = []
|
| 341 |
+
|
| 342 |
+
if can_place_wall(1, pos) and can_reach_goal(1, pos):
|
| 343 |
+
actions.append(N ** 2 + pos)
|
| 344 |
+
if can_place_wall(2, pos) and can_reach_goal(2, pos):
|
| 345 |
+
actions.append(N ** 2 + (N - 1) ** 2 + pos)
|
| 346 |
+
|
| 347 |
+
return actions
|
| 348 |
+
|
| 349 |
+
def rotate_walls(self):
|
| 350 |
+
N = self.N
|
| 351 |
+
rotated_walls = [0] * len(self.walls)
|
| 352 |
+
for i in range((N - 1) ** 2):
|
| 353 |
+
rotated_walls[i] = self.walls[(N - 1) ** 2 - 1 - i]
|
| 354 |
+
self.walls = rotated_walls
|
| 355 |
+
|
| 356 |
+
def next(self, action):
|
| 357 |
+
N = self.N
|
| 358 |
+
# Create the next state
|
| 359 |
+
state = State(board_size=N, player=self.player.copy(), enemy=self.enemy.copy(), walls=deepcopy(self.walls), depth=self.depth + 1)
|
| 360 |
+
|
| 361 |
+
if action < N ** 2:
|
| 362 |
+
# Move piece
|
| 363 |
+
state.player[0] = action
|
| 364 |
+
elif action < N ** 2 + (N - 1) ** 2:
|
| 365 |
+
# Place horizontal wall
|
| 366 |
+
pos = action - N ** 2
|
| 367 |
+
state.walls[pos] = 1
|
| 368 |
+
state.player[1] -= 1
|
| 369 |
+
else:
|
| 370 |
+
# Place vertical wall
|
| 371 |
+
pos = action - N ** 2 - (N - 1) ** 2
|
| 372 |
+
state.walls[pos] = 2
|
| 373 |
+
state.player[1] -= 1
|
| 374 |
+
|
| 375 |
+
state.rotate_walls()
|
| 376 |
+
|
| 377 |
+
# Swap players
|
| 378 |
+
state.player, state.enemy = state.enemy, state.player
|
| 379 |
+
|
| 380 |
+
return state
|
| 381 |
+
|
| 382 |
+
# Check if it's the first player's turn
|
| 383 |
+
def is_first_player(self):
|
| 384 |
+
return self.depth % 2 == 0
|
| 385 |
+
|
| 386 |
+
def __str__(self):
|
| 387 |
+
"""Display the game state as a string."""
|
| 388 |
+
N = self.N
|
| 389 |
+
is_first_player = self.is_first_player()
|
| 390 |
+
|
| 391 |
+
board = [['o'] * (2 * N - 1) for _ in range(2 * N - 1)]
|
| 392 |
+
for i in range(2 * N - 1):
|
| 393 |
+
for j in range(2 * N - 1):
|
| 394 |
+
if i % 2 == 1 and j % 2 == 1:
|
| 395 |
+
board[i][j] = 'x'
|
| 396 |
+
|
| 397 |
+
p_pos = self.player[0] if is_first_player else self.enemy[0]
|
| 398 |
+
e_pos = self.enemy[0] if is_first_player else self.player[0]
|
| 399 |
+
|
| 400 |
+
e_pos = N ** 2 - 1 - e_pos
|
| 401 |
+
|
| 402 |
+
p_x, p_y = p_pos // N, p_pos % N
|
| 403 |
+
e_x, e_y = e_pos // N, e_pos % N
|
| 404 |
+
|
| 405 |
+
board[2 * p_x][2 * p_y] = 'P'
|
| 406 |
+
board[2 * e_x][2 * e_y] = 'E'
|
| 407 |
+
|
| 408 |
+
turn_info = "<Enemy's Turn>" if is_first_player else "<Player's Turn>"
|
| 409 |
+
|
| 410 |
+
if not is_first_player:
|
| 411 |
+
self.rotate_walls()
|
| 412 |
+
|
| 413 |
+
# Set walls
|
| 414 |
+
for i in range(N - 1):
|
| 415 |
+
for j in range(N - 1):
|
| 416 |
+
pos = i * (N - 1) + j
|
| 417 |
+
if self.walls[pos] == 1:
|
| 418 |
+
board[2 * i + 1][2 * j] = '-'
|
| 419 |
+
board[2 * i + 1][2 * (j + 1)] = '-'
|
| 420 |
+
if self.walls[pos] == 2:
|
| 421 |
+
board[2 * i][2 * j + 1] = '|'
|
| 422 |
+
board[2 * (i + 1)][2 * j + 1] = '|'
|
| 423 |
+
|
| 424 |
+
if not is_first_player:
|
| 425 |
+
self.rotate_walls()
|
| 426 |
+
|
| 427 |
+
board_str = '\n'.join([''.join(row) for row in board])
|
| 428 |
+
return turn_info + '\n' + board_str
|
| 429 |
+
|
| 430 |
+
|
| 431 |
+
# Randomly select an action
|
| 432 |
+
def random_action(state):
|
| 433 |
+
legal_actions = state.legal_actions()
|
| 434 |
+
action = random.randint(0, len(legal_actions) - 1)
|
| 435 |
+
return legal_actions[action]
|
| 436 |
+
|
| 437 |
+
# Calculate state value using alpha-beta pruning
|
| 438 |
+
def alpha_beta(state, alpha, beta):
|
| 439 |
+
# Loss is -1
|
| 440 |
+
if state.is_lose():
|
| 441 |
+
return -1
|
| 442 |
+
|
| 443 |
+
# Draw is 0
|
| 444 |
+
if state.is_draw():
|
| 445 |
+
return 0
|
| 446 |
+
|
| 447 |
+
# Calculate state values for legal actions
|
| 448 |
+
for action in state.legal_actions():
|
| 449 |
+
score = -alpha_beta(state.next(action), -beta, -alpha)
|
| 450 |
+
if score > alpha:
|
| 451 |
+
alpha = score
|
| 452 |
+
|
| 453 |
+
# If the best score for the current node exceeds the parent node, stop the search
|
| 454 |
+
if alpha >= beta:
|
| 455 |
+
return alpha
|
| 456 |
+
|
| 457 |
+
# Return the maximum value of the state values for legal actions
|
| 458 |
+
return alpha
|
| 459 |
+
|
| 460 |
+
# Select an action using alpha-beta pruning
|
| 461 |
+
def alpha_beta_action(state):
|
| 462 |
+
# Calculate state values for legal actions
|
| 463 |
+
best_action = 0
|
| 464 |
+
alpha = -float('inf')
|
| 465 |
+
for action in state.legal_actions():
|
| 466 |
+
score = -alpha_beta(state.next(action), -float('inf'), -alpha)
|
| 467 |
+
if score > alpha:
|
| 468 |
+
best_action = action
|
| 469 |
+
alpha = score
|
| 470 |
+
|
| 471 |
+
# Return the action with the maximum state value
|
| 472 |
+
return best_action
|
| 473 |
+
|
| 474 |
+
# Playout
|
| 475 |
+
def playout(state):
|
| 476 |
+
# Loss is -1
|
| 477 |
+
if state.is_lose():
|
| 478 |
+
return -1
|
| 479 |
+
|
| 480 |
+
# Draw is 0
|
| 481 |
+
if state.is_draw():
|
| 482 |
+
return 0
|
| 483 |
+
|
| 484 |
+
# Next state value
|
| 485 |
+
return -playout(state.next(random_action(state)))
|
| 486 |
+
|
| 487 |
+
# Return the index of the maximum value
|
| 488 |
+
def argmax(collection):
|
| 489 |
+
return collection.index(max(collection))
|
| 490 |
+
|
| 491 |
+
# Select an action using Monte Carlo Tree Search
|
| 492 |
+
def mcts_action(state):
|
| 493 |
+
# Node for Monte Carlo Tree Search
|
| 494 |
+
class node:
|
| 495 |
+
# Initialization
|
| 496 |
+
def __init__(self, state):
|
| 497 |
+
self.state = state # State
|
| 498 |
+
self.w = 0 # Cumulative value
|
| 499 |
+
self.n = 0 # Number of trials
|
| 500 |
+
self.child_nodes = None # Child nodes
|
| 501 |
+
|
| 502 |
+
# Evaluation
|
| 503 |
+
def evaluate(self):
|
| 504 |
+
# When the game ends
|
| 505 |
+
if self.state.is_done():
|
| 506 |
+
# Get value from the game result
|
| 507 |
+
value = -1 if self.state.is_lose() else 0 # Loss is -1, draw is 0
|
| 508 |
+
|
| 509 |
+
# Update cumulative value and number of trials
|
| 510 |
+
self.w += value
|
| 511 |
+
self.n += 1
|
| 512 |
+
return value
|
| 513 |
+
|
| 514 |
+
# When there are no child nodes
|
| 515 |
+
if not self.child_nodes:
|
| 516 |
+
# Get value from playout
|
| 517 |
+
value = playout(self.state)
|
| 518 |
+
|
| 519 |
+
# Update cumulative value and number of trials
|
| 520 |
+
self.w += value
|
| 521 |
+
self.n += 1
|
| 522 |
+
|
| 523 |
+
# Expand child nodes
|
| 524 |
+
if self.n == 10:
|
| 525 |
+
self.expand()
|
| 526 |
+
return value
|
| 527 |
+
|
| 528 |
+
# When there are child nodes
|
| 529 |
+
else:
|
| 530 |
+
# Get value from evaluating the child node with the maximum UCB1
|
| 531 |
+
value = -self.next_child_node().evaluate()
|
| 532 |
+
|
| 533 |
+
# Update cumulative value and number of trials
|
| 534 |
+
self.w += value
|
| 535 |
+
self.n += 1
|
| 536 |
+
return value
|
| 537 |
+
|
| 538 |
+
# Expand child nodes
|
| 539 |
+
def expand(self):
|
| 540 |
+
legal_actions = self.state.legal_actions()
|
| 541 |
+
self.child_nodes = []
|
| 542 |
+
for action in legal_actions:
|
| 543 |
+
self.child_nodes.append(node(self.state.next(action)))
|
| 544 |
+
|
| 545 |
+
# Get the child node with the maximum UCB1
|
| 546 |
+
def next_child_node(self):
|
| 547 |
+
# Return the child node with n=0
|
| 548 |
+
for child_node in self.child_nodes:
|
| 549 |
+
if child_node.n == 0:
|
| 550 |
+
return child_node
|
| 551 |
+
|
| 552 |
+
# Calculate UCB1
|
| 553 |
+
t = 0
|
| 554 |
+
for c in self.child_nodes:
|
| 555 |
+
t += c.n
|
| 556 |
+
ucb1_values = []
|
| 557 |
+
for child_node in self.child_nodes:
|
| 558 |
+
ucb1_values.append(-child_node.w/child_node.n + 2*(2*math.log(t)/child_node.n)**0.5)
|
| 559 |
+
|
| 560 |
+
# Return the child node with the maximum UCB1
|
| 561 |
+
return self.child_nodes[argmax(ucb1_values)]
|
| 562 |
+
|
| 563 |
+
# Generate the root node
|
| 564 |
+
root_node = node(state)
|
| 565 |
+
root_node.expand()
|
| 566 |
+
|
| 567 |
+
# Evaluate the root node 100 times
|
| 568 |
+
for _ in range(100):
|
| 569 |
+
root_node.evaluate()
|
| 570 |
+
|
| 571 |
+
# Return the action with the maximum number of trials
|
| 572 |
+
legal_actions = state.legal_actions()
|
| 573 |
+
n_list = []
|
| 574 |
+
for c in root_node.child_nodes:
|
| 575 |
+
n_list.append(c.n)
|
| 576 |
+
return legal_actions[argmax(n_list)]
|
| 577 |
+
|
| 578 |
+
# Running the function
|
| 579 |
+
if __name__ == '__main__':
|
| 580 |
+
# Generate the state
|
| 581 |
+
state = State()
|
| 582 |
+
|
| 583 |
+
# Loop until the game ends
|
| 584 |
+
while True:
|
| 585 |
+
# When the game ends
|
| 586 |
+
if state.is_done():
|
| 587 |
+
break
|
| 588 |
+
|
| 589 |
+
# Get the next state
|
| 590 |
+
state = state.next(random_action(state))
|
| 591 |
+
|
| 592 |
+
# Display as a string
|
| 593 |
+
print(state)
|
| 594 |
+
print()
|
human_play.py
ADDED
|
@@ -0,0 +1,229 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Importing necessary packages and modules
|
| 2 |
+
from game import State
|
| 3 |
+
from pv_mcts import pv_mcts_action, random_action
|
| 4 |
+
from tensorflow.keras.models import load_model
|
| 5 |
+
import tkinter as tk
|
| 6 |
+
|
| 7 |
+
# Loading the best player's model
|
| 8 |
+
model = load_model('./model/best.keras')
|
| 9 |
+
|
| 10 |
+
# Defining the Game UI
|
| 11 |
+
class GameUI(tk.Frame):
|
| 12 |
+
# Initialization
|
| 13 |
+
def __init__(self, master=None, model=None):
|
| 14 |
+
tk.Frame.__init__(self, master)
|
| 15 |
+
self.master.title('Quoridor')
|
| 16 |
+
|
| 17 |
+
# Generating the game state
|
| 18 |
+
self.state = State()
|
| 19 |
+
self.N = self.state.N
|
| 20 |
+
self.D = 200 # Cell size (pixels)
|
| 21 |
+
self.L = self.N * self.D # Canvas size
|
| 22 |
+
|
| 23 |
+
self.select = -1 # Selection (-1: none, 0~(N*N-1): square)
|
| 24 |
+
self.placing_wall = False # Flag to indicate if we are placing a wall
|
| 25 |
+
|
| 26 |
+
# Creating the function for action selection using PV MCTS
|
| 27 |
+
self.next_action = pv_mcts_action(model) if model else random_action()
|
| 28 |
+
|
| 29 |
+
# Main frame layout
|
| 30 |
+
self.grid()
|
| 31 |
+
|
| 32 |
+
# Creating the canvas for the game board
|
| 33 |
+
self.c = tk.Canvas(self, width=self.L, height=self.L, highlightthickness=0)
|
| 34 |
+
self.c.bind('<Button-1>', self.turn_of_human)
|
| 35 |
+
self.c.grid(row=1, column=1, padx=10, pady=10)
|
| 36 |
+
|
| 37 |
+
# Displaying the player's walls on the left
|
| 38 |
+
self.player_walls_frame = tk.Frame(self)
|
| 39 |
+
self.player_walls_frame.grid(row=1, column=2, padx=10, pady=10)
|
| 40 |
+
self.player_walls = tk.Label(self.player_walls_frame, text="Player Walls", anchor="center", justify=tk.CENTER, font=('Helvetica', 24))
|
| 41 |
+
self.player_walls.pack()
|
| 42 |
+
|
| 43 |
+
# Displaying the enemy's walls on the right
|
| 44 |
+
self.enemy_walls_frame = tk.Frame(self)
|
| 45 |
+
self.enemy_walls_frame.grid(row=1, column=0, padx=10, pady=10)
|
| 46 |
+
self.enemy_walls = tk.Label(self.enemy_walls_frame, text="Enemy Walls", anchor="center", justify=tk.CENTER, font=('Helvetica', 24))
|
| 47 |
+
self.enemy_walls.pack()
|
| 48 |
+
|
| 49 |
+
# Displaying the action buttons below the game board
|
| 50 |
+
self.controls_frame = tk.Frame(self)
|
| 51 |
+
self.controls_frame.grid(row=2, column=1, padx=10, pady=10)
|
| 52 |
+
self.wall_button = tk.Button(self.controls_frame, text="Place Wall", command=self.place_wall_mode)
|
| 53 |
+
self.wall_button.pack()
|
| 54 |
+
|
| 55 |
+
self.wall_direction = tk.StringVar(value="horizontal")
|
| 56 |
+
self.wall_horizontal_button = tk.Radiobutton(self.controls_frame, text="Horizontal", variable=self.wall_direction, value="horizontal")
|
| 57 |
+
self.wall_vertical_button = tk.Radiobutton(self.controls_frame, text="Vertical", variable=self.wall_direction, value="vertical")
|
| 58 |
+
self.wall_horizontal_button.pack()
|
| 59 |
+
self.wall_vertical_button.pack()
|
| 60 |
+
|
| 61 |
+
# Result message
|
| 62 |
+
self.result_message = tk.Label(self, text="", font=('Helvetica', 60))
|
| 63 |
+
self.result_message.grid(row=0, column=1, pady=10)
|
| 64 |
+
|
| 65 |
+
# Updating the drawing
|
| 66 |
+
self.on_draw()
|
| 67 |
+
|
| 68 |
+
def place_wall_mode(self):
|
| 69 |
+
self.placing_wall = not self.placing_wall
|
| 70 |
+
self.wall_button.config(text="Move Piece" if self.placing_wall else "Place Wall")
|
| 71 |
+
|
| 72 |
+
# Human's turn
|
| 73 |
+
def turn_of_human(self, event):
|
| 74 |
+
N = self.N
|
| 75 |
+
D = self.D
|
| 76 |
+
# If the game is over
|
| 77 |
+
if self.state.is_done():
|
| 78 |
+
return
|
| 79 |
+
|
| 80 |
+
# If it is not the first player's turn
|
| 81 |
+
if not self.state.is_first_player():
|
| 82 |
+
return
|
| 83 |
+
|
| 84 |
+
# Calculate the selection and move position
|
| 85 |
+
if self.placing_wall:
|
| 86 |
+
x, y = (event.x - D // 2) // D, (event.y - D // 2) // D
|
| 87 |
+
print(x, y)
|
| 88 |
+
if 0 <= x < N - 1 and 0 <= y < N - 1:
|
| 89 |
+
self.place_wall(x, y)
|
| 90 |
+
else:
|
| 91 |
+
x, y = event.x // D, event.y // D
|
| 92 |
+
self.select = N * y + x
|
| 93 |
+
action = self.select
|
| 94 |
+
|
| 95 |
+
# Convert selection and move to action
|
| 96 |
+
|
| 97 |
+
# If the action is not legal
|
| 98 |
+
if not (action in self.state.legal_actions()):
|
| 99 |
+
self.select = -1
|
| 100 |
+
self.on_draw()
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
# Get the next state
|
| 104 |
+
self.state = self.state.next(action)
|
| 105 |
+
self.select = -1
|
| 106 |
+
self.on_draw()
|
| 107 |
+
|
| 108 |
+
# AI's turn
|
| 109 |
+
self.master.after(500, self.turn_of_ai)
|
| 110 |
+
|
| 111 |
+
def place_wall(self, x, y):
|
| 112 |
+
N = self.N
|
| 113 |
+
# Adjusted logic for placing walls at grid points
|
| 114 |
+
if self.wall_direction.get() == "horizontal":
|
| 115 |
+
action = N ** 2 + (N - 1) * y + x
|
| 116 |
+
else:
|
| 117 |
+
action = N ** 2 + (N - 1) ** 2 + (N - 1) * y + x
|
| 118 |
+
|
| 119 |
+
# Check if the action is legal
|
| 120 |
+
if action in self.state.legal_actions():
|
| 121 |
+
# Get the next state
|
| 122 |
+
self.state = self.state.next(action)
|
| 123 |
+
self.placing_wall = False
|
| 124 |
+
self.wall_button.config(text="Place Wall")
|
| 125 |
+
self.on_draw()
|
| 126 |
+
else:
|
| 127 |
+
self.placing_wall = False
|
| 128 |
+
self.wall_button.config(text="Place Wall")
|
| 129 |
+
self.on_draw()
|
| 130 |
+
|
| 131 |
+
# AI's turn
|
| 132 |
+
def turn_of_ai(self):
|
| 133 |
+
# If the game is over
|
| 134 |
+
if self.state.is_done():
|
| 135 |
+
self.display_result()
|
| 136 |
+
self.master.after(1000, self.reset_game)
|
| 137 |
+
return
|
| 138 |
+
|
| 139 |
+
# Get the action
|
| 140 |
+
action = self.next_action(self.state)
|
| 141 |
+
|
| 142 |
+
# Get the next state
|
| 143 |
+
self.state = self.state.next(action)
|
| 144 |
+
self.on_draw()
|
| 145 |
+
|
| 146 |
+
if self.state.is_done():
|
| 147 |
+
self.display_result()
|
| 148 |
+
self.master.after(1000, self.reset_game)
|
| 149 |
+
return
|
| 150 |
+
|
| 151 |
+
def display_result(self):
|
| 152 |
+
is_lose = self.state.is_lose() if self.state.is_first_player() else not self.state.is_lose()
|
| 153 |
+
if is_lose:
|
| 154 |
+
self.result_message.config(text="You Lose", fg="blue")
|
| 155 |
+
else:
|
| 156 |
+
self.result_message.config(text="You Win", fg="red")
|
| 157 |
+
|
| 158 |
+
def reset_game(self):
|
| 159 |
+
self.state = State()
|
| 160 |
+
self.on_draw()
|
| 161 |
+
self.result_message.config(text="")
|
| 162 |
+
|
| 163 |
+
# Draw the piece
|
| 164 |
+
def draw_piece(self, index, color):
|
| 165 |
+
N = self.N
|
| 166 |
+
D = self.D
|
| 167 |
+
x = (index % N) * D
|
| 168 |
+
y = (index // N) * D
|
| 169 |
+
margin = D // 10
|
| 170 |
+
self.c.create_oval(x + margin, y + margin, x + D - margin, y + D - margin, fill=color, outline='black')
|
| 171 |
+
|
| 172 |
+
# Draw the walls
|
| 173 |
+
def draw_walls(self):
|
| 174 |
+
N = self.N
|
| 175 |
+
D = self.D
|
| 176 |
+
for i in range(len(self.state.walls)):
|
| 177 |
+
x, y = i % (N - 1), i // (N - 1)
|
| 178 |
+
if self.state.walls[i] == 1:
|
| 179 |
+
x1, y1 = x * D, (y + 1) * D
|
| 180 |
+
x2, y2 = (x + 2) * D, (y + 1) * D
|
| 181 |
+
self.c.create_line(x1, y1, x2, y2, width=16.0, fill='#D1B575')
|
| 182 |
+
elif self.state.walls[i] == 2:
|
| 183 |
+
x1, y1 = (x + 1) * D, y * D
|
| 184 |
+
x2, y2 = (x + 1) * D, (y + 2) * D
|
| 185 |
+
self.c.create_line(x1, y1, x2, y2, width=16.0, fill='#D1B575')
|
| 186 |
+
|
| 187 |
+
# Update the drawing
|
| 188 |
+
def on_draw(self):
|
| 189 |
+
N = self.N
|
| 190 |
+
D = self.D
|
| 191 |
+
L = self.L
|
| 192 |
+
is_first_player = self.state.is_first_player()
|
| 193 |
+
|
| 194 |
+
# Grid
|
| 195 |
+
self.c.delete('all')
|
| 196 |
+
self.c.create_rectangle(0, 0, L, L, width=0.0, fill='#4B4B4B')
|
| 197 |
+
for i in range(1, N):
|
| 198 |
+
self.c.create_line(i * D, 0, i * D, L, width=16.0, fill='#8B0000')
|
| 199 |
+
self.c.create_line(0, i * D, L, i * D, width=16.0, fill='#8B0000')
|
| 200 |
+
|
| 201 |
+
# Pieces
|
| 202 |
+
p_pos = self.state.player[0] if is_first_player else self.state.enemy[0]
|
| 203 |
+
e_pos = self.state.enemy[0] if is_first_player else self.state.player[0]
|
| 204 |
+
e_pos = N ** 2 - 1 - e_pos
|
| 205 |
+
|
| 206 |
+
self.draw_piece(p_pos, '#D2B48C')
|
| 207 |
+
self.draw_piece(e_pos, '#5D3A3A')
|
| 208 |
+
|
| 209 |
+
p_walls = self.state.player[1] if is_first_player else self.state.enemy[1]
|
| 210 |
+
e_walls = self.state.enemy[1] if is_first_player else self.state.player[1]
|
| 211 |
+
|
| 212 |
+
# Update the wall count
|
| 213 |
+
self.player_walls.config(text=f"Player Walls\n{p_walls}")
|
| 214 |
+
self.enemy_walls.config(text=f"Enemy Walls\n{e_walls}")
|
| 215 |
+
|
| 216 |
+
if not is_first_player:
|
| 217 |
+
self.state.rotate_walls()
|
| 218 |
+
|
| 219 |
+
# Walls
|
| 220 |
+
self.draw_walls()
|
| 221 |
+
|
| 222 |
+
if not is_first_player:
|
| 223 |
+
self.state.rotate_walls()
|
| 224 |
+
|
| 225 |
+
# Run the game UI
|
| 226 |
+
if __name__ == '__main__':
|
| 227 |
+
f = GameUI(model=model)
|
| 228 |
+
f.pack()
|
| 229 |
+
f.mainloop()
|
pv_mcts.py
ADDED
|
@@ -0,0 +1,168 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Monte Carlo Tree Search Implementation
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Import packages
|
| 6 |
+
from game import State
|
| 7 |
+
from dual_network import DN_INPUT_SHAPE
|
| 8 |
+
from math import sqrt
|
| 9 |
+
from tensorflow.keras.models import load_model
|
| 10 |
+
from pathlib import Path
|
| 11 |
+
import numpy as np
|
| 12 |
+
from copy import deepcopy
|
| 13 |
+
import random
|
| 14 |
+
|
| 15 |
+
# Prepare parameters
|
| 16 |
+
PV_EVALUATE_COUNT = 50 # Number of simulations per inference (original is 1600)
|
| 17 |
+
|
| 18 |
+
# Inference
|
| 19 |
+
def predict(model, state):
|
| 20 |
+
# Reshape input data for inference
|
| 21 |
+
a, b, c = DN_INPUT_SHAPE
|
| 22 |
+
x = np.array(state.pieces_array())
|
| 23 |
+
x = x.reshape(c, a, b).transpose(1, 2, 0).reshape(1, a, b, c)
|
| 24 |
+
|
| 25 |
+
# Inference
|
| 26 |
+
y = model.predict(x, batch_size=1)
|
| 27 |
+
|
| 28 |
+
# Get policy
|
| 29 |
+
policies = y[0][0][list(state.legal_actions())] # Only legal moves
|
| 30 |
+
policies /= np.sum(policies) if np.sum(policies) else 1 # Convert to a probability distribution summing to 1
|
| 31 |
+
|
| 32 |
+
# Get value
|
| 33 |
+
value = y[1][0][0]
|
| 34 |
+
return policies, value
|
| 35 |
+
|
| 36 |
+
# Convert list of nodes to list of scores
|
| 37 |
+
def nodes_to_scores(nodes):
|
| 38 |
+
scores = []
|
| 39 |
+
for c in nodes:
|
| 40 |
+
scores.append(c.n)
|
| 41 |
+
return scores
|
| 42 |
+
|
| 43 |
+
# Get Monte Carlo Tree Search scores
|
| 44 |
+
def pv_mcts_scores(model, state, temperature):
|
| 45 |
+
# Define Monte Carlo Tree Search node
|
| 46 |
+
class Node:
|
| 47 |
+
# Initialize node
|
| 48 |
+
def __init__(self, state, p):
|
| 49 |
+
self.state = state # State
|
| 50 |
+
self.p = p # Policy
|
| 51 |
+
self.w = 0 # Cumulative value
|
| 52 |
+
self.n = 0 # Number of simulations
|
| 53 |
+
self.child_nodes = None # Child nodes
|
| 54 |
+
|
| 55 |
+
# Calculate value of the state
|
| 56 |
+
def evaluate(self):
|
| 57 |
+
# If the game is over
|
| 58 |
+
if self.state.is_done():
|
| 59 |
+
# Get value from the game result
|
| 60 |
+
value = -1 if self.state.is_lose() else 0
|
| 61 |
+
|
| 62 |
+
# Update cumulative value and number of simulations
|
| 63 |
+
self.w += value
|
| 64 |
+
self.n += 1
|
| 65 |
+
return value
|
| 66 |
+
|
| 67 |
+
# If there are no child nodes
|
| 68 |
+
if not self.child_nodes:
|
| 69 |
+
# Get policy and value from neural network inference
|
| 70 |
+
policies, value = predict(model, self.state)
|
| 71 |
+
|
| 72 |
+
# Update cumulative value and number of simulations
|
| 73 |
+
self.w += value
|
| 74 |
+
self.n += 1
|
| 75 |
+
|
| 76 |
+
# Expand child nodes
|
| 77 |
+
self.child_nodes = []
|
| 78 |
+
for action, policy in zip(self.state.legal_actions(), policies):
|
| 79 |
+
self.child_nodes.append(Node(self.state.next(action), policy))
|
| 80 |
+
return value
|
| 81 |
+
|
| 82 |
+
# If there are child nodes
|
| 83 |
+
else:
|
| 84 |
+
# Get value from the evaluation of the child node with the maximum arc evaluation value
|
| 85 |
+
value = -self.next_child_node().evaluate()
|
| 86 |
+
|
| 87 |
+
# Update cumulative value and number of simulations
|
| 88 |
+
self.w += value
|
| 89 |
+
self.n += 1
|
| 90 |
+
return value
|
| 91 |
+
|
| 92 |
+
# Get child node with the maximum arc evaluation value
|
| 93 |
+
def next_child_node(self):
|
| 94 |
+
# Calculate arc evaluation value
|
| 95 |
+
C_PUCT = 1.0
|
| 96 |
+
t = sum(nodes_to_scores(self.child_nodes))
|
| 97 |
+
pucb_values = []
|
| 98 |
+
for child_node in self.child_nodes:
|
| 99 |
+
pucb_values.append((-child_node.w / child_node.n if child_node.n else 0.0) +
|
| 100 |
+
C_PUCT * child_node.p * sqrt(t) / (1 + child_node.n))
|
| 101 |
+
|
| 102 |
+
# Return child node with the maximum arc evaluation value
|
| 103 |
+
return self.child_nodes[np.argmax(pucb_values)]
|
| 104 |
+
|
| 105 |
+
# Create a node for the current state
|
| 106 |
+
root_node = Node(state, 0)
|
| 107 |
+
|
| 108 |
+
# Perform multiple evaluations
|
| 109 |
+
for _ in range(PV_EVALUATE_COUNT):
|
| 110 |
+
root_node.evaluate()
|
| 111 |
+
|
| 112 |
+
# Probability distribution of legal moves
|
| 113 |
+
scores = nodes_to_scores(root_node.child_nodes)
|
| 114 |
+
if temperature == 0: # Only the maximum value is 1
|
| 115 |
+
action = np.argmax(scores)
|
| 116 |
+
scores = np.zeros(len(scores))
|
| 117 |
+
scores[action] = 1
|
| 118 |
+
else: # Add variation with Boltzmann distribution
|
| 119 |
+
scores = boltzman(scores, temperature)
|
| 120 |
+
return scores
|
| 121 |
+
|
| 122 |
+
# Action selection with Monte Carlo Tree Search
|
| 123 |
+
def pv_mcts_action(model, temperature=0):
|
| 124 |
+
def pv_mcts_action(state):
|
| 125 |
+
scores = pv_mcts_scores(model, deepcopy(state), temperature)
|
| 126 |
+
|
| 127 |
+
return np.random.choice(state.legal_actions(), p=scores)
|
| 128 |
+
return pv_mcts_action
|
| 129 |
+
|
| 130 |
+
# Boltzmann distribution
|
| 131 |
+
def boltzman(xs, temperature):
|
| 132 |
+
xs = [x ** (1 / temperature) for x in xs]
|
| 133 |
+
return [x / sum(xs) for x in xs]
|
| 134 |
+
|
| 135 |
+
def random_action():
|
| 136 |
+
def random_action(state):
|
| 137 |
+
legal_actions = state.legal_actions()
|
| 138 |
+
action = random.randint(0, len(legal_actions) - 1)
|
| 139 |
+
|
| 140 |
+
return legal_actions[action]
|
| 141 |
+
return random_action
|
| 142 |
+
|
| 143 |
+
# Confirm operation
|
| 144 |
+
if __name__ == '__main__':
|
| 145 |
+
# Load model
|
| 146 |
+
path = sorted(Path('./model').glob('*.keras'))[-1]
|
| 147 |
+
model = load_model(str(path))
|
| 148 |
+
|
| 149 |
+
# Generate state
|
| 150 |
+
state = State()
|
| 151 |
+
|
| 152 |
+
# Create function to get actions with Monte Carlo Tree Search
|
| 153 |
+
next_action = pv_mcts_action(model, 1.0)
|
| 154 |
+
|
| 155 |
+
# Loop until the game is over
|
| 156 |
+
while True:
|
| 157 |
+
# If the game is over
|
| 158 |
+
if state.is_done():
|
| 159 |
+
break
|
| 160 |
+
|
| 161 |
+
# Get action
|
| 162 |
+
action = next_action(state)
|
| 163 |
+
|
| 164 |
+
# Get next state
|
| 165 |
+
state = state.next(action)
|
| 166 |
+
|
| 167 |
+
# Print state
|
| 168 |
+
print(state)
|
requirements.txt
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
tensorflow
|
| 2 |
+
numpy
|
| 3 |
+
matplotlib
|
| 4 |
+
pandas
|
self_play.py
ADDED
|
@@ -0,0 +1,101 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Self-Play Part
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Importing packages
|
| 6 |
+
from game import State
|
| 7 |
+
from pv_mcts import pv_mcts_scores
|
| 8 |
+
from dual_network import DN_OUTPUT_SIZE
|
| 9 |
+
from datetime import datetime
|
| 10 |
+
from tensorflow.keras.models import load_model
|
| 11 |
+
from tensorflow.keras import backend as K
|
| 12 |
+
from pathlib import Path
|
| 13 |
+
import numpy as np
|
| 14 |
+
import pickle
|
| 15 |
+
import os
|
| 16 |
+
from copy import deepcopy
|
| 17 |
+
|
| 18 |
+
# Preparing parameters
|
| 19 |
+
SP_GAME_COUNT = 50 # Number of games for self-play (25000 in the original version)
|
| 20 |
+
SP_TEMPERATURE = 1.0 # Temperature parameter for Boltzmann distribution
|
| 21 |
+
|
| 22 |
+
# Value of the first player
|
| 23 |
+
def first_player_value(ended_state):
|
| 24 |
+
# 1: First player wins, -1: First player loses, 0: Draw
|
| 25 |
+
if ended_state.is_lose():
|
| 26 |
+
return -1 if ended_state.is_first_player() else 1
|
| 27 |
+
return 0
|
| 28 |
+
|
| 29 |
+
# Saving training data
|
| 30 |
+
def write_data(history):
|
| 31 |
+
now = datetime.now()
|
| 32 |
+
os.makedirs('./data/', exist_ok=True) # Create folder if it does not exist
|
| 33 |
+
path = './data/{:04}{:02}{:02}{:02}{:02}{:02}.history'.format(
|
| 34 |
+
now.year, now.month, now.day, now.hour, now.minute, now.second)
|
| 35 |
+
with open(path, mode='wb') as f:
|
| 36 |
+
pickle.dump(history, f)
|
| 37 |
+
|
| 38 |
+
# Executing one game
|
| 39 |
+
def play(model):
|
| 40 |
+
# Training data
|
| 41 |
+
history = []
|
| 42 |
+
|
| 43 |
+
# Generating the state
|
| 44 |
+
state = State()
|
| 45 |
+
|
| 46 |
+
while True:
|
| 47 |
+
# When the game ends
|
| 48 |
+
if state.is_done():
|
| 49 |
+
break
|
| 50 |
+
|
| 51 |
+
# Getting the probability distribution of legal moves
|
| 52 |
+
scores = pv_mcts_scores(model, deepcopy(state), SP_TEMPERATURE)
|
| 53 |
+
|
| 54 |
+
# Adding the state and policy to the training data
|
| 55 |
+
policies = [0] * DN_OUTPUT_SIZE
|
| 56 |
+
for action, policy in zip(state.legal_actions(), scores):
|
| 57 |
+
policies[action] = policy
|
| 58 |
+
history.append([state.pieces_array(), policies, None])
|
| 59 |
+
|
| 60 |
+
# Getting the action
|
| 61 |
+
action = np.random.choice(state.legal_actions(), p=scores)
|
| 62 |
+
|
| 63 |
+
# Getting the next state
|
| 64 |
+
state = state.next(action)
|
| 65 |
+
|
| 66 |
+
# Adding the value to the training data
|
| 67 |
+
value = first_player_value(state)
|
| 68 |
+
for i in range(len(history)):
|
| 69 |
+
history[i][2] = value
|
| 70 |
+
value = -value
|
| 71 |
+
return history
|
| 72 |
+
|
| 73 |
+
# Self-Play
|
| 74 |
+
def self_play():
|
| 75 |
+
# Training data
|
| 76 |
+
history = []
|
| 77 |
+
|
| 78 |
+
# Loading the best player's model
|
| 79 |
+
model = load_model('./model/best.keras')
|
| 80 |
+
|
| 81 |
+
# Executing multiple games
|
| 82 |
+
for i in range(SP_GAME_COUNT):
|
| 83 |
+
# Executing one game
|
| 84 |
+
h = play(model)
|
| 85 |
+
history.extend(h)
|
| 86 |
+
|
| 87 |
+
# Output
|
| 88 |
+
print('\rSelfPlay {}/{}'.format(i+1, SP_GAME_COUNT), end='')
|
| 89 |
+
print('')
|
| 90 |
+
|
| 91 |
+
# Saving the training data
|
| 92 |
+
write_data(history)
|
| 93 |
+
|
| 94 |
+
# Clearing the model
|
| 95 |
+
K.clear_session()
|
| 96 |
+
del model
|
| 97 |
+
|
| 98 |
+
# Running the function
|
| 99 |
+
if __name__ == '__main__':
|
| 100 |
+
self_play()
|
| 101 |
+
|
train_cycle.py
ADDED
|
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# Execution of Learning Cycle
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
# Importing packages
|
| 6 |
+
from dual_network import dual_network
|
| 7 |
+
from self_play import self_play
|
| 8 |
+
from train_network import train_network
|
| 9 |
+
from evaluate_network import evaluate_network
|
| 10 |
+
from evaluate_best_player import evaluate_best_player
|
| 11 |
+
|
| 12 |
+
# Number of NUM_EPOCH
|
| 13 |
+
NUM_TRAIN_CYCLE = 3
|
| 14 |
+
|
| 15 |
+
# Main function
|
| 16 |
+
if __name__ == '__main__':
|
| 17 |
+
# Creating the dual network
|
| 18 |
+
dual_network()
|
| 19 |
+
|
| 20 |
+
for i in range(NUM_TRAIN_CYCLE):
|
| 21 |
+
print('Train', i, '====================')
|
| 22 |
+
# self-play part
|
| 23 |
+
self_play()
|
| 24 |
+
|
| 25 |
+
# parameter update part
|
| 26 |
+
train_network()
|
| 27 |
+
|
| 28 |
+
# Evaluating new parameters
|
| 29 |
+
update_best_player = evaluate_network()
|
| 30 |
+
|
| 31 |
+
# Evaluating the best player
|
| 32 |
+
if update_best_player:
|
| 33 |
+
evaluate_best_player()
|
train_network.py
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ====================
|
| 2 |
+
# parameter update part
|
| 3 |
+
# ====================
|
| 4 |
+
|
| 5 |
+
from dual_network import DN_INPUT_SHAPE
|
| 6 |
+
from tensorflow.keras.callbacks import LearningRateScheduler, LambdaCallback
|
| 7 |
+
from tensorflow.keras.models import load_model
|
| 8 |
+
from tensorflow.keras import backend as K
|
| 9 |
+
from pathlib import Path
|
| 10 |
+
import numpy as np
|
| 11 |
+
import pickle
|
| 12 |
+
|
| 13 |
+
NUM_EPOCH = 100
|
| 14 |
+
BATCH_SIZE = 128
|
| 15 |
+
|
| 16 |
+
def load_data():
|
| 17 |
+
history_path = sorted(Path('./data').glob('*.history'))[-1]
|
| 18 |
+
with history_path.open(mode='rb') as f:
|
| 19 |
+
return pickle.load(f)
|
| 20 |
+
|
| 21 |
+
# Training the dual network
|
| 22 |
+
def train_network():
|
| 23 |
+
# Loading training data
|
| 24 |
+
history = load_data()
|
| 25 |
+
s, p, v = zip(*history)
|
| 26 |
+
|
| 27 |
+
# Reshaping the input data for training
|
| 28 |
+
a, b, c = DN_INPUT_SHAPE
|
| 29 |
+
s = np.array(s)
|
| 30 |
+
s = s.reshape(len(s), c, a, b).transpose(0, 2, 3, 1)
|
| 31 |
+
p = np.array(p)
|
| 32 |
+
v = np.array(v)
|
| 33 |
+
|
| 34 |
+
# Loading the best player's model
|
| 35 |
+
model = load_model('./model/best.keras')
|
| 36 |
+
|
| 37 |
+
# Compiling the model
|
| 38 |
+
model.compile(loss=['categorical_crossentropy', 'mse'], optimizer='adam')
|
| 39 |
+
|
| 40 |
+
# Learning rate
|
| 41 |
+
def step_decay(epoch):
|
| 42 |
+
x = 0.001
|
| 43 |
+
if epoch >= 50: x = 0.0005
|
| 44 |
+
if epoch >= 80: x = 0.00025
|
| 45 |
+
return x
|
| 46 |
+
lr_decay = LearningRateScheduler(step_decay)
|
| 47 |
+
|
| 48 |
+
# Output
|
| 49 |
+
print_callback = LambdaCallback(
|
| 50 |
+
on_epoch_begin=lambda epoch, logs: print('\rTrain {}/{}'.format(epoch + 1, NUM_EPOCH), end='')
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
# Executing training
|
| 54 |
+
model.fit(
|
| 55 |
+
s, [p, v], batch_size=BATCH_SIZE , epochs=NUM_EPOCH, verbose=0, callbacks=[lr_decay, print_callback]
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
# Saving the latest player's model
|
| 59 |
+
model.save('./model/latest.keras')
|
| 60 |
+
|
| 61 |
+
# Clearing the model
|
| 62 |
+
K.clear_session()
|
| 63 |
+
del model
|
| 64 |
+
|
| 65 |
+
if __name__ == '__main__':
|
| 66 |
+
train_network()
|