Spaces:
Sleeping
Sleeping
sjz
commited on
Commit
·
9e90422
1
Parent(s):
ae94556
fix player vs AI bugs
Browse files- const.py +5 -3
- pages/test.py +592 -0
const.py
CHANGED
|
@@ -14,6 +14,8 @@ _BOARD_SIZE_1D = _BOARD_SIZE * _BOARD_SIZE
|
|
| 14 |
_BLANK = 0
|
| 15 |
_BLACK = 1
|
| 16 |
_WHITE = 2
|
|
|
|
|
|
|
| 17 |
_NEW = 3
|
| 18 |
_PLAYER_SYMBOL1 = {
|
| 19 |
_WHITE: "⚪",
|
|
@@ -31,10 +33,10 @@ _PLAYER_SYMBOL2 = {
|
|
| 31 |
|
| 32 |
|
| 33 |
|
| 34 |
-
|
| 35 |
-
|
| 36 |
_BLANK: "Blank",
|
| 37 |
-
|
| 38 |
}
|
| 39 |
_PLAYER_COLOR_AI_VS_AI = {
|
| 40 |
_WHITE: "WHITE",
|
|
|
|
| 14 |
_BLANK = 0
|
| 15 |
_BLACK = 1
|
| 16 |
_WHITE = 2
|
| 17 |
+
_HUMAN = 4
|
| 18 |
+
_AI = 5
|
| 19 |
_NEW = 3
|
| 20 |
_PLAYER_SYMBOL1 = {
|
| 21 |
_WHITE: "⚪",
|
|
|
|
| 33 |
|
| 34 |
|
| 35 |
|
| 36 |
+
_PLAYER_NAME = {
|
| 37 |
+
_AI: "AI",
|
| 38 |
_BLANK: "Blank",
|
| 39 |
+
_HUMAN: "YOU HUMAN",
|
| 40 |
}
|
| 41 |
_PLAYER_COLOR_AI_VS_AI = {
|
| 42 |
_WHITE: "WHITE",
|
pages/test.py
ADDED
|
@@ -0,0 +1,592 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
FileName: app.py
|
| 3 |
+
Author: Benhao Huang
|
| 4 |
+
Create Date: 2023/11/19
|
| 5 |
+
Description: this file is used to display our project and add visualization elements to the game, using Streamlit
|
| 6 |
+
"""
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
import pandas as pd
|
| 10 |
+
from copy import deepcopy
|
| 11 |
+
import numpy as np
|
| 12 |
+
import streamlit as st
|
| 13 |
+
from scipy.signal import convolve # this is used to check if any player wins
|
| 14 |
+
from streamlit import session_state
|
| 15 |
+
from streamlit_server_state import server_state, server_state_lock
|
| 16 |
+
from Gomoku_MCTS import MCTSpure, alphazero, Board, PolicyValueNet_old, PolicyValueNet_new, duel_PolicyValueNet, \
|
| 17 |
+
Gumbel_MCTSPlayer
|
| 18 |
+
from Gomoku_Bot import Gomoku_bot
|
| 19 |
+
from Gomoku_Bot import Board as Gomoku_bot_board
|
| 20 |
+
|
| 21 |
+
import matplotlib.pyplot as plt
|
| 22 |
+
|
| 23 |
+
from const import (
|
| 24 |
+
_BLACK, # 1
|
| 25 |
+
_WHITE, # 2
|
| 26 |
+
_HUMAN,
|
| 27 |
+
_AI,
|
| 28 |
+
_BLANK,
|
| 29 |
+
_PLAYER_NAME,
|
| 30 |
+
_PLAYER_SYMBOL1,
|
| 31 |
+
_PLAYER_SYMBOL2,
|
| 32 |
+
_ROOM_COLOR,
|
| 33 |
+
_VERTICAL,
|
| 34 |
+
_NEW,
|
| 35 |
+
_HORIZONTAL,
|
| 36 |
+
_DIAGONAL_UP_LEFT,
|
| 37 |
+
_DIAGONAL_UP_RIGHT,
|
| 38 |
+
_BOARD_SIZE,
|
| 39 |
+
_MODEL_PATH
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
_PLAYER_SYMBOL = [0, _PLAYER_SYMBOL1, _PLAYER_SYMBOL2]
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
if "FirstPlayer" not in session_state:
|
| 46 |
+
session_state.FirstPlayer = _HUMAN
|
| 47 |
+
session_state.Players = [ _BLACK,_WHITE]
|
| 48 |
+
session_state.Symbols = _PLAYER_SYMBOL1
|
| 49 |
+
|
| 50 |
+
# Utils
|
| 51 |
+
class Room:
|
| 52 |
+
def __init__(self, room_id) -> None:
|
| 53 |
+
self.ROOM_ID = room_id
|
| 54 |
+
self.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Players)
|
| 55 |
+
self.TURN = _BLACK
|
| 56 |
+
self.CURR_PLAYER = session_state.FirstPlayer
|
| 57 |
+
self.HISTORY = (0, 0)
|
| 58 |
+
self.WINNER = _BLANK
|
| 59 |
+
self.TIME = time.time()
|
| 60 |
+
self.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 61 |
+
self.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 62 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 63 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 64 |
+
c_puct=5, n_playout=100),
|
| 65 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 66 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 67 |
+
c_puct=5, n_playout=100),
|
| 68 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 69 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
| 70 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 71 |
+
'Gomoku Bot': Gomoku_bot(self.gomoku_bot_board, -1)}
|
| 72 |
+
self.MCTS = self.MCTS_dict['AlphaZero']
|
| 73 |
+
self.last_mcts = self.MCTS
|
| 74 |
+
self.AID_MCTS = self.MCTS_dict['AlphaZero']
|
| 75 |
+
self.COORDINATE_1D = []
|
| 76 |
+
self.current_move = -1
|
| 77 |
+
self.ai_simula_time_list = []
|
| 78 |
+
self.human_simula_time_list = []
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def change_turn(cur):
|
| 82 |
+
if cur in [_HUMAN, _AI]:
|
| 83 |
+
return _HUMAN if cur == _AI else _AI
|
| 84 |
+
return cur % 2 + 1
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
# Initialize the game
|
| 88 |
+
if "ROOM" not in session_state:
|
| 89 |
+
session_state.ROOM = Room("local")
|
| 90 |
+
if "OWNER" not in session_state:
|
| 91 |
+
session_state.OWNER = False
|
| 92 |
+
if "USE_AIAID" not in session_state:
|
| 93 |
+
session_state.USE_AIAID = False
|
| 94 |
+
|
| 95 |
+
# Check server health
|
| 96 |
+
if "ROOMS" not in server_state:
|
| 97 |
+
with server_state_lock["ROOMS"]:
|
| 98 |
+
server_state.ROOMS = {}
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def handle_oppo_model_selection():
|
| 102 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 103 |
+
session_state.ROOM.MCTS = new_mct
|
| 104 |
+
session_state.ROOM.last_mcts = new_mct
|
| 105 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
| 106 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 107 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
| 108 |
+
session_state.ROOM.TURN = _BLACK
|
| 109 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 110 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 111 |
+
session_state.ROOM.human_simula_time_list = []
|
| 112 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 113 |
+
return
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
def handle_aid_model_selection():
|
| 117 |
+
if st.session_state['selected_aid_model'] == 'None':
|
| 118 |
+
session_state.USE_AIAID = False
|
| 119 |
+
return
|
| 120 |
+
session_state.USE_AIAID = True
|
| 121 |
+
TreeNode = session_state.ROOM.MCTS.mcts._root # use the same tree node
|
| 122 |
+
new_mct = session_state.ROOM.MCTS_dict[st.session_state['selected_aid_model']]
|
| 123 |
+
new_mct.mcts._root = deepcopy(TreeNode)
|
| 124 |
+
session_state.ROOM.AID_MCTS = new_mct
|
| 125 |
+
return
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
if 'selected_oppo_model' not in st.session_state:
|
| 129 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero' # 默认值
|
| 130 |
+
|
| 131 |
+
if 'selected_aid_model' not in st.session_state:
|
| 132 |
+
st.session_state['selected_aid_model'] = 'AlphaZero' # 默认值
|
| 133 |
+
|
| 134 |
+
# Layout
|
| 135 |
+
TITLE = st.empty()
|
| 136 |
+
Model_Switch = st.empty()
|
| 137 |
+
|
| 138 |
+
TITLE.header("🤖 AI 3603 Gomoku")
|
| 139 |
+
selected_oppo_option = Model_Switch.selectbox('Select Opponent Model',
|
| 140 |
+
['Pure MCTS', 'AlphaZero', 'Gomoku Bot', 'duel', 'Gumbel AlphaZero'],
|
| 141 |
+
index=1, key='oppo_model')
|
| 142 |
+
|
| 143 |
+
if st.session_state['selected_oppo_model'] != selected_oppo_option:
|
| 144 |
+
st.session_state['selected_oppo_model'] = selected_oppo_option
|
| 145 |
+
handle_oppo_model_selection()
|
| 146 |
+
|
| 147 |
+
ROUND_INFO = st.empty()
|
| 148 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 149 |
+
BOARD_PLATE = [
|
| 150 |
+
[cell.empty() for cell in st.columns([1 for _ in range(_BOARD_SIZE)])] for _ in range(_BOARD_SIZE)
|
| 151 |
+
]
|
| 152 |
+
LOG = st.empty()
|
| 153 |
+
|
| 154 |
+
# Sidebar
|
| 155 |
+
SCORE_TAG = st.sidebar.empty()
|
| 156 |
+
SCORE_PLATE = st.sidebar.columns(2)
|
| 157 |
+
# History scores
|
| 158 |
+
SCORE_TAG.subheader("Scores")
|
| 159 |
+
|
| 160 |
+
PLAY_MODE_INFO = st.sidebar.container()
|
| 161 |
+
MULTIPLAYER_TAG = st.sidebar.empty()
|
| 162 |
+
with st.sidebar.container():
|
| 163 |
+
ANOTHER_ROUND = st.empty()
|
| 164 |
+
RESTART = st.empty()
|
| 165 |
+
GIVEIN = st.empty()
|
| 166 |
+
CHANGE_PLAYER = st.empty()
|
| 167 |
+
AIAID = st.empty()
|
| 168 |
+
EXIT = st.empty()
|
| 169 |
+
selected_aid_option = AIAID.selectbox('Select Assistant Model', ['None', 'Pure MCTS', 'AlphaZero'], index=0,
|
| 170 |
+
key='aid_model')
|
| 171 |
+
if st.session_state['selected_aid_model'] != selected_aid_option:
|
| 172 |
+
st.session_state['selected_aid_model'] = selected_aid_option
|
| 173 |
+
handle_aid_model_selection()
|
| 174 |
+
|
| 175 |
+
GAME_INFO = st.sidebar.container()
|
| 176 |
+
message = st.empty()
|
| 177 |
+
PLAY_MODE_INFO.write("---\n\n**You are Black, AI agent is White.**")
|
| 178 |
+
GAME_INFO.markdown(
|
| 179 |
+
"""
|
| 180 |
+
---
|
| 181 |
+
# <span style="color:black;">Freestyle Gomoku game. 🎲</span>
|
| 182 |
+
- no restrictions 🚫
|
| 183 |
+
- no regrets 😎
|
| 184 |
+
- no regrets 😎
|
| 185 |
+
- swap players after one round is over 🔁
|
| 186 |
+
Powered by an AlphaZero approach with our own improvements! 🚀 For the specific details, please check out our <a href="insert_report_link_here" style="color:blue;">report</a>.
|
| 187 |
+
##### Adapted and improved by us! 🌟 <a href="https://github.com/Lijiaxin0111/AI_3603_BIGHOME" style="color:blue;">Our Github repo</a>
|
| 188 |
+
""",
|
| 189 |
+
unsafe_allow_html=True,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
|
| 193 |
+
def restart() -> None:
|
| 194 |
+
"""
|
| 195 |
+
Restart the game.
|
| 196 |
+
"""
|
| 197 |
+
session_state.ROOM = Room(session_state.ROOM.ROOM_ID)
|
| 198 |
+
st.session_state['selected_oppo_model'] = 'AlphaZero'
|
| 199 |
+
|
| 200 |
+
def givein() -> None:
|
| 201 |
+
"""
|
| 202 |
+
Give in to AI.
|
| 203 |
+
"""
|
| 204 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
| 205 |
+
session_state.ROOM.WINNER = _AI
|
| 206 |
+
# add 1 score to AI
|
| 207 |
+
session_state.ROOM.HISTORY = (
|
| 208 |
+
session_state.ROOM.HISTORY[0]
|
| 209 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
| 210 |
+
session_state.ROOM.HISTORY[1]
|
| 211 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
| 212 |
+
)
|
| 213 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
| 214 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 215 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 216 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 217 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 218 |
+
c_puct=5, n_playout=100),
|
| 219 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 220 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 221 |
+
c_puct=5, n_playout=100),
|
| 222 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 223 |
+
_MODEL_PATH[
|
| 224 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
| 225 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 226 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 227 |
+
|
| 228 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 229 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 230 |
+
session_state.ROOM.TURN = _BLACK
|
| 231 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
| 232 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 233 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 234 |
+
session_state.ROOM.human_simula_time_list = []
|
| 235 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 236 |
+
|
| 237 |
+
def swap_players() -> None:
|
| 238 |
+
session_state.update(
|
| 239 |
+
FirstPlayer=change_turn(session_state.FirstPlayer),
|
| 240 |
+
)
|
| 241 |
+
"""
|
| 242 |
+
session_state.FirstPlayer = _HUMAN
|
| 243 |
+
session_state.Players = [ _BLACK,_WHITE]
|
| 244 |
+
session_state.Symbols = _PLAYER_SYMBOL1
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5, players=session_state.Players)
|
| 248 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 249 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 250 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 251 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 252 |
+
c_puct=5, n_playout=100),
|
| 253 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 254 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 255 |
+
c_puct=5, n_playout=100),
|
| 256 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 257 |
+
_MODEL_PATH[
|
| 258 |
+
"Gumbel AlphaZero"]).policy_value_fn,
|
| 259 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 260 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 261 |
+
|
| 262 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 263 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 264 |
+
session_state.ROOM.TURN = _BLACK
|
| 265 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
| 266 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 267 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 268 |
+
session_state.ROOM.human_simula_time_list = []
|
| 269 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 270 |
+
|
| 271 |
+
RESTART.button(
|
| 272 |
+
"Reset",
|
| 273 |
+
on_click=restart,
|
| 274 |
+
help="Clear the board as well as the scores",
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
GIVEIN.button(
|
| 278 |
+
"Give in",
|
| 279 |
+
on_click = givein,
|
| 280 |
+
help="Give in to AI",
|
| 281 |
+
)
|
| 282 |
+
|
| 283 |
+
CHANGE_PLAYER.button(
|
| 284 |
+
"Swap players",
|
| 285 |
+
on_click=swap_players,
|
| 286 |
+
help="Swap players",
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
|
| 290 |
+
# Draw the board
|
| 291 |
+
def gomoku():
|
| 292 |
+
"""
|
| 293 |
+
Draw the board.
|
| 294 |
+
Handle the main logic.
|
| 295 |
+
"""
|
| 296 |
+
|
| 297 |
+
# Restart the game
|
| 298 |
+
|
| 299 |
+
# Continue new round
|
| 300 |
+
def another_round() -> None:
|
| 301 |
+
"""
|
| 302 |
+
Continue new round.
|
| 303 |
+
"""
|
| 304 |
+
session_state.ROOM = deepcopy(session_state.ROOM)
|
| 305 |
+
session_state.ROOM.BOARD = Board(width=_BOARD_SIZE, height=_BOARD_SIZE, n_in_row=5)
|
| 306 |
+
session_state.ROOM.gomoku_bot_board = Gomoku_bot_board(_BOARD_SIZE, 1)
|
| 307 |
+
session_state.ROOM.MCTS_dict = {'Pure MCTS': MCTSpure(c_puct=5, n_playout=1000),
|
| 308 |
+
'AlphaZero': alphazero(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 309 |
+
_MODEL_PATH["AlphaZero"]).policy_value_fn,
|
| 310 |
+
c_puct=5, n_playout=100),
|
| 311 |
+
'duel': alphazero(duel_PolicyValueNet(_BOARD_SIZE, _BOARD_SIZE,
|
| 312 |
+
_MODEL_PATH["duel"]).policy_value_fn,
|
| 313 |
+
c_puct=5, n_playout=100),
|
| 314 |
+
'Gumbel AlphaZero': Gumbel_MCTSPlayer(PolicyValueNet_new(_BOARD_SIZE, _BOARD_SIZE,
|
| 315 |
+
_MODEL_PATH["Gumbel AlphaZero"]).policy_value_fn,
|
| 316 |
+
c_puct=5, n_playout=100, m_action=8),
|
| 317 |
+
'Gomoku Bot': Gomoku_bot(session_state.ROOM.gomoku_bot_board, -1)}
|
| 318 |
+
|
| 319 |
+
session_state.ROOM.MCTS = session_state.ROOM.MCTS_dict[st.session_state['selected_oppo_model']]
|
| 320 |
+
session_state.ROOM.last_mcts = session_state.ROOM.MCTS
|
| 321 |
+
session_state.ROOM.TURN = _BLACK
|
| 322 |
+
session_state.ROOM.CURR_PLAYER = session_state.FirstPlayer
|
| 323 |
+
session_state.ROOM.WINNER = _BLANK # 0
|
| 324 |
+
session_state.ROOM.ai_simula_time_list = []
|
| 325 |
+
session_state.ROOM.human_simula_time_list = []
|
| 326 |
+
session_state.ROOM.COORDINATE_1D = []
|
| 327 |
+
|
| 328 |
+
# Room status sync
|
| 329 |
+
def sync_room() -> bool:
|
| 330 |
+
room_id = session_state.ROOM.ROOM_ID
|
| 331 |
+
if room_id not in server_state.ROOMS.keys():
|
| 332 |
+
session_state.ROOM = Room("local")
|
| 333 |
+
return False
|
| 334 |
+
elif server_state.ROOMS[room_id].TIME == session_state.ROOM.TIME:
|
| 335 |
+
return False
|
| 336 |
+
elif server_state.ROOMS[room_id].TIME < session_state.ROOM.TIME:
|
| 337 |
+
# Only acquire the lock when writing to the server state
|
| 338 |
+
with server_state_lock["ROOMS"]:
|
| 339 |
+
server_rooms = server_state.ROOMS
|
| 340 |
+
server_rooms[room_id] = session_state.ROOM
|
| 341 |
+
server_state.ROOMS = server_rooms
|
| 342 |
+
return True
|
| 343 |
+
else:
|
| 344 |
+
session_state.ROOM = server_state.ROOMS[room_id]
|
| 345 |
+
return True
|
| 346 |
+
|
| 347 |
+
# Triggers the board response on click
|
| 348 |
+
def handle_click(x, y):
|
| 349 |
+
"""
|
| 350 |
+
Controls whether to pass on / continue current board / may start new round
|
| 351 |
+
"""
|
| 352 |
+
if session_state.ROOM.BOARD.board_map[x][y] != _BLANK:
|
| 353 |
+
pass
|
| 354 |
+
elif (
|
| 355 |
+
session_state.ROOM.ROOM_ID in server_state.ROOMS.keys()
|
| 356 |
+
and _ROOM_COLOR[session_state.OWNER]
|
| 357 |
+
!= server_state.ROOMS[session_state.ROOM.ROOM_ID].TURN
|
| 358 |
+
):
|
| 359 |
+
sync_room()
|
| 360 |
+
|
| 361 |
+
# normal play situation
|
| 362 |
+
elif session_state.ROOM.WINNER == _BLANK:
|
| 363 |
+
move = session_state.ROOM.BOARD.location_to_move((x, y))
|
| 364 |
+
session_state.ROOM.current_move = move
|
| 365 |
+
session_state.ROOM.BOARD.do_move(move)
|
| 366 |
+
# Gomoku Bot BOARD
|
| 367 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - move // _BOARD_SIZE - 1,
|
| 368 |
+
move % _BOARD_SIZE) # # this move starts from left up corner (0,0), however, the move in the game starts from left bottom corner (0,0)
|
| 369 |
+
session_state.ROOM.BOARD.board_map[x][y] = session_state.ROOM.TURN
|
| 370 |
+
session_state.ROOM.COORDINATE_1D.append(x * _BOARD_SIZE + y)
|
| 371 |
+
|
| 372 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
| 373 |
+
session_state.ROOM.CURR_PLAYER = change_turn(session_state.ROOM.CURR_PLAYER)
|
| 374 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
| 375 |
+
if win:
|
| 376 |
+
session_state.ROOM.WINNER = session_state.ROOM.CURR_PLAYER
|
| 377 |
+
session_state.ROOM.HISTORY = (
|
| 378 |
+
session_state.ROOM.HISTORY[0]
|
| 379 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
| 380 |
+
session_state.ROOM.HISTORY[1]
|
| 381 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
| 382 |
+
)
|
| 383 |
+
session_state.ROOM.TIME = time.time()
|
| 384 |
+
|
| 385 |
+
def forbid_click(x, y):
|
| 386 |
+
# st.warning('This posistion has been occupied!!!!', icon="⚠️")
|
| 387 |
+
st.error("({}, {}) has been occupied!!)".format(x, y), icon="🚨")
|
| 388 |
+
|
| 389 |
+
# Draw board
|
| 390 |
+
def draw_board(response: bool):
|
| 391 |
+
"""construct each buttons for all cells of the board"""
|
| 392 |
+
if session_state.USE_AIAID and session_state.ROOM.WINNER == _BLANK and session_state.ROOM.CURR_PLAYER == _HUMAN:
|
| 393 |
+
if session_state.USE_AIAID:
|
| 394 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
| 395 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
| 396 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
| 397 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
| 398 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
| 399 |
+
if response and session_state.ROOM.CURR_PLAYER == _HUMAN: # human turn
|
| 400 |
+
start_time = time.time()
|
| 401 |
+
print("Your turn")
|
| 402 |
+
# construction of clickable buttons
|
| 403 |
+
cur_move = (session_state.ROOM.current_move // _BOARD_SIZE, session_state.ROOM.current_move % _BOARD_SIZE)
|
| 404 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 405 |
+
for j, cell in enumerate(row):
|
| 406 |
+
if (
|
| 407 |
+
i * _BOARD_SIZE + j
|
| 408 |
+
in (session_state.ROOM.COORDINATE_1D)
|
| 409 |
+
):
|
| 410 |
+
if i == cur_move[0] and j == cur_move[1]:
|
| 411 |
+
BOARD_PLATE[i][j].button(
|
| 412 |
+
session_state.Symbols[_NEW],
|
| 413 |
+
key=f"{i}:{j}",
|
| 414 |
+
args=(i, j),
|
| 415 |
+
on_click=forbid_click,
|
| 416 |
+
)
|
| 417 |
+
else:
|
| 418 |
+
# disable click for GPT choices
|
| 419 |
+
BOARD_PLATE[i][j].button(
|
| 420 |
+
session_state.Symbols[cell],
|
| 421 |
+
key=f"{i}:{j}",
|
| 422 |
+
args=(i, j),
|
| 423 |
+
on_click=forbid_click
|
| 424 |
+
)
|
| 425 |
+
else:
|
| 426 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts:
|
| 427 |
+
# enable click for other cells available for human choices
|
| 428 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 429 |
+
BOARD_PLATE[i][j].button(
|
| 430 |
+
session_state.Symbols[cell] + f"{round(prob, 2)}",
|
| 431 |
+
key=f"{i}:{j}",
|
| 432 |
+
on_click=handle_click,
|
| 433 |
+
args=(i, j),
|
| 434 |
+
)
|
| 435 |
+
else:
|
| 436 |
+
# enable click for other cells available for human choices
|
| 437 |
+
BOARD_PLATE[i][j].button(
|
| 438 |
+
session_state.Symbols[cell],
|
| 439 |
+
key=f"{i}:{j}",
|
| 440 |
+
on_click=handle_click,
|
| 441 |
+
args=(i, j),
|
| 442 |
+
)
|
| 443 |
+
end_time = time.time()
|
| 444 |
+
print("Time used for human move: ", end_time - start_time)
|
| 445 |
+
|
| 446 |
+
elif response and session_state.ROOM.CURR_PLAYER == _AI: # AI turn
|
| 447 |
+
message.empty()
|
| 448 |
+
with st.spinner('🔮✨ Waiting for AI response... ⏳🚀'):
|
| 449 |
+
time.sleep(0.05)
|
| 450 |
+
print("AI's turn")
|
| 451 |
+
|
| 452 |
+
if st.session_state['selected_oppo_model'] != 'Gomoku Bot':
|
| 453 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(session_state.ROOM.BOARD, return_time=True)
|
| 454 |
+
else:
|
| 455 |
+
move, simul_time = session_state.ROOM.MCTS.get_action(return_time=True)
|
| 456 |
+
session_state.ROOM.ai_simula_time_list.append(simul_time)
|
| 457 |
+
print("AI takes move: ", move)
|
| 458 |
+
session_state.ROOM.current_move = move
|
| 459 |
+
gpt_response = move
|
| 460 |
+
gpt_i, gpt_j = gpt_response // _BOARD_SIZE, gpt_response % _BOARD_SIZE
|
| 461 |
+
print("AI's move is located at ({}, {}) :".format(gpt_i, gpt_j))
|
| 462 |
+
move = session_state.ROOM.BOARD.location_to_move((gpt_i, gpt_j))
|
| 463 |
+
print("Location to move: ", move)
|
| 464 |
+
# MCTS BOARD
|
| 465 |
+
session_state.ROOM.BOARD.do_move(move)
|
| 466 |
+
# Gomoku Bot BOARD
|
| 467 |
+
session_state.ROOM.MCTS_dict["Gomoku Bot"].board.put(_BOARD_SIZE - 1 - move // _BOARD_SIZE,
|
| 468 |
+
move % _BOARD_SIZE)
|
| 469 |
+
session_state.ROOM.COORDINATE_1D.append(gpt_i * _BOARD_SIZE + gpt_j)
|
| 470 |
+
|
| 471 |
+
if not session_state.ROOM.BOARD.game_end()[0]:
|
| 472 |
+
if session_state.USE_AIAID:
|
| 473 |
+
copy_mcts = deepcopy(session_state.ROOM.AID_MCTS.mcts)
|
| 474 |
+
_, acts_aid, probs_aid, simul_mean_time_aid = copy_mcts.get_move_probs(session_state.ROOM.BOARD)
|
| 475 |
+
sorted_acts_probs = sorted(zip(acts_aid, probs_aid), key=lambda x: x[1], reverse=True)
|
| 476 |
+
top_five_acts = [act for act, prob in sorted_acts_probs[:5]]
|
| 477 |
+
top_five_probs = [prob for act, prob in sorted_acts_probs[:5]]
|
| 478 |
+
else:
|
| 479 |
+
top_five_acts = []
|
| 480 |
+
top_five_probs = []
|
| 481 |
+
|
| 482 |
+
# construction of clickable buttons
|
| 483 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 484 |
+
# print("row:", row)
|
| 485 |
+
for j, cell in enumerate(row):
|
| 486 |
+
if (
|
| 487 |
+
i * _BOARD_SIZE + j
|
| 488 |
+
in (session_state.ROOM.COORDINATE_1D)
|
| 489 |
+
):
|
| 490 |
+
if i == gpt_i and j == gpt_j:
|
| 491 |
+
BOARD_PLATE[i][j].button(
|
| 492 |
+
session_state.Symbols[_NEW],
|
| 493 |
+
key=f"{i}:{j}",
|
| 494 |
+
args=(i, j),
|
| 495 |
+
on_click=handle_click,
|
| 496 |
+
)
|
| 497 |
+
else:
|
| 498 |
+
# disable click for GPT choices
|
| 499 |
+
BOARD_PLATE[i][j].button(
|
| 500 |
+
session_state.Symbols[cell],
|
| 501 |
+
key=f"{i}:{j}",
|
| 502 |
+
args=(i, j),
|
| 503 |
+
on_click=forbid_click
|
| 504 |
+
)
|
| 505 |
+
else:
|
| 506 |
+
if session_state.USE_AIAID and i * _BOARD_SIZE + j in top_five_acts and not \
|
| 507 |
+
session_state.ROOM.BOARD.game_end()[0]:
|
| 508 |
+
# enable click for other cells available for human choices
|
| 509 |
+
prob = top_five_probs[top_five_acts.index(i * _BOARD_SIZE + j)]
|
| 510 |
+
BOARD_PLATE[i][j].button(
|
| 511 |
+
session_state.Symbols[cell] + f"{round(prob, 2)}",
|
| 512 |
+
key=f"{i}:{j}",
|
| 513 |
+
on_click=handle_click,
|
| 514 |
+
args=(i, j),
|
| 515 |
+
)
|
| 516 |
+
else:
|
| 517 |
+
# enable click for other cells available for human choices
|
| 518 |
+
BOARD_PLATE[i][j].button(
|
| 519 |
+
session_state.Symbols[cell],
|
| 520 |
+
key=f"{i}:{j}",
|
| 521 |
+
on_click=handle_click,
|
| 522 |
+
args=(i, j),
|
| 523 |
+
)
|
| 524 |
+
|
| 525 |
+
message.markdown(
|
| 526 |
+
'AI agent has calculated its strategy, which takes <span style="color: blue; font-size: 20px;">{:.3e}</span>s per simulation.'.format(
|
| 527 |
+
simul_time),
|
| 528 |
+
unsafe_allow_html=True
|
| 529 |
+
)
|
| 530 |
+
LOG.subheader("Logs")
|
| 531 |
+
# change turn
|
| 532 |
+
session_state.ROOM.TURN = change_turn(session_state.ROOM.TURN)
|
| 533 |
+
session_state.ROOM.CURR_PLAYER = change_turn(session_state.ROOM.CURR_PLAYER)
|
| 534 |
+
|
| 535 |
+
win, winner = session_state.ROOM.BOARD.game_end()
|
| 536 |
+
if win:
|
| 537 |
+
session_state.ROOM.WINNER = session_state.ROOM.CURR_PLAYER
|
| 538 |
+
|
| 539 |
+
session_state.ROOM.HISTORY = (
|
| 540 |
+
session_state.ROOM.HISTORY[0]
|
| 541 |
+
+ int(session_state.ROOM.WINNER == _AI),
|
| 542 |
+
session_state.ROOM.HISTORY[1]
|
| 543 |
+
+ int(session_state.ROOM.WINNER == _HUMAN),
|
| 544 |
+
)
|
| 545 |
+
session_state.ROOM.TIME = time.time()
|
| 546 |
+
|
| 547 |
+
if not response or session_state.ROOM.WINNER != _BLANK:
|
| 548 |
+
if session_state.ROOM.WINNER != _BLANK:
|
| 549 |
+
print("Game over")
|
| 550 |
+
for i, row in enumerate(session_state.ROOM.BOARD.board_map):
|
| 551 |
+
for j, cell in enumerate(row):
|
| 552 |
+
BOARD_PLATE[i][j].write(
|
| 553 |
+
session_state.Symbols[cell],
|
| 554 |
+
# key=f"{i}:{j}",
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# Game process control
|
| 558 |
+
def game_control():
|
| 559 |
+
if session_state.ROOM.WINNER != _BLANK:
|
| 560 |
+
draw_board(False)
|
| 561 |
+
else:
|
| 562 |
+
draw_board(True)
|
| 563 |
+
if session_state.ROOM.WINNER != _BLANK or 0 not in session_state.ROOM.BOARD.board_map:
|
| 564 |
+
GIVEIN.empty()
|
| 565 |
+
ANOTHER_ROUND.button(
|
| 566 |
+
"Play Next round!",
|
| 567 |
+
on_click=another_round,
|
| 568 |
+
help="Clear board and swap first player",
|
| 569 |
+
)
|
| 570 |
+
|
| 571 |
+
# Infos
|
| 572 |
+
def update_info() -> None:
|
| 573 |
+
# Additional information
|
| 574 |
+
SCORE_PLATE[0].metric("Gomoku-Agent", session_state.ROOM.HISTORY[0])
|
| 575 |
+
SCORE_PLATE[1].metric("You", session_state.ROOM.HISTORY[1])
|
| 576 |
+
if session_state.ROOM.WINNER != _BLANK:
|
| 577 |
+
st.balloons()
|
| 578 |
+
ROUND_INFO.write(
|
| 579 |
+
f"#### **{_PLAYER_NAME[session_state.ROOM.WINNER]} WIN!**\n**Click buttons on the left for more plays.**"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 583 |
+
st.markdown("<br>", unsafe_allow_html=True)
|
| 584 |
+
chart_data = pd.DataFrame(session_state.ROOM.ai_simula_time_list, columns=["Simulation Time"])
|
| 585 |
+
st.line_chart(chart_data)
|
| 586 |
+
|
| 587 |
+
game_control()
|
| 588 |
+
update_info()
|
| 589 |
+
|
| 590 |
+
|
| 591 |
+
if __name__ == "__main__":
|
| 592 |
+
gomoku()
|