Spaces:
Sleeping
Sleeping
AI VS AI bugs
Browse files- Gomoku_MCTS/policy_value_net_pytorch_new.py +4 -4
- const.py +17 -2
- pages/AI_VS_AI.py +256 -188
- pages/Player_VS_AI.py +59 -18
- pages/Try.py +17 -0
Gomoku_MCTS/policy_value_net_pytorch_new.py
CHANGED
|
@@ -20,8 +20,6 @@ def set_learning_rate(optimizer, lr):
|
|
| 20 |
param_group['lr'] = lr
|
| 21 |
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
class ResidualBlock(nn.Module):
|
| 26 |
def __init__(self, channels):
|
| 27 |
super(ResidualBlock, self).__init__()
|
|
@@ -37,8 +35,10 @@ class ResidualBlock(nn.Module):
|
|
| 37 |
out += residual
|
| 38 |
return F.relu(out)
|
| 39 |
|
|
|
|
| 40 |
class Net(nn.Module):
|
| 41 |
"""Policy-Value network module for AlphaZero Gomoku."""
|
|
|
|
| 42 |
def __init__(self, board_width, board_height, num_residual_blocks=5):
|
| 43 |
super(Net, self).__init__()
|
| 44 |
self.board_width = board_width
|
|
@@ -78,7 +78,7 @@ class PolicyValueNet():
|
|
| 78 |
"""policy-value network """
|
| 79 |
|
| 80 |
def __init__(self, board_width, board_height,
|
| 81 |
-
model_file=None, use_gpu=False, bias
|
| 82 |
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 83 |
self.use_gpu = use_gpu
|
| 84 |
self.l2_const = 1e-4 # coef of l2 penalty
|
|
@@ -111,7 +111,7 @@ class PolicyValueNet():
|
|
| 111 |
self.policy_value_net = Net(board_width, board_height)
|
| 112 |
|
| 113 |
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
| 114 |
-
|
| 115 |
|
| 116 |
def infer_board_size_from_model(self, model):
|
| 117 |
# Use the size of the act_fc1 layer to infer board dimensions
|
|
|
|
| 20 |
param_group['lr'] = lr
|
| 21 |
|
| 22 |
|
|
|
|
|
|
|
| 23 |
class ResidualBlock(nn.Module):
|
| 24 |
def __init__(self, channels):
|
| 25 |
super(ResidualBlock, self).__init__()
|
|
|
|
| 35 |
out += residual
|
| 36 |
return F.relu(out)
|
| 37 |
|
| 38 |
+
|
| 39 |
class Net(nn.Module):
|
| 40 |
"""Policy-Value network module for AlphaZero Gomoku."""
|
| 41 |
+
|
| 42 |
def __init__(self, board_width, board_height, num_residual_blocks=5):
|
| 43 |
super(Net, self).__init__()
|
| 44 |
self.board_width = board_width
|
|
|
|
| 78 |
"""policy-value network """
|
| 79 |
|
| 80 |
def __init__(self, board_width, board_height,
|
| 81 |
+
model_file=None, use_gpu=False, bias=False):
|
| 82 |
self.device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
| 83 |
self.use_gpu = use_gpu
|
| 84 |
self.l2_const = 1e-4 # coef of l2 penalty
|
|
|
|
| 111 |
self.policy_value_net = Net(board_width, board_height)
|
| 112 |
|
| 113 |
self.optimizer = optim.Adam(self.policy_value_net.parameters(),
|
| 114 |
+
weight_decay=self.l2_const)
|
| 115 |
|
| 116 |
def infer_board_size_from_model(self, model):
|
| 117 |
# Use the size of the act_fc1 layer to infer board dimensions
|
const.py
CHANGED
|
@@ -15,13 +15,22 @@ _BLANK = 0
|
|
| 15 |
_BLACK = 1
|
| 16 |
_WHITE = 2
|
| 17 |
_NEW = 3
|
| 18 |
-
|
| 19 |
_WHITE: "⚪",
|
| 20 |
_BLANK: "➕",
|
| 21 |
_BLACK: "⚫",
|
| 22 |
_NEW: "🔴",
|
|
|
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
}
|
|
|
|
|
|
|
|
|
|
| 25 |
_PLAYER_COLOR = {
|
| 26 |
_WHITE: "AI",
|
| 27 |
_BLANK: "Blank",
|
|
@@ -70,8 +79,14 @@ _ROOM_COLOR = {
|
|
| 70 |
}
|
| 71 |
|
| 72 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 73 |
_MODEL_PATH = {
|
| 74 |
"AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model",
|
| 75 |
-
"duel": "Gomoku_MCTS/checkpoint/2023-12-14-
|
| 76 |
"Gumbel AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model",
|
| 77 |
}
|
|
|
|
| 15 |
_BLACK = 1
|
| 16 |
_WHITE = 2
|
| 17 |
_NEW = 3
|
| 18 |
+
_PLAYER_SYMBOL1 = {
|
| 19 |
_WHITE: "⚪",
|
| 20 |
_BLANK: "➕",
|
| 21 |
_BLACK: "⚫",
|
| 22 |
_NEW: "🔴",
|
| 23 |
+
}
|
| 24 |
|
| 25 |
+
_PLAYER_SYMBOL2 = {
|
| 26 |
+
_BLACK: "⚪",
|
| 27 |
+
_BLANK: "➕",
|
| 28 |
+
_WHITE: "⚫",
|
| 29 |
+
_NEW: "🔴",
|
| 30 |
}
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
|
| 34 |
_PLAYER_COLOR = {
|
| 35 |
_WHITE: "AI",
|
| 36 |
_BLANK: "Blank",
|
|
|
|
| 79 |
}
|
| 80 |
|
| 81 |
|
| 82 |
+
# _MODEL_PATH = {
|
| 83 |
+
# "AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model",
|
| 84 |
+
# "duel": "Gomoku_MCTS/checkpoint/2023-12-14-18-16-09_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model",
|
| 85 |
+
# "Gumbel AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model",
|
| 86 |
+
# }
|
| 87 |
+
|
| 88 |
_MODEL_PATH = {
|
| 89 |
"AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-18-17-07_test_teaching_learning_collect_epochs=1000_size=9_model=normal/best_policy.model",
|
| 90 |
+
"duel": "/Users/husky/GomokuDemo/Gomoku_MCTS/checkpoint/2023-12-14-10-22-12_test_teaching_learning_collect_epochs=1000_size=9_model=duel/best_policy.model",
|
| 91 |
"Gumbel AlphaZero": "Gomoku_MCTS/checkpoint/2023-12-14-21-19-40_selfplay_epochs=1000_size=9_model=gumbel/best_policy.model",
|
| 92 |
}
|
pages/AI_VS_AI.py
CHANGED
|
@@ -8,27 +8,25 @@ Description: this file is used to display our project and add visualization elem
|
|
| 8 |
import time
|
| 9 |
import pandas as pd
|
| 10 |
from copy import deepcopy
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
# import torch
|
| 14 |
import numpy as np
|
| 15 |
import streamlit as st
|
| 16 |
from scipy.signal import convolve # this is used to check if any player wins
|
| 17 |
from streamlit import session_state
|
| 18 |
from streamlit_server_state import server_state, server_state_lock
|
| 19 |
-
from Gomoku_MCTS import MCTSpure, alphazero, Board,
|
|
|
|
| 20 |
from Gomoku_Bot import Gomoku_bot
|
| 21 |
from Gomoku_Bot import Board as Gomoku_bot_board
|
| 22 |
-
import matplotlib.pyplot as plt
|
| 23 |
-
|
| 24 |
|
|
|
|
| 25 |
|
| 26 |
from const import (
|
| 27 |
_BLACK, # 1, for human
|
| 28 |
_WHITE, # 2 , for AI
|
| 29 |
_BLANK,
|
| 30 |
_PLAYER_COLOR,
|
| 31 |
-
|
|
|
|
| 32 |
_ROOM_COLOR,
|
| 33 |
_VERTICAL,
|
| 34 |
_NEW,
|
|
@@ -36,41 +34,71 @@ from const import (
|
|
| 36 |
_DIAGONAL_UP_LEFT,
|
| 37 |
_DIAGONAL_UP_RIGHT,
|
| 38 |
_BOARD_SIZE,
|
| 39 |
-
|
| 40 |
-
_AI_AID_INFO
|
| 41 |
)
|
| 42 |
|
|
|
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
| 48 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 49 |
|
| 50 |
-
gpt2 = load_model()
|
| 51 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 52 |
|
| 53 |
# Utils
|
| 54 |
class Room:
|
| 55 |
def __init__(self, room_id) -> None:
|
| 56 |
self.ROOM_ID = room_id
|
| 57 |
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
| 58 |
-
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=
|
| 59 |
-
self.PLAYER =
|
| 60 |
self.TURN = self.PLAYER
|
| 61 |
self.HISTORY = (0, 0)
|
| 62 |
self.WINNER = _BLANK
|
| 63 |
self.TIME = time.time()
|
| 64 |
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 65 |
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 66 |
-
'AlphaZero': alphazero(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 67 |
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
| 68 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
|
|
|
| 69 |
self.last_mcts = self.MCTS
|
| 70 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
| 71 |
-
self.COORDINATE_1D = [
|
| 72 |
self.current_move = -1
|
| 73 |
-
self.
|
|
|
|
|
|
|
| 74 |
|
| 75 |
|
| 76 |
def change_turn(cur):
|
|
@@ -90,9 +118,9 @@ if "ROOMS" not in server_state:
|
|
| 90 |
with server_state_lock["ROOMS"]:
|
| 91 |
server_state.ROOMS = {}
|
| 92 |
|
|
|
|
| 93 |
def handle_oppo_model_selection():
|
| 94 |
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
| 95 |
-
session_state.ROOM.last_mcts = session_state.ROOM.MCTS # since use different mechanism, store previous mcts first
|
| 96 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
| 97 |
return
|
| 98 |
else:
|
|
@@ -100,37 +128,66 @@ def handle_oppo_model_selection():
|
|
| 100 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 101 |
new_mct.mcts._root = deepcopy(TreeNode)
|
| 102 |
session_state.ROOM.MCTS = new_mct
|
| 103 |
-
session_state.ROOM.last_mcts
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 104 |
return
|
| 105 |
|
|
|
|
| 106 |
def handle_aid_model_selection():
|
| 107 |
if st.session_state['selected_aid_model'] == 'None':
|
| 108 |
session_state.USE_AIAID = False
|
| 109 |
return
|
| 110 |
session_state.USE_AIAID = True
|
| 111 |
-
TreeNode = session_state.ROOM.MCTS.mcts._root
|
| 112 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
| 113 |
new_mct.mcts._root = deepcopy(TreeNode)
|
| 114 |
session_state.ROOM.AID_MCTS = new_mct
|
| 115 |
return
|
| 116 |
|
|
|
|
| 117 |
if 'selected_oppo_model' not in st.session_state:
|
| 118 |
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
| 119 |
|
|
|
|
|
|
|
|
|
|
| 120 |
if 'selected_aid_model' not in st.session_state:
|
| 121 |
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
|
| 122 |
|
| 123 |
# Layout
|
| 124 |
TITLE = st.empty()
|
| 125 |
Model_Switch = st.empty()
|
|
|
|
| 126 |
|
| 127 |
TITLE.header("🤖 AI 3603 Gomoku")
|
| 128 |
-
selected_oppo_option = Model_Switch.selectbox('Select
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 129 |
|
| 130 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
| 131 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
| 132 |
handle_oppo_model_selection()
|
| 133 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 134 |
ROUND_INFO = st.empty()
|
| 135 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 136 |
BOARD_PLATE = [
|
|
@@ -149,9 +206,11 @@ MULTIPLAYER_TAG = st.sidebar.empty()
|
|
| 149 |
with st.sidebar.container():
|
| 150 |
ANOTHER_ROUND = st.empty()
|
| 151 |
RESTART = st.empty()
|
|
|
|
| 152 |
AIAID = st.empty()
|
| 153 |
EXIT = st.empty()
|
| 154 |
-
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
|
|
|
| 155 |
if st.session_state['selected_aid_model'] != selected_aid_option:
|
| 156 |
st.session_state['selected_aid_model'] = selected_aid_option
|
| 157 |
handle_aid_model_selection()
|
|
@@ -174,7 +233,6 @@ GAME_INFO.markdown(
|
|
| 174 |
)
|
| 175 |
|
| 176 |
|
| 177 |
-
|
| 178 |
def restart() -> None:
|
| 179 |
"""
|
| 180 |
Restart the game.
|
|
@@ -182,12 +240,52 @@ def restart() -> None:
|
|
| 182 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
| 183 |
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
| 184 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 185 |
RESTART.button(
|
| 186 |
"Reset",
|
| 187 |
on_click=restart,
|
| 188 |
help="Clear the board as well as the scores",
|
| 189 |
)
|
| 190 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 191 |
|
| 192 |
# Draw the board
|
| 193 |
def gomoku():
|
|
@@ -207,14 +305,25 @@ def gomoku():
|
|
| 207 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
| 208 |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 209 |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 210 |
-
|
| 211 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 213 |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 214 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
| 215 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
| 216 |
session_state.ROOM.WINNER = _BLANK # 0
|
| 217 |
-
session_state.ROOM.
|
|
|
|
|
|
|
| 218 |
|
| 219 |
# Room status sync
|
| 220 |
def sync_room() -> bool:
|
|
@@ -235,59 +344,6 @@ def gomoku():
|
|
| 235 |
session_state.ROOM = server_state.ROOMS[room_id]
|
| 236 |
return True
|
| 237 |
|
| 238 |
-
# Check if winner emerge from move
|
| 239 |
-
def check_win() -> int:
|
| 240 |
-
"""
|
| 241 |
-
Use convolution to check if any player wins.
|
| 242 |
-
"""
|
| 243 |
-
vertical = convolve(
|
| 244 |
-
session_state.ROOM.BOARD.board_map,
|
| 245 |
-
_VERTICAL,
|
| 246 |
-
mode="same",
|
| 247 |
-
)
|
| 248 |
-
horizontal = convolve(
|
| 249 |
-
session_state.ROOM.BOARD.board_map,
|
| 250 |
-
_HORIZONTAL,
|
| 251 |
-
mode="same",
|
| 252 |
-
)
|
| 253 |
-
diagonal_up_left = convolve(
|
| 254 |
-
session_state.ROOM.BOARD.board_map,
|
| 255 |
-
_DIAGONAL_UP_LEFT,
|
| 256 |
-
mode="same",
|
| 257 |
-
)
|
| 258 |
-
diagonal_up_right = convolve(
|
| 259 |
-
session_state.ROOM.BOARD.board_map,
|
| 260 |
-
_DIAGONAL_UP_RIGHT,
|
| 261 |
-
mode="same",
|
| 262 |
-
)
|
| 263 |
-
if (
|
| 264 |
-
np.max(
|
| 265 |
-
[
|
| 266 |
-
np.max(vertical),
|
| 267 |
-
np.max(horizontal),
|
| 268 |
-
np.max(diagonal_up_left),
|
| 269 |
-
np.max(diagonal_up_right),
|
| 270 |
-
]
|
| 271 |
-
)
|
| 272 |
-
== 5 * _BLACK
|
| 273 |
-
):
|
| 274 |
-
winner = _BLACK
|
| 275 |
-
elif (
|
| 276 |
-
np.min(
|
| 277 |
-
[
|
| 278 |
-
np.min(vertical),
|
| 279 |
-
np.min(horizontal),
|
| 280 |
-
np.min(diagonal_up_left),
|
| 281 |
-
np.min(diagonal_up_right),
|
| 282 |
-
]
|
| 283 |
-
)
|
| 284 |
-
== 5 * _WHITE
|
| 285 |
-
):
|
| 286 |
-
winner = _WHITE
|
| 287 |
-
else:
|
| 288 |
-
winner = _BLANK
|
| 289 |
-
return winner
|
| 290 |
-
|
| 291 |
# Triggers the board response on click
|
| 292 |
def handle_click(x, y):
|
| 293 |
"""
|
|
@@ -310,7 +366,8 @@ def gomoku():
|
|
| 310 |
session_state.ROOM.current_move = move
|
| 311 |
session_state.ROOM.BOARD.do_move(move)
|
| 312 |
# Gomoku Bot BOARD
|
| 313 |
-
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(move // _BOARD_SIZE
|
|
|
|
| 314 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
| 315 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
| 316 |
|
|
@@ -333,98 +390,121 @@ def gomoku():
|
|
| 333 |
# Draw board
|
| 334 |
def draw_board(response: bool):
|
| 335 |
"""construct each buttons for all cells of the board"""
|
| 336 |
-
if
|
| 337 |
-
|
| 338 |
-
|
| 339 |
-
|
| 340 |
-
|
| 341 |
-
|
| 342 |
-
|
| 343 |
-
|
| 344 |
-
|
| 345 |
-
|
| 346 |
-
|
| 347 |
-
|
| 348 |
-
|
| 349 |
-
|
| 350 |
-
|
| 351 |
-
|
| 352 |
-
|
| 353 |
-
|
| 354 |
-
|
| 355 |
-
|
| 356 |
-
|
| 357 |
-
|
| 358 |
-
|
| 359 |
-
|
| 360 |
-
|
| 361 |
-
|
| 362 |
-
|
| 363 |
-
|
| 364 |
-
|
| 365 |
-
|
| 366 |
-
|
| 367 |
-
|
| 368 |
-
|
| 369 |
-
|
| 370 |
-
|
| 371 |
-
|
| 372 |
-
|
| 373 |
-
|
| 374 |
-
|
| 375 |
-
|
| 376 |
-
|
| 377 |
-
|
| 378 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 379 |
else:
|
| 380 |
-
|
| 381 |
-
|
| 382 |
-
|
| 383 |
-
|
| 384 |
-
|
| 385 |
-
|
| 386 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 387 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 388 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 389 |
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
| 390 |
message.empty()
|
| 391 |
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
| 392 |
time.sleep(0.1)
|
| 393 |
print("AI's turn")
|
| 394 |
print("Below are current board under AI's view")
|
| 395 |
-
|
| 396 |
-
# move = _BOARD_SIZE * _BOARD_SIZE
|
| 397 |
-
# forbid = []
|
| 398 |
-
# step = 0.1
|
| 399 |
-
# tmp = 0.7
|
| 400 |
-
# while move >= _BOARD_SIZE * _BOARD_SIZE or move in session_state.ROOM.COORDINATE_1D:
|
| 401 |
-
#
|
| 402 |
-
# gpt_predictions = generate_gpt2(
|
| 403 |
-
# gpt2,
|
| 404 |
-
# torch.tensor(session_state.ROOM.COORDINATE_1D).unsqueeze(0),
|
| 405 |
-
# tmp
|
| 406 |
-
# )
|
| 407 |
-
# print(gpt_predictions)
|
| 408 |
-
# move = gpt_predictions[len(session_state.ROOM.COORDINATE_1D)]
|
| 409 |
-
# print(move)
|
| 410 |
-
# tmp += step
|
| 411 |
-
# # if move >= _BOARD_SIZE * _BOARD_SIZE:
|
| 412 |
-
# # forbid.append(move)
|
| 413 |
-
# # else:
|
| 414 |
-
# # break
|
| 415 |
-
#
|
| 416 |
-
#
|
| 417 |
-
# gpt_response = move
|
| 418 |
-
# gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
| 419 |
-
# print(gpt_i, gpt_j)
|
| 420 |
-
# # session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
| 421 |
-
#
|
| 422 |
-
# simul_time = 0
|
| 423 |
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
| 424 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
| 425 |
else:
|
| 426 |
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
| 427 |
-
session_state.ROOM.
|
| 428 |
print("AI takes move: ", move)
|
| 429 |
session_state.ROOM.current_move = move
|
| 430 |
gpt_response = move
|
|
@@ -436,7 +516,8 @@ def gomoku():
|
|
| 436 |
# MCTS BOARD
|
| 437 |
session_state.ROOM.BOARD.do_move(move)
|
| 438 |
# Gomoku Bot BOARD
|
| 439 |
-
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(
|
|
|
|
| 440 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
| 441 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
| 442 |
|
|
@@ -457,43 +538,43 @@ def gomoku():
|
|
| 457 |
for j, cell in enumerate(row):
|
| 458 |
if (
|
| 459 |
i * _BOARD_SIZE + j
|
| 460 |
-
in
|
| 461 |
):
|
| 462 |
if i == gpt_i and j == gpt_j:
|
| 463 |
BOARD_PLATE[i][j].button(
|
| 464 |
-
|
| 465 |
key=f"{i}:{j}",
|
| 466 |
args=(i, j),
|
| 467 |
-
on_click=
|
| 468 |
)
|
| 469 |
else:
|
| 470 |
# disable click for GPT choices
|
| 471 |
BOARD_PLATE[i][j].button(
|
| 472 |
-
|
| 473 |
key=f"{i}:{j}",
|
| 474 |
args=(i, j),
|
| 475 |
on_click=forbid_click
|
| 476 |
)
|
| 477 |
else:
|
| 478 |
-
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not
|
|
|
|
| 479 |
# enable click for other cells available for human choices
|
| 480 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 481 |
BOARD_PLATE[i][j].button(
|
| 482 |
-
|
| 483 |
key=f"{i}:{j}",
|
| 484 |
-
on_click=
|
| 485 |
args=(i, j),
|
| 486 |
)
|
| 487 |
else:
|
| 488 |
# enable click for other cells available for human choices
|
| 489 |
BOARD_PLATE[i][j].button(
|
| 490 |
-
|
| 491 |
key=f"{i}:{j}",
|
| 492 |
-
on_click=
|
| 493 |
args=(i, j),
|
| 494 |
)
|
| 495 |
|
| 496 |
-
|
| 497 |
message.markdown(
|
| 498 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
| 499 |
simul_time),
|
|
@@ -522,7 +603,7 @@ def gomoku():
|
|
| 522 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 523 |
for j, cell in enumerate(row):
|
| 524 |
BOARD_PLATE[i][j].write(
|
| 525 |
-
|
| 526 |
# key=f"{i}:{j}",
|
| 527 |
)
|
| 528 |
|
|
@@ -549,24 +630,11 @@ def gomoku():
|
|
| 549 |
ROUND_INFO.write(
|
| 550 |
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
| 551 |
)
|
| 552 |
-
|
| 553 |
-
# elif 0 not in session_state.ROOM.BOARD.board_map:
|
| 554 |
-
# ROUND_INFO.write("#### **Tie**")
|
| 555 |
-
# else:
|
| 556 |
-
# ROUND_INFO.write(
|
| 557 |
-
# f"#### **{_PLAYER_SYMBOL[session_state.ROOM.TURN]} {_PLAYER_COLOR[session_state.ROOM.TURN]}'s turn...**"
|
| 558 |
-
# )
|
| 559 |
-
|
| 560 |
-
# draw the plot for simulation time
|
| 561 |
-
# 创建一个 DataFrame
|
| 562 |
-
|
| 563 |
-
# print(session_state.ROOM.simula_time_list)
|
| 564 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 565 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 566 |
-
chart_data = pd.DataFrame(session_state.ROOM.
|
| 567 |
st.line_chart(chart_data)
|
| 568 |
|
| 569 |
-
|
| 570 |
game_control()
|
| 571 |
update_info()
|
| 572 |
|
|
|
|
| 8 |
import time
|
| 9 |
import pandas as pd
|
| 10 |
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import streamlit as st
|
| 13 |
from scipy.signal import convolve # this is used to check if any player wins
|
| 14 |
from streamlit import session_state
|
| 15 |
from streamlit_server_state import server_state, server_state_lock
|
| 16 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
| 17 |
+
Gumbel_MCTSPlayer
|
| 18 |
from Gomoku_Bot import Gomoku_bot
|
| 19 |
from Gomoku_Bot import Board as Gomoku_bot_board
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
|
| 23 |
from const import (
|
| 24 |
_BLACK, # 1, for human
|
| 25 |
_WHITE, # 2 , for AI
|
| 26 |
_BLANK,
|
| 27 |
_PLAYER_COLOR,
|
| 28 |
+
_PLAYER_SYMBOL1,
|
| 29 |
+
_PLAYER_SYMBOL2,
|
| 30 |
_ROOM_COLOR,
|
| 31 |
_VERTICAL,
|
| 32 |
_NEW,
|
|
|
|
| 34 |
_DIAGONAL_UP_LEFT,
|
| 35 |
_DIAGONAL_UP_RIGHT,
|
| 36 |
_BOARD_SIZE,
|
| 37 |
+
_MODEL_PATH
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
+
_PLAYER_SYMBOL = [0, _PLAYER_SYMBOL1, _PLAYER_SYMBOL2]
|
| 41 |
|
| 42 |
+
# '''
|
| 43 |
+
# from ai import (
|
| 44 |
+
# BOS_TOKEN_ID,
|
| 45 |
+
# generate_gpt2,
|
| 46 |
+
# load_model,
|
| 47 |
+
# )
|
| 48 |
+
#
|
| 49 |
+
# gpt2 = load_model()
|
| 50 |
+
#
|
| 51 |
+
# '''
|
| 52 |
|
|
|
|
| 53 |
|
| 54 |
+
if "FirstPlayer" not in session_state:
|
| 55 |
+
session_state.FirstPlayer = _BLACK
|
| 56 |
+
session_state.Player = [[], [ _BLACK,_WHITE], [_WHITE,_BLACK]][session_state.FirstPlayer]
|
| 57 |
+
session_state.Symbol = _PLAYER_SYMBOL[session_state.FirstPlayer]
|
| 58 |
|
| 59 |
# Utils
|
| 60 |
class Room:
|
| 61 |
def __init__(self, room_id) -> None:
|
| 62 |
self.ROOM_ID = room_id
|
| 63 |
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
| 64 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player)
|
| 65 |
+
self.PLAYER = session_state.FirstPlayer
|
| 66 |
self.TURN = self.PLAYER
|
| 67 |
self.HISTORY = (0, 0)
|
| 68 |
self.WINNER = _BLANK
|
| 69 |
self.TIME = time.time()
|
| 70 |
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 71 |
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 72 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 73 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 74 |
+
c_puct=5, n_playout=100),
|
| 75 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 76 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 77 |
+
c_puct=5, n_playout=100),
|
| 78 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 79 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
| 80 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 81 |
+
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
| 82 |
+
self.MCTS_dict_ = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 83 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 84 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 85 |
+
c_puct=5, n_playout=100),
|
| 86 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 87 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 88 |
+
c_puct=5, n_playout=100),
|
| 89 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 90 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
| 91 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 92 |
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
| 93 |
self.MCTS = self.MCTS_dict['AlphaZero']
|
| 94 |
+
self.MCTS_ = self.MCTS_dict['AlphaZero']
|
| 95 |
self.last_mcts = self.MCTS
|
| 96 |
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
| 97 |
+
self.COORDINATE_1D = []
|
| 98 |
self.current_move = -1
|
| 99 |
+
self.ai_simula_time_list = []
|
| 100 |
+
self.ai_simula_time_list_ = []
|
| 101 |
+
self.human_simula_time_list = []
|
| 102 |
|
| 103 |
|
| 104 |
def change_turn(cur):
|
|
|
|
| 118 |
with server_state_lock["ROOMS"]:
|
| 119 |
server_state.ROOMS = {}
|
| 120 |
|
| 121 |
+
|
| 122 |
def handle_oppo_model_selection():
|
| 123 |
if st.session_state['selected_oppo_model'] == 'Gomoku Bot':
|
|
|
|
| 124 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict['Gomoku Bot']
|
| 125 |
return
|
| 126 |
else:
|
|
|
|
| 128 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 129 |
new_mct.mcts._root = deepcopy(TreeNode)
|
| 130 |
session_state.ROOM.MCTS = new_mct
|
| 131 |
+
session_state.ROOM.last_mcts = new_mct
|
| 132 |
+
return
|
| 133 |
+
|
| 134 |
+
def handle_oppo_model_selection_():
|
| 135 |
+
if st.session_state['selected_oppo_model_'] == 'Gomoku Bot':
|
| 136 |
+
session_state.ROOM.MCTS_ = session_state.ROOM.MCTS_dict_['Gomoku Bot']
|
| 137 |
+
return
|
| 138 |
+
else:
|
| 139 |
+
TreeNode = session_state.ROOM.last_mcts_.mcts._root
|
| 140 |
+
new_mct = session_state.ROOM.MCTS_dict_[st.session_state['selected_oppo_model_']]
|
| 141 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
| 142 |
+
session_state.ROOM.MCTS_ = new_mct
|
| 143 |
+
session_state.ROOM.last_mcts_ = new_mct
|
| 144 |
return
|
| 145 |
|
| 146 |
+
|
| 147 |
def handle_aid_model_selection():
|
| 148 |
if st.session_state['selected_aid_model'] == 'None':
|
| 149 |
session_state.USE_AIAID = False
|
| 150 |
return
|
| 151 |
session_state.USE_AIAID = True
|
| 152 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
| 153 |
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
| 154 |
new_mct.mcts._root = deepcopy(TreeNode)
|
| 155 |
session_state.ROOM.AID_MCTS = new_mct
|
| 156 |
return
|
| 157 |
|
| 158 |
+
|
| 159 |
if 'selected_oppo_model' not in st.session_state:
|
| 160 |
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
| 161 |
|
| 162 |
+
if 'selected_oppo_model_' not in st.session_state:
|
| 163 |
+
st.session_state['selected_oppo_model_'] = 'AlphaZero' # 默认值
|
| 164 |
+
|
| 165 |
if 'selected_aid_model' not in st.session_state:
|
| 166 |
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
|
| 167 |
|
| 168 |
# Layout
|
| 169 |
TITLE = st.empty()
|
| 170 |
Model_Switch = st.empty()
|
| 171 |
+
Model_Switch_ = st.empty()
|
| 172 |
|
| 173 |
TITLE.header("🤖 AI 3603 Gomoku")
|
| 174 |
+
selected_oppo_option = Model_Switch.selectbox('Select Model 1',
|
| 175 |
+
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
|
| 176 |
+
index=1, key='oppo_model')
|
| 177 |
+
|
| 178 |
+
selected_oppo_option_ = Model_Switch_.selectbox('Select Model 2',
|
| 179 |
+
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
|
| 180 |
+
index=1, key='oppo_model_')
|
| 181 |
|
| 182 |
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
| 183 |
st.session_state['selected_oppo_model'] = selected_oppo_option
|
| 184 |
handle_oppo_model_selection()
|
| 185 |
|
| 186 |
+
if st.session_state['selected_oppo_model_'] != selected_oppo_option_:
|
| 187 |
+
st.session_state['selected_oppo_model_'] = selected_oppo_option_
|
| 188 |
+
handle_oppo_model_selection_()
|
| 189 |
+
|
| 190 |
+
|
| 191 |
ROUND_INFO = st.empty()
|
| 192 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 193 |
BOARD_PLATE = [
|
|
|
|
| 206 |
with st.sidebar.container():
|
| 207 |
ANOTHER_ROUND = st.empty()
|
| 208 |
RESTART = st.empty()
|
| 209 |
+
CHANGE_PLAYER = st.empty()
|
| 210 |
AIAID = st.empty()
|
| 211 |
EXIT = st.empty()
|
| 212 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
| 213 |
+
key='aid_model')
|
| 214 |
if st.session_state['selected_aid_model'] != selected_aid_option:
|
| 215 |
st.session_state['selected_aid_model'] = selected_aid_option
|
| 216 |
handle_aid_model_selection()
|
|
|
|
| 233 |
)
|
| 234 |
|
| 235 |
|
|
|
|
| 236 |
def restart() -> None:
|
| 237 |
"""
|
| 238 |
Restart the game.
|
|
|
|
| 240 |
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
| 241 |
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
| 242 |
|
| 243 |
+
|
| 244 |
+
def swap_players() -> None:
|
| 245 |
+
session_state.update(
|
| 246 |
+
FirstPlayer=change_turn(session_state.FirstPlayer),
|
| 247 |
+
)
|
| 248 |
+
session_state.update(
|
| 249 |
+
Player=[[], [_BLACK, _WHITE], [_WHITE, _BLACK]][session_state.FirstPlayer],
|
| 250 |
+
Symbol=_PLAYER_SYMBOL[session_state.FirstPlayer]
|
| 251 |
+
)
|
| 252 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player)
|
| 253 |
+
session_state.ROOM.PLAYER = session_state.FirstPlayer
|
| 254 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 255 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 256 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 257 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 258 |
+
c_puct=5, n_playout=100),
|
| 259 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 260 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 261 |
+
c_puct=5, n_playout=100),
|
| 262 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 263 |
+
_MODEL_PATH[
|
| 264 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
| 265 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 266 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 267 |
+
|
| 268 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 269 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 270 |
+
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
| 271 |
+
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
| 272 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 273 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 274 |
+
session_state.ROOM.human_simula_time_list = []
|
| 275 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 276 |
+
|
| 277 |
RESTART.button(
|
| 278 |
"Reset",
|
| 279 |
on_click=restart,
|
| 280 |
help="Clear the board as well as the scores",
|
| 281 |
)
|
| 282 |
|
| 283 |
+
CHANGE_PLAYER.button(
|
| 284 |
+
"Swap players",
|
| 285 |
+
on_click=swap_players,
|
| 286 |
+
help="Swap players",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
|
| 290 |
# Draw the board
|
| 291 |
def gomoku():
|
|
|
|
| 305 |
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
| 306 |
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 307 |
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 308 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 309 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 310 |
+
c_puct=5, n_playout=100),
|
| 311 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 312 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 313 |
+
c_puct=5, n_playout=100),
|
| 314 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 315 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
| 316 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 317 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 318 |
+
|
| 319 |
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 320 |
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 321 |
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
| 322 |
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
| 323 |
session_state.ROOM.WINNER = _BLANK # 0
|
| 324 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 325 |
+
session_state.ROOM.human_simula_time_list = []
|
| 326 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 327 |
|
| 328 |
# Room status sync
|
| 329 |
def sync_room() -> bool:
|
|
|
|
| 344 |
session_state.ROOM = server_state.ROOMS[room_id]
|
| 345 |
return True
|
| 346 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 347 |
# Triggers the board response on click
|
| 348 |
def handle_click(x, y):
|
| 349 |
"""
|
|
|
|
| 366 |
session_state.ROOM.current_move = move
|
| 367 |
session_state.ROOM.BOARD.do_move(move)
|
| 368 |
# Gomoku Bot BOARD
|
| 369 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1,
|
| 370 |
+
move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
|
| 371 |
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
| 372 |
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
| 373 |
|
|
|
|
| 390 |
# Draw board
|
| 391 |
def draw_board(response: bool):
|
| 392 |
"""construct each buttons for all cells of the board"""
|
| 393 |
+
if response and session_state.ROOM.TURN == _BLACK: # Another AI
|
| 394 |
+
message.empty()
|
| 395 |
+
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
| 396 |
+
time.sleep(0.1)
|
| 397 |
+
print("AI's turn")
|
| 398 |
+
print("Below are current board under AI's view")
|
| 399 |
+
|
| 400 |
+
if st.session_state['selected_oppo_model_'] != 'Gomoku Bot':
|
| 401 |
+
move, simul_time = session_state.ROOM.MCTS_.get_action(session_state.ROOM.BOARD, return_time=True)
|
| 402 |
+
else:
|
| 403 |
+
move, simul_time = session_state.ROOM.MCTS_.get_action(return_time=True)
|
| 404 |
+
session_state.ROOM.ai_simula_time_list_.append(simul_time)
|
| 405 |
+
print("AI takes move: ", move)
|
| 406 |
+
session_state.ROOM.current_move = move
|
| 407 |
+
gpt_response = move
|
| 408 |
+
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
| 409 |
+
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
| 410 |
+
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
| 411 |
+
print("Location to move: ", move)
|
| 412 |
+
# print("Location to move: ", move)
|
| 413 |
+
# MCTS BOARD
|
| 414 |
+
session_state.ROOM.BOARD.do_move(move)
|
| 415 |
+
# Gomoku Bot BOARD
|
| 416 |
+
session_state.ROOM.MCTS_dict_["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
|
| 417 |
+
move % _BOARD_SIZE)
|
| 418 |
+
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
| 419 |
+
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
| 420 |
+
|
| 421 |
+
if not session_state.ROOM.BOARD.game_end()[0]:
|
| 422 |
+
if session_state.USE_AIAID:
|
| 423 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
| 424 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
| 425 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
| 426 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
| 427 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
| 428 |
+
else:
|
| 429 |
+
top_five_acts = []
|
| 430 |
+
top_five_probs = []
|
| 431 |
+
|
| 432 |
+
# construction of clickable buttons
|
| 433 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 434 |
+
# print("row:", row)
|
| 435 |
+
for j, cell in enumerate(row):
|
| 436 |
+
if (
|
| 437 |
+
i * _BOARD_SIZE + j
|
| 438 |
+
in session_state.ROOM.COORDINATE_1D
|
| 439 |
+
):
|
| 440 |
+
if i == gpt_i and j == gpt_j:
|
| 441 |
+
BOARD_PLATE[i][j].button(
|
| 442 |
+
session_state.Symbol[_NEW],
|
| 443 |
+
key=f"{i}:{j}",
|
| 444 |
+
args=(i, j),
|
| 445 |
+
on_click=forbid_click,
|
| 446 |
+
)
|
| 447 |
+
else:
|
| 448 |
+
# disable click for GPT choices
|
| 449 |
+
BOARD_PLATE[i][j].button(
|
| 450 |
+
session_state.Symbol[cell],
|
| 451 |
+
key=f"{i}:{j}",
|
| 452 |
+
args=(i, j),
|
| 453 |
+
on_click=forbid_click
|
| 454 |
+
)
|
| 455 |
else:
|
| 456 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
|
| 457 |
+
session_state.ROOM.BOARD.game_end()[0]:
|
| 458 |
+
# enable click for other cells available for human choices
|
| 459 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 460 |
+
BOARD_PLATE[i][j].button(
|
| 461 |
+
session_state.Symbol[cell] + f"({round(prob, 2)})",
|
| 462 |
+
key=f"{i}:{j}",
|
| 463 |
+
on_click=forbid_click,
|
| 464 |
+
args=(i, j),
|
| 465 |
+
)
|
| 466 |
+
else:
|
| 467 |
+
# enable click for other cells available for human choices
|
| 468 |
+
BOARD_PLATE[i][j].button(
|
| 469 |
+
session_state.Symbol[cell],
|
| 470 |
+
key=f"{i}:{j}",
|
| 471 |
+
on_click=forbid_click,
|
| 472 |
+
args=(i, j),
|
| 473 |
+
)
|
| 474 |
|
| 475 |
+
message.markdown(
|
| 476 |
+
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
| 477 |
+
simul_time),
|
| 478 |
+
unsafe_allow_html=True
|
| 479 |
+
)
|
| 480 |
+
LOG.subheader("Logs")
|
| 481 |
+
# change turn
|
| 482 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
| 483 |
+
# session_state.ROOM.WINNER = check_win()
|
| 484 |
+
|
| 485 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
| 486 |
+
if win:
|
| 487 |
+
session_state.ROOM.WINNER = winner
|
| 488 |
|
| 489 |
+
session_state.ROOM.HISTORY = (
|
| 490 |
+
session_state.ROOM.HISTORY[0]
|
| 491 |
+
+ int(session_state.ROOM.WINNER == _WHITE),
|
| 492 |
+
session_state.ROOM.HISTORY[1]
|
| 493 |
+
+ int(session_state.ROOM.WINNER == _BLACK),
|
| 494 |
+
)
|
| 495 |
+
session_state.ROOM.TIME = time.time()
|
| 496 |
elif response and session_state.ROOM.TURN == _WHITE: # AI turn
|
| 497 |
message.empty()
|
| 498 |
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
| 499 |
time.sleep(0.1)
|
| 500 |
print("AI's turn")
|
| 501 |
print("Below are current board under AI's view")
|
| 502 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 503 |
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
| 504 |
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
| 505 |
else:
|
| 506 |
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
| 507 |
+
session_state.ROOM.ai_simula_time_list.append(simul_time)
|
| 508 |
print("AI takes move: ", move)
|
| 509 |
session_state.ROOM.current_move = move
|
| 510 |
gpt_response = move
|
|
|
|
| 516 |
# MCTS BOARD
|
| 517 |
session_state.ROOM.BOARD.do_move(move)
|
| 518 |
# Gomoku Bot BOARD
|
| 519 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
|
| 520 |
+
move % _BOARD_SIZE)
|
| 521 |
# session_state.ROOM.BOARD[gpt_i][gpt_j] = session_state.ROOM.TURN
|
| 522 |
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
| 523 |
|
|
|
|
| 538 |
for j, cell in enumerate(row):
|
| 539 |
if (
|
| 540 |
i * _BOARD_SIZE + j
|
| 541 |
+
in session_state.ROOM.COORDINATE_1D
|
| 542 |
):
|
| 543 |
if i == gpt_i and j == gpt_j:
|
| 544 |
BOARD_PLATE[i][j].button(
|
| 545 |
+
session_state.Symbol[_NEW],
|
| 546 |
key=f"{i}:{j}",
|
| 547 |
args=(i, j),
|
| 548 |
+
on_click=forbid_click,
|
| 549 |
)
|
| 550 |
else:
|
| 551 |
# disable click for GPT choices
|
| 552 |
BOARD_PLATE[i][j].button(
|
| 553 |
+
session_state.Symbol[cell],
|
| 554 |
key=f"{i}:{j}",
|
| 555 |
args=(i, j),
|
| 556 |
on_click=forbid_click
|
| 557 |
)
|
| 558 |
else:
|
| 559 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
|
| 560 |
+
session_state.ROOM.BOARD.game_end()[0]:
|
| 561 |
# enable click for other cells available for human choices
|
| 562 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 563 |
BOARD_PLATE[i][j].button(
|
| 564 |
+
session_state.Symbol[cell] + f"({round(prob, 2)})",
|
| 565 |
key=f"{i}:{j}",
|
| 566 |
+
on_click=forbid_click,
|
| 567 |
args=(i, j),
|
| 568 |
)
|
| 569 |
else:
|
| 570 |
# enable click for other cells available for human choices
|
| 571 |
BOARD_PLATE[i][j].button(
|
| 572 |
+
session_state.Symbol[cell],
|
| 573 |
key=f"{i}:{j}",
|
| 574 |
+
on_click=forbid_click,
|
| 575 |
args=(i, j),
|
| 576 |
)
|
| 577 |
|
|
|
|
| 578 |
message.markdown(
|
| 579 |
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
| 580 |
simul_time),
|
|
|
|
| 603 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 604 |
for j, cell in enumerate(row):
|
| 605 |
BOARD_PLATE[i][j].write(
|
| 606 |
+
session_state.Symbol[cell],
|
| 607 |
# key=f"{i}:{j}",
|
| 608 |
)
|
| 609 |
|
|
|
|
| 630 |
ROUND_INFO.write(
|
| 631 |
f"#### **{_PLAYER_COLOR[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
| 632 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 633 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 634 |
st.markdown("<br>", unsafe_allow_html=True)
|
| 635 |
+
chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"])
|
| 636 |
st.line_chart(chart_data)
|
| 637 |
|
|
|
|
| 638 |
game_control()
|
| 639 |
update_info()
|
| 640 |
|
pages/Player_VS_AI.py
CHANGED
|
@@ -8,14 +8,10 @@ Description: this file is used to display our project and add visualization elem
|
|
| 8 |
import time
|
| 9 |
import pandas as pd
|
| 10 |
from copy import deepcopy
|
| 11 |
-
import torch
|
| 12 |
-
|
| 13 |
-
# import torch
|
| 14 |
import numpy as np
|
| 15 |
import streamlit as st
|
| 16 |
from scipy.signal import convolve # this is used to check if any player wins
|
| 17 |
from streamlit import session_state
|
| 18 |
-
from streamlit.delta_generator import DeltaGenerator
|
| 19 |
from streamlit_server_state import server_state, server_state_lock
|
| 20 |
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
| 21 |
Gumbel_MCTSPlayer
|
|
@@ -29,7 +25,8 @@ from const import (
|
|
| 29 |
_WHITE, # 2 , for AI
|
| 30 |
_BLANK,
|
| 31 |
_PLAYER_COLOR,
|
| 32 |
-
|
|
|
|
| 33 |
_ROOM_COLOR,
|
| 34 |
_VERTICAL,
|
| 35 |
_NEW,
|
|
@@ -37,11 +34,11 @@ from const import (
|
|
| 37 |
_DIAGONAL_UP_LEFT,
|
| 38 |
_DIAGONAL_UP_RIGHT,
|
| 39 |
_BOARD_SIZE,
|
| 40 |
-
_BOARD_SIZE_1D,
|
| 41 |
-
_AI_AID_INFO,
|
| 42 |
_MODEL_PATH
|
| 43 |
)
|
| 44 |
|
|
|
|
|
|
|
| 45 |
# '''
|
| 46 |
# from ai import (
|
| 47 |
# BOS_TOKEN_ID,
|
|
@@ -54,13 +51,18 @@ from const import (
|
|
| 54 |
# '''
|
| 55 |
|
| 56 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 57 |
# Utils
|
| 58 |
class Room:
|
| 59 |
def __init__(self, room_id) -> None:
|
| 60 |
self.ROOM_ID = room_id
|
| 61 |
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
| 62 |
-
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=
|
| 63 |
-
self.PLAYER =
|
| 64 |
self.TURN = self.PLAYER
|
| 65 |
self.HISTORY = (0, 0)
|
| 66 |
self.WINNER = _BLANK
|
|
@@ -167,6 +169,7 @@ with st.sidebar.container():
|
|
| 167 |
ANOTHER_ROUND = st.empty()
|
| 168 |
RESTART = st.empty()
|
| 169 |
GIVEIN = st.empty()
|
|
|
|
| 170 |
AIAID = st.empty()
|
| 171 |
EXIT = st.empty()
|
| 172 |
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
|
@@ -237,6 +240,38 @@ def givein() -> None:
|
|
| 237 |
session_state.ROOM.human_simula_time_list = []
|
| 238 |
session_state.ROOM.COORDINATE_1D = []
|
| 239 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 240 |
|
| 241 |
RESTART.button(
|
| 242 |
"Reset",
|
|
@@ -250,6 +285,12 @@ GIVEIN.button(
|
|
| 250 |
help="Give in to AI",
|
| 251 |
)
|
| 252 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 253 |
|
| 254 |
# Draw the board
|
| 255 |
def gomoku():
|
|
@@ -428,7 +469,7 @@ def gomoku():
|
|
| 428 |
):
|
| 429 |
if i == cur_move[0] and j == cur_move[1]:
|
| 430 |
BOARD_PLATE[i][j].button(
|
| 431 |
-
|
| 432 |
key=f"{i}:{j}",
|
| 433 |
args=(i, j),
|
| 434 |
on_click=forbid_click,
|
|
@@ -436,7 +477,7 @@ def gomoku():
|
|
| 436 |
else:
|
| 437 |
# disable click for GPT choices
|
| 438 |
BOARD_PLATE[i][j].button(
|
| 439 |
-
|
| 440 |
key=f"{i}:{j}",
|
| 441 |
args=(i, j),
|
| 442 |
on_click=forbid_click
|
|
@@ -446,7 +487,7 @@ def gomoku():
|
|
| 446 |
# enable click for other cells available for human choices
|
| 447 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 448 |
BOARD_PLATE[i][j].button(
|
| 449 |
-
|
| 450 |
key=f"{i}:{j}",
|
| 451 |
on_click=handle_click,
|
| 452 |
args=(i, j),
|
|
@@ -454,7 +495,7 @@ def gomoku():
|
|
| 454 |
else:
|
| 455 |
# enable click for other cells available for human choices
|
| 456 |
BOARD_PLATE[i][j].button(
|
| 457 |
-
|
| 458 |
key=f"{i}:{j}",
|
| 459 |
on_click=handle_click,
|
| 460 |
args=(i, j),
|
|
@@ -538,7 +579,7 @@ def gomoku():
|
|
| 538 |
):
|
| 539 |
if i == gpt_i and j == gpt_j:
|
| 540 |
BOARD_PLATE[i][j].button(
|
| 541 |
-
|
| 542 |
key=f"{i}:{j}",
|
| 543 |
args=(i, j),
|
| 544 |
on_click=handle_click,
|
|
@@ -546,7 +587,7 @@ def gomoku():
|
|
| 546 |
else:
|
| 547 |
# disable click for GPT choices
|
| 548 |
BOARD_PLATE[i][j].button(
|
| 549 |
-
|
| 550 |
key=f"{i}:{j}",
|
| 551 |
args=(i, j),
|
| 552 |
on_click=forbid_click
|
|
@@ -557,7 +598,7 @@ def gomoku():
|
|
| 557 |
# enable click for other cells available for human choices
|
| 558 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 559 |
BOARD_PLATE[i][j].button(
|
| 560 |
-
|
| 561 |
key=f"{i}:{j}",
|
| 562 |
on_click=handle_click,
|
| 563 |
args=(i, j),
|
|
@@ -565,7 +606,7 @@ def gomoku():
|
|
| 565 |
else:
|
| 566 |
# enable click for other cells available for human choices
|
| 567 |
BOARD_PLATE[i][j].button(
|
| 568 |
-
|
| 569 |
key=f"{i}:{j}",
|
| 570 |
on_click=handle_click,
|
| 571 |
args=(i, j),
|
|
@@ -599,7 +640,7 @@ def gomoku():
|
|
| 599 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 600 |
for j, cell in enumerate(row):
|
| 601 |
BOARD_PLATE[i][j].write(
|
| 602 |
-
|
| 603 |
# key=f"{i}:{j}",
|
| 604 |
)
|
| 605 |
|
|
|
|
| 8 |
import time
|
| 9 |
import pandas as pd
|
| 10 |
from copy import deepcopy
|
|
|
|
|
|
|
|
|
|
| 11 |
import numpy as np
|
| 12 |
import streamlit as st
|
| 13 |
from scipy.signal import convolve # this is used to check if any player wins
|
| 14 |
from streamlit import session_state
|
|
|
|
| 15 |
from streamlit_server_state import server_state, server_state_lock
|
| 16 |
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
| 17 |
Gumbel_MCTSPlayer
|
|
|
|
| 25 |
_WHITE, # 2 , for AI
|
| 26 |
_BLANK,
|
| 27 |
_PLAYER_COLOR,
|
| 28 |
+
_PLAYER_SYMBOL1,
|
| 29 |
+
_PLAYER_SYMBOL2,
|
| 30 |
_ROOM_COLOR,
|
| 31 |
_VERTICAL,
|
| 32 |
_NEW,
|
|
|
|
| 34 |
_DIAGONAL_UP_LEFT,
|
| 35 |
_DIAGONAL_UP_RIGHT,
|
| 36 |
_BOARD_SIZE,
|
|
|
|
|
|
|
| 37 |
_MODEL_PATH
|
| 38 |
)
|
| 39 |
|
| 40 |
+
_PLAYER_SYMBOL = [0, _PLAYER_SYMBOL1, _PLAYER_SYMBOL2]
|
| 41 |
+
|
| 42 |
# '''
|
| 43 |
# from ai import (
|
| 44 |
# BOS_TOKEN_ID,
|
|
|
|
| 51 |
# '''
|
| 52 |
|
| 53 |
|
| 54 |
+
if "FirstPlayer" not in session_state:
|
| 55 |
+
session_state.FirstPlayer = _BLACK
|
| 56 |
+
session_state.Player = [[], [ _BLACK,_WHITE], [_WHITE,_BLACK]][session_state.FirstPlayer]
|
| 57 |
+
session_state.Symbol = _PLAYER_SYMBOL[session_state.FirstPlayer]
|
| 58 |
+
|
| 59 |
# Utils
|
| 60 |
class Room:
|
| 61 |
def __init__(self, room_id) -> None:
|
| 62 |
self.ROOM_ID = room_id
|
| 63 |
# self.BOARD = np.zeros(shape=(_BOARD_SIZE, _BOARD_SIZE), dtype=int)
|
| 64 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player)
|
| 65 |
+
self.PLAYER = session_state.FirstPlayer
|
| 66 |
self.TURN = self.PLAYER
|
| 67 |
self.HISTORY = (0, 0)
|
| 68 |
self.WINNER = _BLANK
|
|
|
|
| 169 |
ANOTHER_ROUND = st.empty()
|
| 170 |
RESTART = st.empty()
|
| 171 |
GIVEIN = st.empty()
|
| 172 |
+
CHANGE_PLAYER = st.empty()
|
| 173 |
AIAID = st.empty()
|
| 174 |
EXIT = st.empty()
|
| 175 |
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
|
|
|
| 240 |
session_state.ROOM.human_simula_time_list = []
|
| 241 |
session_state.ROOM.COORDINATE_1D = []
|
| 242 |
|
| 243 |
+
def swap_players() -> None:
|
| 244 |
+
session_state.update(
|
| 245 |
+
FirstPlayer=change_turn(session_state.FirstPlayer),
|
| 246 |
+
)
|
| 247 |
+
session_state.update(
|
| 248 |
+
Player=[[], [_BLACK, _WHITE], [_WHITE, _BLACK]][session_state.FirstPlayer],
|
| 249 |
+
Symbol=_PLAYER_SYMBOL[session_state.FirstPlayer]
|
| 250 |
+
)
|
| 251 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Player)
|
| 252 |
+
session_state.ROOM.PLAYER = session_state.FirstPlayer
|
| 253 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 254 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 255 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 256 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 257 |
+
c_puct=5, n_playout=100),
|
| 258 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 259 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 260 |
+
c_puct=5, n_playout=100),
|
| 261 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 262 |
+
_MODEL_PATH[
|
| 263 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
| 264 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 265 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 266 |
+
|
| 267 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 268 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 269 |
+
session_state.ROOM.PLAYER = session_state.ROOM.PLAYER
|
| 270 |
+
session_state.ROOM.TURN = session_state.ROOM.PLAYER
|
| 271 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 272 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 273 |
+
session_state.ROOM.human_simula_time_list = []
|
| 274 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 275 |
|
| 276 |
RESTART.button(
|
| 277 |
"Reset",
|
|
|
|
| 285 |
help="Give in to AI",
|
| 286 |
)
|
| 287 |
|
| 288 |
+
CHANGE_PLAYER.button(
|
| 289 |
+
"Swap players",
|
| 290 |
+
on_click=swap_players,
|
| 291 |
+
help="Swap players",
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
|
| 295 |
# Draw the board
|
| 296 |
def gomoku():
|
|
|
|
| 469 |
):
|
| 470 |
if i == cur_move[0] and j == cur_move[1]:
|
| 471 |
BOARD_PLATE[i][j].button(
|
| 472 |
+
session_state.Symbol[_NEW],
|
| 473 |
key=f"{i}:{j}",
|
| 474 |
args=(i, j),
|
| 475 |
on_click=forbid_click,
|
|
|
|
| 477 |
else:
|
| 478 |
# disable click for GPT choices
|
| 479 |
BOARD_PLATE[i][j].button(
|
| 480 |
+
session_state.Symbol[cell],
|
| 481 |
key=f"{i}:{j}",
|
| 482 |
args=(i, j),
|
| 483 |
on_click=forbid_click
|
|
|
|
| 487 |
# enable click for other cells available for human choices
|
| 488 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 489 |
BOARD_PLATE[i][j].button(
|
| 490 |
+
session_state.Symbol[cell] + f"{round(prob, 2)}",
|
| 491 |
key=f"{i}:{j}",
|
| 492 |
on_click=handle_click,
|
| 493 |
args=(i, j),
|
|
|
|
| 495 |
else:
|
| 496 |
# enable click for other cells available for human choices
|
| 497 |
BOARD_PLATE[i][j].button(
|
| 498 |
+
session_state.Symbol[cell],
|
| 499 |
key=f"{i}:{j}",
|
| 500 |
on_click=handle_click,
|
| 501 |
args=(i, j),
|
|
|
|
| 579 |
):
|
| 580 |
if i == gpt_i and j == gpt_j:
|
| 581 |
BOARD_PLATE[i][j].button(
|
| 582 |
+
session_state.Symbol[_NEW],
|
| 583 |
key=f"{i}:{j}",
|
| 584 |
args=(i, j),
|
| 585 |
on_click=handle_click,
|
|
|
|
| 587 |
else:
|
| 588 |
# disable click for GPT choices
|
| 589 |
BOARD_PLATE[i][j].button(
|
| 590 |
+
session_state.Symbol[cell],
|
| 591 |
key=f"{i}:{j}",
|
| 592 |
args=(i, j),
|
| 593 |
on_click=forbid_click
|
|
|
|
| 598 |
# enable click for other cells available for human choices
|
| 599 |
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 600 |
BOARD_PLATE[i][j].button(
|
| 601 |
+
session_state.Symbol[cell] + f"{round(prob, 2)}",
|
| 602 |
key=f"{i}:{j}",
|
| 603 |
on_click=handle_click,
|
| 604 |
args=(i, j),
|
|
|
|
| 606 |
else:
|
| 607 |
# enable click for other cells available for human choices
|
| 608 |
BOARD_PLATE[i][j].button(
|
| 609 |
+
session_state.Symbol[cell],
|
| 610 |
key=f"{i}:{j}",
|
| 611 |
on_click=handle_click,
|
| 612 |
args=(i, j),
|
|
|
|
| 640 |
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 641 |
for j, cell in enumerate(row):
|
| 642 |
BOARD_PLATE[i][j].write(
|
| 643 |
+
session_state.Symbol[cell],
|
| 644 |
# key=f"{i}:{j}",
|
| 645 |
)
|
| 646 |
|
pages/Try.py
ADDED
|
@@ -0,0 +1,17 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import streamlit as st
|
| 2 |
+
|
| 3 |
+
# 示例数据
|
| 4 |
+
data = {
|
| 5 |
+
"Player": ["Alice", "Bob", "Charlie"],
|
| 6 |
+
"Score": [100, 95, 90]
|
| 7 |
+
}
|
| 8 |
+
|
| 9 |
+
# 将数据转换为 Markdown 格式的字符串
|
| 10 |
+
def create_leaderboard(data):
|
| 11 |
+
leaderboard = "### Leaderboard\n"
|
| 12 |
+
for i, (player, score) in enumerate(zip(data["Player"], data["Score"]), start=1):
|
| 13 |
+
leaderboard += f"{i}. **{player}**: {score} points\n"
|
| 14 |
+
return leaderboard
|
| 15 |
+
|
| 16 |
+
# 在 Streamlit 应用中显示排行榜
|
| 17 |
+
st.markdown(create_leaderboard(data))
|