Spaces:
Sleeping
Sleeping
Gruhit Patel
commited on
init-backend
Browse files- agent.py +72 -0
- arena.py +83 -0
- buffer.py +50 -0
- config.py +47 -0
- game.py +131 -0
- main.py +63 -0
- main2.py +92 -0
- mcts.py +282 -0
- model.py +106 -0
- requirement.txt +7 -0
- trainer.py +85 -0
- view_board.py +60 -0
agent.py
ADDED
|
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import Model
|
| 2 |
+
from buffer import Buffer
|
| 3 |
+
from game import Connect4
|
| 4 |
+
from mcts import MCTS_NN
|
| 5 |
+
|
| 6 |
+
import numpy as np
|
| 7 |
+
from typing import Tuple, List
|
| 8 |
+
|
| 9 |
+
class Agent:
|
| 10 |
+
def __init__(self, row:int, col:int, n_action: int, obs_shape: Tuple[int, int, int],
|
| 11 |
+
model: Model, iteration: int, temperature:float):
|
| 12 |
+
|
| 13 |
+
self.row = row
|
| 14 |
+
self.col = col
|
| 15 |
+
self.n_action = n_action
|
| 16 |
+
self.obs_shape = obs_shape
|
| 17 |
+
self.iteration = iteration
|
| 18 |
+
self.temperature = temperature
|
| 19 |
+
|
| 20 |
+
# Create buffer instance
|
| 21 |
+
self.buffer = Buffer(n_action=self.n_action, obs_shape=self.obs_shape)
|
| 22 |
+
|
| 23 |
+
# Target model instance
|
| 24 |
+
self.target_model = model
|
| 25 |
+
|
| 26 |
+
# Reset the MCTS class instance and buffer
|
| 27 |
+
def reset(self, state: Connect4, reset_buffer: bool = False) -> None:
|
| 28 |
+
# Reset the state of the Monte-carlo tree search instance
|
| 29 |
+
self.mcts = MCTS_NN(state=state, model=self.target_model)
|
| 30 |
+
|
| 31 |
+
# Reset the buffer
|
| 32 |
+
def reset_buffer(self) -> None:
|
| 33 |
+
self.buffer.reset()
|
| 34 |
+
|
| 35 |
+
# Get the policy from mcts simulation
|
| 36 |
+
def perform_mcts(self) -> np.ndarray:
|
| 37 |
+
for _ in range(self.iteration):
|
| 38 |
+
self.mcts.selection(self.mcts.root, add_dirichlet=True)
|
| 39 |
+
|
| 40 |
+
policy = self.mcts.get_policy_pie(self.temperature)
|
| 41 |
+
|
| 42 |
+
return policy
|
| 43 |
+
|
| 44 |
+
# Get an action for any state
|
| 45 |
+
def get_action(self) -> int:
|
| 46 |
+
policy = self.perform_mcts()
|
| 47 |
+
action = np.random.choice(self.n_action, p=policy)
|
| 48 |
+
return action, policy
|
| 49 |
+
|
| 50 |
+
# This method updates the buffer and send it to the buffer object
|
| 51 |
+
def update_buffer(self, episodic_buffer: List)->None:
|
| 52 |
+
# Get the last index of the episodic buffer
|
| 53 |
+
idx = len(episodic_buffer) - 1
|
| 54 |
+
|
| 55 |
+
# Always the last state will have value 1 as it would be the winning move
|
| 56 |
+
value = 1
|
| 57 |
+
while idx >= 0:
|
| 58 |
+
episodic_buffer[idx][1] = value
|
| 59 |
+
value *= -1 # For parent the value is negative
|
| 60 |
+
idx -= 1 # Go to the previous experience tuple
|
| 61 |
+
|
| 62 |
+
for state, value, policy in episodic_buffer:
|
| 63 |
+
self.buffer.store_experience(
|
| 64 |
+
state = state,
|
| 65 |
+
value = value,
|
| 66 |
+
policy = policy
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
# Update the root to set it to one of its child node
|
| 70 |
+
# based on the actio taken in the above method `get_action()`
|
| 71 |
+
def update(self, action: int) -> None:
|
| 72 |
+
self.mcts.update_root(action)
|
arena.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from game import Connect4
|
| 2 |
+
from agent import Agent
|
| 3 |
+
from model import Model
|
| 4 |
+
from mcts import MCTS_NN
|
| 5 |
+
|
| 6 |
+
from typing import Union
|
| 7 |
+
import tqdm
|
| 8 |
+
import numpy as np
|
| 9 |
+
|
| 10 |
+
def play_selfgames(agent: Agent, training_games: int):
|
| 11 |
+
|
| 12 |
+
for _ in tqdm(range(training_games)):
|
| 13 |
+
board = Connect4(row=agent.row, col=agent.col)
|
| 14 |
+
agent.reset(state = board)
|
| 15 |
+
|
| 16 |
+
# a buffer list to store the transition of current episode
|
| 17 |
+
episodic_buffer = []
|
| 18 |
+
|
| 19 |
+
while not board.is_win() and not board.is_draw():
|
| 20 |
+
# While getting the action the search is performed
|
| 21 |
+
# also the experience is stored in it
|
| 22 |
+
action, policy = agent.get_action()
|
| 23 |
+
episodic_buffer.append([
|
| 24 |
+
board.get_state(),
|
| 25 |
+
board.player_1,
|
| 26 |
+
policy
|
| 27 |
+
])
|
| 28 |
+
|
| 29 |
+
board, _ = board.drop_piece(action)
|
| 30 |
+
|
| 31 |
+
# Update the root node of MCTS to one of its child node
|
| 32 |
+
agent.update(action)
|
| 33 |
+
|
| 34 |
+
# When the episode is compelted update the buffer
|
| 35 |
+
agent.update_buffer(episodic_buffer)
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def get_move_for_bot(state: Connect4, model: Model, tree_iters: int, random_move: bool = False) -> int:
|
| 39 |
+
mcts = MCTS_NN(state = state, model = model)
|
| 40 |
+
|
| 41 |
+
for _ in range(tree_iters):
|
| 42 |
+
mcts.selection(mcts.root, random_move)
|
| 43 |
+
|
| 44 |
+
policy = mcts.get_policy_pie()
|
| 45 |
+
act = np.argmax(policy)
|
| 46 |
+
|
| 47 |
+
return act
|
| 48 |
+
|
| 49 |
+
def play_game_against_bot(bot1: Model, bot2: Model, tree_iters:int) -> Union[None, int]:
|
| 50 |
+
board = Connect4()
|
| 51 |
+
player_1 = True
|
| 52 |
+
|
| 53 |
+
# In function bot1 will be always datagen model to make 1st move
|
| 54 |
+
# bot2 will be main_model to make 2nd move
|
| 55 |
+
# We randomly allow them to make first move based for 50% of time
|
| 56 |
+
flip = False
|
| 57 |
+
if np.random.uniform() < 0.5:
|
| 58 |
+
flip = True
|
| 59 |
+
(bot1, bot2) = (bot2, bot1)
|
| 60 |
+
print("Bot has been flipped")
|
| 61 |
+
|
| 62 |
+
while not board.is_win() and not board.is_draw():
|
| 63 |
+
if player_1:
|
| 64 |
+
act = get_move_for_bot(board, model=bot1, tree_iters=tree_iters)
|
| 65 |
+
player_1 = False
|
| 66 |
+
else:
|
| 67 |
+
act = get_move_for_bot(board, model=bot2, tree_iters=tree_iters)
|
| 68 |
+
player_1 = True
|
| 69 |
+
|
| 70 |
+
board, win = board.drop_piece(act)
|
| 71 |
+
print(board)
|
| 72 |
+
|
| 73 |
+
# Here returning
|
| 74 |
+
# 0 - draw
|
| 75 |
+
# 1 - datagen won
|
| 76 |
+
# -1 - main_model won
|
| 77 |
+
# Hence when flipped we have to handle the values accordingly
|
| 78 |
+
if flip:
|
| 79 |
+
# Thus if we have flipped then main_model who is player 1 if its has won
|
| 80 |
+
# then we want to return -1 for it and vice-versa
|
| 81 |
+
return 0 if win == None else win*-1
|
| 82 |
+
else:
|
| 83 |
+
return 0 if win == None else win
|
buffer.py
ADDED
|
@@ -0,0 +1,50 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import numpy as np
|
| 3 |
+
|
| 4 |
+
class Buffer:
|
| 5 |
+
def __init__(self, n_action: int, obs_shape: Tuple[int, int, int]):
|
| 6 |
+
self.n_action = n_action
|
| 7 |
+
self.obs_shape = obs_shape
|
| 8 |
+
self.mem_size = 0
|
| 9 |
+
|
| 10 |
+
# Creating empty lists for storing value. Provide dynamicness
|
| 11 |
+
self.state = []
|
| 12 |
+
self.value = []
|
| 13 |
+
self.policy = []
|
| 14 |
+
|
| 15 |
+
def store_experience(self, state: np.ndarray, value: float, policy: np.ndarray):
|
| 16 |
+
self.state.append(state)
|
| 17 |
+
self.value.append(value)
|
| 18 |
+
self.policy.append(policy)
|
| 19 |
+
|
| 20 |
+
self.mem_size += 1
|
| 21 |
+
|
| 22 |
+
def sample(self, batch_size: int) -> Tuple[
|
| 23 |
+
np.ndarray,
|
| 24 |
+
np.ndarray,
|
| 25 |
+
np.ndarray
|
| 26 |
+
]: # type: ignore
|
| 27 |
+
# shuffle the memmory
|
| 28 |
+
np.random.shuffle(self.state)
|
| 29 |
+
np.random.shuffle(self.value)
|
| 30 |
+
np.random.shuffle(self.policy)
|
| 31 |
+
|
| 32 |
+
for start_idx in range(0, self.mem_size, batch_size):
|
| 33 |
+
end_idx = min(start_idx+batch_size, self.mem_size)
|
| 34 |
+
s = self.state[start_idx:end_idx]
|
| 35 |
+
v = self.value[start_idx:end_idx]
|
| 36 |
+
p = self.policy[start_idx:end_idx]
|
| 37 |
+
|
| 38 |
+
yield (np.array(s), np.array(v), np.array(p))
|
| 39 |
+
|
| 40 |
+
# Reset the the buffer to store new experience
|
| 41 |
+
def reset(self) -> None:
|
| 42 |
+
self.state = []
|
| 43 |
+
self.value = []
|
| 44 |
+
self.policy = []
|
| 45 |
+
|
| 46 |
+
self.mem_size = 0
|
| 47 |
+
|
| 48 |
+
# Return the length of the buffer
|
| 49 |
+
def __len__(self):
|
| 50 |
+
return self.mem_size
|
config.py
ADDED
|
@@ -0,0 +1,47 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Tuple
|
| 2 |
+
import torch
|
| 3 |
+
|
| 4 |
+
class Config:
|
| 5 |
+
# Board
|
| 6 |
+
row:int = 6
|
| 7 |
+
col:int = 7
|
| 8 |
+
|
| 9 |
+
# Neural Network
|
| 10 |
+
num_hidden:int = 64
|
| 11 |
+
num_res_block:int = 4
|
| 12 |
+
rate: float = 0.3
|
| 13 |
+
obs_shape: Tuple[int, int, int] = (4, row, col)
|
| 14 |
+
n_action: int = col
|
| 15 |
+
device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
|
| 16 |
+
checkpoint_path: str = "../Models/azv3.pt"
|
| 17 |
+
|
| 18 |
+
# Optimizer
|
| 19 |
+
base_lr: float = 0.01
|
| 20 |
+
weight_decay: float = 1e-4
|
| 21 |
+
|
| 22 |
+
# Monte-carlo tree search
|
| 23 |
+
temperature = 1.0
|
| 24 |
+
tree_iter = 100
|
| 25 |
+
|
| 26 |
+
# Training
|
| 27 |
+
selfplay_games:int = 50
|
| 28 |
+
epoch:int = 10
|
| 29 |
+
batch_size:int = 128
|
| 30 |
+
|
| 31 |
+
# Tournament
|
| 32 |
+
eval_games: int = 10
|
| 33 |
+
|
| 34 |
+
# How much elo rating should be given per winning
|
| 35 |
+
k: int = 10
|
| 36 |
+
|
| 37 |
+
# model update threshold
|
| 38 |
+
threshold: float = 0.55
|
| 39 |
+
|
| 40 |
+
# How many time you want to play selfplay games and train model
|
| 41 |
+
total_iters:int = 40
|
| 42 |
+
|
| 43 |
+
# Parallel_games
|
| 44 |
+
parallel_run: int = 4
|
| 45 |
+
|
| 46 |
+
DIRICHLET_ALPHA: float = 0.3 # Avg legal move / 75% of total move
|
| 47 |
+
EPSILON: float = 0.25
|
game.py
ADDED
|
@@ -0,0 +1,131 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from copy import deepcopy
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Union
|
| 4 |
+
|
| 5 |
+
class Connect4:
|
| 6 |
+
def __init__(self, board:'Connect4'=None, row:int = 6, col:int =7):
|
| 7 |
+
self.row = row
|
| 8 |
+
self.col = col
|
| 9 |
+
self.player_1 = 1
|
| 10 |
+
self.player_2 = -1
|
| 11 |
+
|
| 12 |
+
self.board = np.zeros((self.row, self.col))
|
| 13 |
+
|
| 14 |
+
self.winning_start = None
|
| 15 |
+
self.winning_end = None
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
if board is not None:
|
| 19 |
+
self.__dict__ = deepcopy(board.__dict__)
|
| 20 |
+
|
| 21 |
+
def drop_piece(self, action: int) -> 'Connect4':
|
| 22 |
+
|
| 23 |
+
board = Connect4(board=self)
|
| 24 |
+
|
| 25 |
+
# Find the row in that column which is valid to drop piece
|
| 26 |
+
valid_row_idx = sum(board.board[:, action] == 0) - 1
|
| 27 |
+
board.board[valid_row_idx, action] = self.player_1
|
| 28 |
+
(board.player_1, board.player_2) = (self.player_2, self.player_1)
|
| 29 |
+
|
| 30 |
+
return board, board.is_win()
|
| 31 |
+
|
| 32 |
+
# Get the encoded state for the board
|
| 33 |
+
def get_state(self) -> np.ndarray:
|
| 34 |
+
# Create a layer to state the player turn
|
| 35 |
+
turn = np.ones_like(self.board) if self.player_1 == 1 else np.zeros_like(self.board)
|
| 36 |
+
enc_state = np.stack(
|
| 37 |
+
(self.board == -1, self.board == 0, self.board == 1, turn)
|
| 38 |
+
).astype(np.int32)
|
| 39 |
+
|
| 40 |
+
return enc_state
|
| 41 |
+
|
| 42 |
+
# check if the board results in a draw state
|
| 43 |
+
def is_draw(self):
|
| 44 |
+
return (self.board != 0).all()
|
| 45 |
+
|
| 46 |
+
def is_win(self) -> Union[None, int]:
|
| 47 |
+
# Initially no one is winner
|
| 48 |
+
winner = None
|
| 49 |
+
|
| 50 |
+
# Check for columns
|
| 51 |
+
if self.col_win():
|
| 52 |
+
winner = self.player_2
|
| 53 |
+
# Check for rows
|
| 54 |
+
elif self.row_win():
|
| 55 |
+
winner = self.player_2
|
| 56 |
+
# Check for diagonals
|
| 57 |
+
elif self.diag_win():
|
| 58 |
+
winner = self.player_2
|
| 59 |
+
return winner
|
| 60 |
+
|
| 61 |
+
# Check for column win
|
| 62 |
+
def col_win(self) -> bool:
|
| 63 |
+
# Iterate over each column
|
| 64 |
+
for c in range(self.col):
|
| 65 |
+
# for 4 consequtive rows
|
| 66 |
+
for r in range(self.row-3):
|
| 67 |
+
# if the the all 4 element are of player who made move then its win
|
| 68 |
+
if sum(self.board[r:r+4, c] == self.player_2) == 4:
|
| 69 |
+
self.winning_start = (c, r)
|
| 70 |
+
self.winning_end = (c, r+3)
|
| 71 |
+
return True
|
| 72 |
+
|
| 73 |
+
return False
|
| 74 |
+
|
| 75 |
+
# check for win in row
|
| 76 |
+
def row_win(self) -> bool:
|
| 77 |
+
# Iterate over each row
|
| 78 |
+
for r in range(self.row):
|
| 79 |
+
# For 4 consequtive cols
|
| 80 |
+
for c in range(self.col-3):
|
| 81 |
+
# If all of 4 elements are of player who made move then its win
|
| 82 |
+
if sum(self.board[r, c:c+4] == self.player_2) == 4:
|
| 83 |
+
self.winning_start = (c, r)
|
| 84 |
+
self.winning_end = (c+3, r)
|
| 85 |
+
return True
|
| 86 |
+
|
| 87 |
+
return False
|
| 88 |
+
|
| 89 |
+
# check for win in diagonal
|
| 90 |
+
def diag_win(self) -> bool:
|
| 91 |
+
# For a window of 4x4 if the main diag or other diag has
|
| 92 |
+
# same disc of player who made move then its a win
|
| 93 |
+
for r in range(self.row-3):
|
| 94 |
+
for c in range(self.col-3):
|
| 95 |
+
# Get a window of size 4x4
|
| 96 |
+
window = self.board[r:r+4, c:c+4]
|
| 97 |
+
|
| 98 |
+
# If all 4 element of main diag(/) is player who made move then its win
|
| 99 |
+
if sum(np.diag(window) == self.player_2) == 4:
|
| 100 |
+
self.winning_start = (c+3, r)
|
| 101 |
+
self.winning_end = (c, r+3)
|
| 102 |
+
# print("WinningMain Diag: ", self.winning_start, " - ", self.winning_end)
|
| 103 |
+
return True
|
| 104 |
+
|
| 105 |
+
# If all 4 element of other diag(\) is player who made move then its win
|
| 106 |
+
if sum(np.diag(window[:, ::-1]) == self.player_2) == 4:
|
| 107 |
+
self.winning_start = (c, r)
|
| 108 |
+
self.winning_end = (c+3, r+3)
|
| 109 |
+
# print("WinningMain Diag: ", self.winning_start, " - ", self.winning_end)
|
| 110 |
+
return True
|
| 111 |
+
|
| 112 |
+
return False
|
| 113 |
+
|
| 114 |
+
# get a list of valid move that can be played by the current player
|
| 115 |
+
def get_valid_moves(self) -> np.array:
|
| 116 |
+
valid_cols = [False]*self.col
|
| 117 |
+
for c in range(self.col):
|
| 118 |
+
if self.board[0, c] == 0:
|
| 119 |
+
valid_cols[c] = True
|
| 120 |
+
|
| 121 |
+
return np.array(valid_cols, dtype=bool)
|
| 122 |
+
|
| 123 |
+
def __str__(self) -> str:
|
| 124 |
+
print_str = ""
|
| 125 |
+
for r in range(self.row):
|
| 126 |
+
for c in range(self.col):
|
| 127 |
+
print_str += f"{self.board[r, c]:>3.0f}"
|
| 128 |
+
|
| 129 |
+
print_str += "\n"
|
| 130 |
+
|
| 131 |
+
return print_str
|
main.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from fastapi import FastAPI
|
| 2 |
+
from fastapi.middleware.cors import CORSMiddleware
|
| 3 |
+
from game import Connect4
|
| 4 |
+
from model import Model
|
| 5 |
+
from config import Config
|
| 6 |
+
from pydantic import BaseModel
|
| 7 |
+
from typing import List, Union
|
| 8 |
+
import numpy as np
|
| 9 |
+
from arena import get_move_for_bot
|
| 10 |
+
import torch
|
| 11 |
+
|
| 12 |
+
class Request(BaseModel):
|
| 13 |
+
board: List[List[int]]
|
| 14 |
+
currentPlayer: str
|
| 15 |
+
randomMoves: Union[None, bool]
|
| 16 |
+
mctsIterations: Union[None, int]
|
| 17 |
+
|
| 18 |
+
# Create an application instance
|
| 19 |
+
app = FastAPI()
|
| 20 |
+
app.add_middleware(
|
| 21 |
+
CORSMiddleware,
|
| 22 |
+
allow_origins=["*"],
|
| 23 |
+
allow_credentials=True,
|
| 24 |
+
allow_methods=["GET", "POST"],
|
| 25 |
+
allow_headers=["*"]
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# Create the model
|
| 29 |
+
model = Model(
|
| 30 |
+
n_action = Config.n_action,
|
| 31 |
+
num_hidden = Config.num_hidden,
|
| 32 |
+
num_resblock = Config.num_res_block,
|
| 33 |
+
rate = Config.rate,
|
| 34 |
+
row = Config.row,
|
| 35 |
+
col = Config.col,
|
| 36 |
+
device = Config.device
|
| 37 |
+
)
|
| 38 |
+
model.load_state_dict(torch.load(Config.checkpoint_path))
|
| 39 |
+
model.eval()
|
| 40 |
+
|
| 41 |
+
@app.get("/")
|
| 42 |
+
def root():
|
| 43 |
+
return {"message": "This is a temporary response"}
|
| 44 |
+
|
| 45 |
+
@app.post("/get_move")
|
| 46 |
+
def get_move(req: Request):
|
| 47 |
+
global model
|
| 48 |
+
board_arr = np.array(req.board)
|
| 49 |
+
board = Connect4()
|
| 50 |
+
board.board = board_arr
|
| 51 |
+
|
| 52 |
+
if req.currentPlayer == "yellow":
|
| 53 |
+
(board.player_1, board.player_2) = (board.player_2, board.player_1)
|
| 54 |
+
|
| 55 |
+
# TODO: change the tree_iter to req.parameters
|
| 56 |
+
act = get_move_for_bot(
|
| 57 |
+
state = board,
|
| 58 |
+
model = model,
|
| 59 |
+
tree_iters = req.mctsIterations,
|
| 60 |
+
random_move = req.randomMoves
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
return {'move': int(act)}
|
main2.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import Model
|
| 2 |
+
from config import Config
|
| 3 |
+
from arena import get_move_for_bot
|
| 4 |
+
from game import Connect4
|
| 5 |
+
import pygame
|
| 6 |
+
from view_board import draw_board, draw_winning_line
|
| 7 |
+
import sys
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
def play_game(model: Model):
|
| 11 |
+
board = Connect4(
|
| 12 |
+
row = Config.row,
|
| 13 |
+
col = Config.col
|
| 14 |
+
)
|
| 15 |
+
|
| 16 |
+
pygame.init()
|
| 17 |
+
screen = pygame.display.set_mode((Config.col*100, (Config.row+1)*100))
|
| 18 |
+
|
| 19 |
+
ai_turn = True
|
| 20 |
+
game_end = False
|
| 21 |
+
while True:
|
| 22 |
+
draw_board(screen, board.board)
|
| 23 |
+
draw_winning_line(screen, board.winning_start, board.winning_end)
|
| 24 |
+
|
| 25 |
+
# render(board.board)
|
| 26 |
+
if ai_turn and not game_end:
|
| 27 |
+
# print("Getting move from AI...")
|
| 28 |
+
act = get_move_for_bot(board, model, Config.tree_iter)
|
| 29 |
+
# print(f"AI moved in column {act}")
|
| 30 |
+
board, win = board.drop_piece(act)
|
| 31 |
+
|
| 32 |
+
if win is not None:
|
| 33 |
+
print("AI has WON")
|
| 34 |
+
print("Board \n")
|
| 35 |
+
print(board)
|
| 36 |
+
|
| 37 |
+
print("Winner is...", win)
|
| 38 |
+
game_end = True
|
| 39 |
+
|
| 40 |
+
ai_turn = False
|
| 41 |
+
draw_board(screen, board.board)
|
| 42 |
+
pygame.display.update()
|
| 43 |
+
|
| 44 |
+
for event in pygame.event.get():
|
| 45 |
+
if event.type == pygame.QUIT:
|
| 46 |
+
sys.exit()
|
| 47 |
+
|
| 48 |
+
if event.type == pygame.MOUSEBUTTONDOWN and not game_end:
|
| 49 |
+
posx = event.pos[0]
|
| 50 |
+
act = posx//100
|
| 51 |
+
board, win = board.drop_piece(act)
|
| 52 |
+
ai_turn = True
|
| 53 |
+
|
| 54 |
+
if win is not None:
|
| 55 |
+
print("Human has Won")
|
| 56 |
+
print("Board \n")
|
| 57 |
+
print(board)
|
| 58 |
+
game_end = True
|
| 59 |
+
|
| 60 |
+
if event.type == pygame.MOUSEMOTION and not game_end:
|
| 61 |
+
pygame.draw.rect(screen, (0, 0, 0), (0, 0, 700, 100))
|
| 62 |
+
posx = event.pos[0]
|
| 63 |
+
|
| 64 |
+
# If ai is turn 1 then player's turn is second
|
| 65 |
+
if board.player_1 == -1:
|
| 66 |
+
pygame.draw.circle(screen, (230,230,20), (posx, int(100//2)), 50)
|
| 67 |
+
else:
|
| 68 |
+
pygame.draw.circle(screen, (52, 186, 235), (posx, int(100//2)), 50)
|
| 69 |
+
|
| 70 |
+
pygame.display.update()
|
| 71 |
+
|
| 72 |
+
if __name__ == "__main__":
|
| 73 |
+
model = Model(
|
| 74 |
+
n_action = Config.n_action,
|
| 75 |
+
num_hidden = Config.num_hidden,
|
| 76 |
+
num_resblock = Config.num_res_block,
|
| 77 |
+
rate = Config.rate,
|
| 78 |
+
row = Config.row,
|
| 79 |
+
col = Config.col,
|
| 80 |
+
device = Config.device
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# This is LR = .01 model
|
| 84 |
+
# model_path = './Models/C4GruhitSPatel/FullBuffer5x5V1/TargetModel_500.pt'
|
| 85 |
+
|
| 86 |
+
# This is LR = .001 model
|
| 87 |
+
model_path = "./Models/C4GruhitML/C4CyclicLRV3/TargetModel_500.pt"
|
| 88 |
+
model.load_state_dict(torch.load(model_path))
|
| 89 |
+
model.eval()
|
| 90 |
+
|
| 91 |
+
play_game(model)
|
| 92 |
+
# print("Model Loaded")
|
mcts.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from model import Model
|
| 2 |
+
from typing import Union, Tuple
|
| 3 |
+
from game import Connect4
|
| 4 |
+
from config import Config
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from torch import Tensor
|
| 8 |
+
|
| 9 |
+
import numpy as np
|
| 10 |
+
|
| 11 |
+
class Node:
|
| 12 |
+
def __init__(self, state: Union[Connect4, None], model: Model, name: str):
|
| 13 |
+
|
| 14 |
+
# Current state that the node represent
|
| 15 |
+
self.state = state
|
| 16 |
+
|
| 17 |
+
# Name of the node to trace it
|
| 18 |
+
self.name = name
|
| 19 |
+
|
| 20 |
+
# A model instance that the node will use to get value and policy
|
| 21 |
+
self.model = model
|
| 22 |
+
|
| 23 |
+
# visit count
|
| 24 |
+
self.N = 0
|
| 25 |
+
|
| 26 |
+
# Intermediate reward value
|
| 27 |
+
self.W = 0
|
| 28 |
+
|
| 29 |
+
# value of the node
|
| 30 |
+
self.value = None
|
| 31 |
+
|
| 32 |
+
# Prior policy for action from this node
|
| 33 |
+
self.policy = None
|
| 34 |
+
|
| 35 |
+
# Set the winner of the current node.
|
| 36 |
+
# Node by default indicating no one has won
|
| 37 |
+
self.win = None
|
| 38 |
+
|
| 39 |
+
# Children of current node
|
| 40 |
+
self.children = {}
|
| 41 |
+
|
| 42 |
+
# valid and invalid actions that can be take from this node
|
| 43 |
+
self.valid_actions = None
|
| 44 |
+
self.invalid_actions = None
|
| 45 |
+
|
| 46 |
+
# Set the valid and invalid actions
|
| 47 |
+
self.set_valid_actions()
|
| 48 |
+
|
| 49 |
+
# Initialize the branches to the childrens
|
| 50 |
+
self.initialize_edges()
|
| 51 |
+
|
| 52 |
+
# Set the valid actions that can be taken from the state that
|
| 53 |
+
# the node represent
|
| 54 |
+
def set_valid_actions(self) -> None:
|
| 55 |
+
if self.state is not None:
|
| 56 |
+
self.valid_actions = self.state.get_valid_moves()
|
| 57 |
+
self.invalid_actions = ~self.valid_actions
|
| 58 |
+
|
| 59 |
+
# initialize the edges from this node to potential childrens
|
| 60 |
+
def initialize_edges(self) -> None:
|
| 61 |
+
if self.state is not None:
|
| 62 |
+
self.children = {}
|
| 63 |
+
for act, valid_move in enumerate(self.valid_actions):
|
| 64 |
+
if valid_move:
|
| 65 |
+
# set state as none for childrens as we do not have it
|
| 66 |
+
self.children[act] = Node(
|
| 67 |
+
state=None,
|
| 68 |
+
model=self.model,
|
| 69 |
+
name=self.name + '_' + str(act)
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
def preprocess_state(self, x:np.ndarray) -> Tensor:
|
| 73 |
+
x = torch.tensor(x, dtype=torch.float32, device=Config.device)
|
| 74 |
+
x = x.unsqueeze(0)
|
| 75 |
+
return x
|
| 76 |
+
|
| 77 |
+
# define the forward pass for the current node
|
| 78 |
+
def forward(self) -> None:
|
| 79 |
+
with torch.no_grad():
|
| 80 |
+
value, policy = self.model(self.preprocess_state(self.state.get_state()))
|
| 81 |
+
|
| 82 |
+
value = value[0, 0]
|
| 83 |
+
policy = policy[0]
|
| 84 |
+
|
| 85 |
+
# Mask the invalid actions
|
| 86 |
+
policy[self.invalid_actions] = 0.
|
| 87 |
+
|
| 88 |
+
# Prevent from all probability from turning 0
|
| 89 |
+
if policy.sum() == 0:
|
| 90 |
+
policy[self.valid_actions] = 1.
|
| 91 |
+
|
| 92 |
+
policy = policy.softmax(dim=-1)
|
| 93 |
+
|
| 94 |
+
self.value = value.detach().cpu().numpy()
|
| 95 |
+
self.policy = policy.detach().cpu().numpy()
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
# Get policy for the current node
|
| 99 |
+
def get_policy(self) -> np.ndarray:
|
| 100 |
+
if self.policy is None:
|
| 101 |
+
self.forward()
|
| 102 |
+
|
| 103 |
+
return self.policy
|
| 104 |
+
|
| 105 |
+
# Get the value associated with the node
|
| 106 |
+
def get_value(self) -> float:
|
| 107 |
+
if self.value is None:
|
| 108 |
+
self.forward()
|
| 109 |
+
|
| 110 |
+
return self.value
|
| 111 |
+
|
| 112 |
+
class MCTS_NN:
|
| 113 |
+
def __init__(self, state:Connect4, model:Model, log=None):
|
| 114 |
+
self.root = Node(state=state, model=model, name='root')
|
| 115 |
+
|
| 116 |
+
if log is not None:
|
| 117 |
+
self.log = log
|
| 118 |
+
|
| 119 |
+
# For the simulation on the Monte-carlo tree
|
| 120 |
+
def selection(self, node: Node, add_dirichlet:bool=False, iter:int=0) -> float:
|
| 121 |
+
# Get the best child of the current node
|
| 122 |
+
# self.log.write(f'\nSelecting Best child of {node.name}')
|
| 123 |
+
best_child, best_action = self.get_best_child(node, add_dirichlet, iter)
|
| 124 |
+
# self.log.write(f"Iteartion {iter} - Best Action - {best_action} - Node: {node.name}")
|
| 125 |
+
|
| 126 |
+
# If the child is a leaf node(i.e.) either is terminal or is not expanded
|
| 127 |
+
# expand that node
|
| 128 |
+
if best_child.state is None:
|
| 129 |
+
# self.log.write(f'\nExpanding node {best_child.name}')
|
| 130 |
+
val = self.expolore_and_expand(parent=node, child=best_child, action=best_action, iter=iter)
|
| 131 |
+
|
| 132 |
+
# If the node is already expanded than traverse that node further
|
| 133 |
+
else:
|
| 134 |
+
# As per paper only add dirichlet noise for root node's
|
| 135 |
+
# child selection and not later on
|
| 136 |
+
# self.log.write(f'\nSelecting node further on {best_child.name}')
|
| 137 |
+
val = self.selection(node=best_child, add_dirichlet=False, iter=iter)
|
| 138 |
+
|
| 139 |
+
node.N += 1
|
| 140 |
+
node.W += val
|
| 141 |
+
|
| 142 |
+
return -val
|
| 143 |
+
|
| 144 |
+
# Expore and expand the tree
|
| 145 |
+
def expolore_and_expand(self, parent: Node, child: None, action: int, iter=0) -> float:
|
| 146 |
+
# self.log.write(f'\n<========== Explore or Expand Iteration {iter} ==========>')
|
| 147 |
+
# Check if the current state is a terminal state
|
| 148 |
+
if child.win is None:
|
| 149 |
+
# It is not expanded and is not terminal
|
| 150 |
+
# Perform the action for the parent state to get the next state
|
| 151 |
+
next_state, win = parent.state.drop_piece(action)
|
| 152 |
+
|
| 153 |
+
# First check if somone won in this next state
|
| 154 |
+
if win is not None:
|
| 155 |
+
val = -1 if win == parent.state.player_1 else 1
|
| 156 |
+
child.win = win
|
| 157 |
+
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [Winner Found]')
|
| 158 |
+
# self.log.write(f'\nWinner in that state {win} - child.Value is {val}')
|
| 159 |
+
# self.log.write(f'\nWinning Child in state {child.name}: state\n{next_state}\n')
|
| 160 |
+
# self.log.write('='*100)
|
| 161 |
+
# self.log.write('\n')
|
| 162 |
+
|
| 163 |
+
# else check if the next state results in draw
|
| 164 |
+
elif next_state.is_draw():
|
| 165 |
+
# 0 value if no one has won in the state
|
| 166 |
+
val = 0
|
| 167 |
+
|
| 168 |
+
# 0 for win means no one won
|
| 169 |
+
child.win = 0
|
| 170 |
+
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1}')
|
| 171 |
+
# self.log.write(f'\nDraw Child in state {child.name}: state\n{next_state}\n')
|
| 172 |
+
# self.log.write('='*100)
|
| 173 |
+
# self.log.write('\n')
|
| 174 |
+
|
| 175 |
+
# if the next_state is not winning nor it is draw
|
| 176 |
+
# then expand it normally
|
| 177 |
+
else:
|
| 178 |
+
# If no one is winning yet then get the value for the current
|
| 179 |
+
# state from the child's mode and set it
|
| 180 |
+
child.state = next_state
|
| 181 |
+
child.set_valid_actions()
|
| 182 |
+
child.initialize_edges()
|
| 183 |
+
val = child.get_value()
|
| 184 |
+
# self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [No Winner]')
|
| 185 |
+
# self.log.write(f'\nLeaf node expanded for "{child.name}" with val {val:.5f}\n')
|
| 186 |
+
# self.log.write('='*100)
|
| 187 |
+
# self.log.write('\n')
|
| 188 |
+
|
| 189 |
+
else:
|
| 190 |
+
# If the current child represent a draw state then give value 0
|
| 191 |
+
if child.win == 0:
|
| 192 |
+
# self.log.write(f'\nTerminal DRAW state reached for child {child.name}\n')
|
| 193 |
+
# self.log.write('='*100)
|
| 194 |
+
# self.log.write('\n')
|
| 195 |
+
val = 0
|
| 196 |
+
|
| 197 |
+
# If the winner in child node was the player who played a move
|
| 198 |
+
# in the parent node then set -1 as value as it means that
|
| 199 |
+
# the player in child node has lost
|
| 200 |
+
elif child.win == parent.state.player_1:
|
| 201 |
+
# self.log.write(f'\nTerminal Parent Winning state reached for child {child.name}\n')
|
| 202 |
+
# self.log.write('='*100)
|
| 203 |
+
# self.log.write('\n')
|
| 204 |
+
val = -1
|
| 205 |
+
|
| 206 |
+
# if the winner of child node is the same as the player of child node
|
| 207 |
+
# then provide value of +1
|
| 208 |
+
else:
|
| 209 |
+
# self.log.write(f'\nTerminal child Winning state reached for child {child.name}\n')
|
| 210 |
+
# self.log.write('='*100)
|
| 211 |
+
# self.log.write('\n')
|
| 212 |
+
val = 1
|
| 213 |
+
|
| 214 |
+
# Update the visit count and intermidiate reward of child node
|
| 215 |
+
child.N += 1
|
| 216 |
+
child.W += val
|
| 217 |
+
|
| 218 |
+
# Return negative of val because the player in parent node will be
|
| 219 |
+
# the opposite player from the current node. Hence what is good
|
| 220 |
+
# for current node's player should be bad for the parent node's player
|
| 221 |
+
return -val
|
| 222 |
+
|
| 223 |
+
|
| 224 |
+
# Calculate the PUCT score for a node's children
|
| 225 |
+
def get_puct_score(self, parent: Node, child: Node, prior: float) -> float:
|
| 226 |
+
# PUCT is the sum of q_value of current node + the U(S, a)
|
| 227 |
+
q_value = 0
|
| 228 |
+
if child.N == 0:
|
| 229 |
+
q_value = 0
|
| 230 |
+
else:
|
| 231 |
+
# q_value = 1 - ((child.W/child.N) + 1)/2
|
| 232 |
+
q_value = -child.W/child.N
|
| 233 |
+
|
| 234 |
+
# C_puct represent the exploration constant
|
| 235 |
+
c_puct = 1
|
| 236 |
+
u_sa = c_puct * prior * (np.sqrt(parent.N))/(1+child.N)
|
| 237 |
+
return q_value + u_sa
|
| 238 |
+
|
| 239 |
+
def get_dirichlet_noise(self, node: None) -> np.ndarray:
|
| 240 |
+
num_valid_action = node.valid_actions.sum()
|
| 241 |
+
noise_vec = np.random.dirichlet([Config.DIRICHLET_ALPHA]*num_valid_action)
|
| 242 |
+
noise_arr = np.zeros((len(node.valid_actions),), dtype=noise_vec.dtype)
|
| 243 |
+
noise_arr[node.valid_actions] = noise_vec
|
| 244 |
+
return noise_arr
|
| 245 |
+
|
| 246 |
+
# Get the best child for any node
|
| 247 |
+
def get_best_child(self, node: Node, add_dirichlet: bool, iter=0) -> Tuple[Node, int]:
|
| 248 |
+
# the best node is simple the one with highest PUCT value
|
| 249 |
+
policy = node.get_policy()
|
| 250 |
+
|
| 251 |
+
if add_dirichlet:
|
| 252 |
+
noise_arr = self.get_dirichlet_noise(node)
|
| 253 |
+
policy = (1-Config.EPSILON)*policy + Config.EPSILON*noise_arr
|
| 254 |
+
|
| 255 |
+
best_puct = float('-inf')
|
| 256 |
+
best_child = None
|
| 257 |
+
best_action = None
|
| 258 |
+
# self.log.write(f'\n\n==================== Iteration {iter} ====================\n')
|
| 259 |
+
for action, child in node.children.items():
|
| 260 |
+
puct = self.get_puct_score(parent=node, child=child, prior=policy[action])
|
| 261 |
+
# self.log.write(f'{action} - PUCT: {puct:.4f} | N = {child.N} | W = {child.W:.4f} | P = {policy[action]:.4f}\n')
|
| 262 |
+
if puct > best_puct:
|
| 263 |
+
best_puct = puct
|
| 264 |
+
best_child = child
|
| 265 |
+
best_action = action
|
| 266 |
+
|
| 267 |
+
return best_child, best_action
|
| 268 |
+
|
| 269 |
+
# return the policy pie for the root node based on the visit count
|
| 270 |
+
def get_policy_pie(self, temperature:float=1):
|
| 271 |
+
actions = np.zeros((len(self.root.valid_actions),))
|
| 272 |
+
|
| 273 |
+
for action, child in self.root.children.items():
|
| 274 |
+
actions[action] = (child.N)**(1/temperature)
|
| 275 |
+
|
| 276 |
+
actions /= actions.sum()
|
| 277 |
+
|
| 278 |
+
return actions
|
| 279 |
+
|
| 280 |
+
# Traverse the tree by steping to one of the child node of root node
|
| 281 |
+
def update_root(self, action: int) -> None:
|
| 282 |
+
self.root = self.root.children[action]
|
model.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import Tensor
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.nn.functional as F
|
| 5 |
+
|
| 6 |
+
class ResNetBlock(nn.Module):
|
| 7 |
+
def __init__(self, num_hidden: int):
|
| 8 |
+
super(ResNetBlock, self).__init__()
|
| 9 |
+
|
| 10 |
+
self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False)
|
| 11 |
+
self.bn1 = nn.BatchNorm2d(num_hidden)
|
| 12 |
+
self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False)
|
| 13 |
+
self.bn2 = nn.BatchNorm2d(num_hidden)
|
| 14 |
+
|
| 15 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 16 |
+
residual = x
|
| 17 |
+
x = F.relu(self.bn1(self.conv1(x)))
|
| 18 |
+
x = self.bn2(self.conv2(x))
|
| 19 |
+
x += residual
|
| 20 |
+
x = F.relu(x)
|
| 21 |
+
|
| 22 |
+
return x
|
| 23 |
+
|
| 24 |
+
class DropoutBlock(nn.Module):
|
| 25 |
+
def __init__(self, in_units: int, out_units: int, rate: float):
|
| 26 |
+
super(DropoutBlock, self).__init__()
|
| 27 |
+
self.model = nn.Sequential(
|
| 28 |
+
nn.Linear(in_units, out_units),
|
| 29 |
+
nn.BatchNorm1d(out_units),
|
| 30 |
+
nn.ReLU(),
|
| 31 |
+
nn.Dropout(rate)
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
def forward(self, x: Tensor) -> Tensor:
|
| 35 |
+
return self.model(x)
|
| 36 |
+
|
| 37 |
+
class Model(nn.Module):
|
| 38 |
+
def __init__(self, n_action: int, num_hidden: int, num_resblock:int,
|
| 39 |
+
rate:float, row:int, col: int, device: str):
|
| 40 |
+
super(Model, self).__init__()
|
| 41 |
+
|
| 42 |
+
# Bottom layer
|
| 43 |
+
self.initial_block = nn.Sequential(
|
| 44 |
+
nn.Conv2d(4, num_hidden, kernel_size=3, padding=1),
|
| 45 |
+
nn.BatchNorm2d(num_hidden),
|
| 46 |
+
nn.ReLU()
|
| 47 |
+
).to(device)
|
| 48 |
+
|
| 49 |
+
self.res_blocks = nn.Sequential(
|
| 50 |
+
*[ResNetBlock(num_hidden) for _ in range(num_resblock)]
|
| 51 |
+
).to(device)
|
| 52 |
+
|
| 53 |
+
self.dropout_model = nn.Sequential(
|
| 54 |
+
DropoutBlock(num_hidden*row*col, 200, rate),
|
| 55 |
+
DropoutBlock(200, 100, rate)
|
| 56 |
+
)
|
| 57 |
+
|
| 58 |
+
self.model = nn.Sequential(
|
| 59 |
+
self.initial_block,
|
| 60 |
+
self.res_blocks,
|
| 61 |
+
nn.Flatten(),
|
| 62 |
+
self.dropout_model
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
self.policy_head = nn.Sequential(
|
| 66 |
+
nn.Linear(100, 100),
|
| 67 |
+
nn.ReLU(),
|
| 68 |
+
nn.Linear(100, n_action),
|
| 69 |
+
).to(device)
|
| 70 |
+
|
| 71 |
+
self.value_head = nn.Sequential(
|
| 72 |
+
nn.Linear(100, 100),
|
| 73 |
+
nn.ReLU(),
|
| 74 |
+
nn.Linear(100, 1),
|
| 75 |
+
nn.Tanh()
|
| 76 |
+
).to(device)
|
| 77 |
+
|
| 78 |
+
self.to(device)
|
| 79 |
+
|
| 80 |
+
self.device = device
|
| 81 |
+
|
| 82 |
+
# Losses
|
| 83 |
+
# Mean Square Error for minimizing the difference between estimated value and target value
|
| 84 |
+
self.mse_loss = nn.MSELoss()
|
| 85 |
+
|
| 86 |
+
# Cross entropy loss to evaluate the correct policy as compared to target policy
|
| 87 |
+
self.ce_loss = nn.CrossEntropyLoss()
|
| 88 |
+
|
| 89 |
+
def forward(self, x):
|
| 90 |
+
x = self.model(x)
|
| 91 |
+
value = self.value_head(x)
|
| 92 |
+
policy = self.policy_head(x)
|
| 93 |
+
|
| 94 |
+
return value, policy
|
| 95 |
+
|
| 96 |
+
# Perform the loss calculation
|
| 97 |
+
def get_loss(self, pred_val, pred_policy, true_val, true_policy):
|
| 98 |
+
val_loss = self.mse_loss(pred_val, true_val)
|
| 99 |
+
policy_loss = self.ce_loss(pred_policy, true_policy)
|
| 100 |
+
|
| 101 |
+
final_loss = val_loss + policy_loss
|
| 102 |
+
return {
|
| 103 |
+
'total_loss': final_loss,
|
| 104 |
+
'value_loss': val_loss,
|
| 105 |
+
'policy_loss': policy_loss
|
| 106 |
+
}
|
requirement.txt
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
torch==2.2.2
|
| 2 |
+
tqdm==4.66.2
|
| 3 |
+
pygame==2.5.2
|
| 4 |
+
fastapi==0.110.1
|
| 5 |
+
pydantic==2.6.4
|
| 6 |
+
uvicorn==0.29.0
|
| 7 |
+
numpy==1.26.4
|
trainer.py
ADDED
|
@@ -0,0 +1,85 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
from torch import optim
|
| 3 |
+
from torch.utils.tensorboard import SummaryWriter
|
| 4 |
+
|
| 5 |
+
import tqdm
|
| 6 |
+
import numpy as np
|
| 7 |
+
|
| 8 |
+
from model import Model
|
| 9 |
+
from buffer import Buffer
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
class Trainer:
|
| 13 |
+
def __init__(self, model: Model, buffer: Buffer, base_lr:float = 0.001,
|
| 14 |
+
weight_decay=1e-4, device:str='cpu'):
|
| 15 |
+
self.main_model = model
|
| 16 |
+
self.main_buffer = buffer
|
| 17 |
+
self.global_step = 0
|
| 18 |
+
self.device = device
|
| 19 |
+
|
| 20 |
+
# optimizer
|
| 21 |
+
self.optimizer = optim.SGD(
|
| 22 |
+
self.main_model.parameters(),
|
| 23 |
+
lr = base_lr,
|
| 24 |
+
weight_decay = weight_decay,
|
| 25 |
+
momentum = 0.9
|
| 26 |
+
)
|
| 27 |
+
|
| 28 |
+
# self.scheduler = optim.lr_scheduler.CyclicLR(
|
| 29 |
+
# self.optimizer,
|
| 30 |
+
# base_lr = base_lr,
|
| 31 |
+
# max_lr = 0.1
|
| 32 |
+
# )
|
| 33 |
+
|
| 34 |
+
# Tensorboard summary writer
|
| 35 |
+
self.writer = SummaryWriter()
|
| 36 |
+
|
| 37 |
+
def transfer_buffer(self, buffer) -> None:
|
| 38 |
+
for state, value, policy in zip(buffer.state, buffer.value, buffer.policy):
|
| 39 |
+
self.main_buffer.store_experience(
|
| 40 |
+
state = state,
|
| 41 |
+
value = value,
|
| 42 |
+
policy = policy
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
def reset_buffer(self) -> None:
|
| 46 |
+
self.main_buffer.reset()
|
| 47 |
+
|
| 48 |
+
# learn from the buffer
|
| 49 |
+
def learn(self, state: np.ndarray, value: np.ndarray, policy: np.ndarray) -> float:
|
| 50 |
+
|
| 51 |
+
state = torch.tensor(state, dtype=torch.float32, device=self.device)
|
| 52 |
+
value = torch.tensor(value, dtype=torch.float32, device=self.device).unsqueeze(-1)
|
| 53 |
+
policy = torch.tensor(policy, dtype=torch.float32, device=self.device)
|
| 54 |
+
|
| 55 |
+
pred_val, pred_policy = self.main_model(state)
|
| 56 |
+
|
| 57 |
+
self.optimizer.zero_grad()
|
| 58 |
+
loss = self.main_model.get_loss(pred_val, pred_policy, value, policy)
|
| 59 |
+
loss.backward()
|
| 60 |
+
self.optimizer.step()
|
| 61 |
+
|
| 62 |
+
return loss.detach().cpu().numpy()
|
| 63 |
+
|
| 64 |
+
# Training loop for the model
|
| 65 |
+
def train_model(self, epochs: int, batch_size: int):
|
| 66 |
+
|
| 67 |
+
train_steps = np.ceil(len(self.main_buffer) / batch_size).astype(np.int32)
|
| 68 |
+
|
| 69 |
+
# perform the training
|
| 70 |
+
for epoch in range(epochs):
|
| 71 |
+
for state, value, policy in tqdm(self.main_buffer.sample(batch_size), total=train_steps, desc=f'Epoch:{epoch+1}'):
|
| 72 |
+
loss = self.learn(state, value, policy)
|
| 73 |
+
self.writer.add_scalar("loss", loss, self.global_step)
|
| 74 |
+
self.global_step += 1
|
| 75 |
+
|
| 76 |
+
self.writer.flush()
|
| 77 |
+
|
| 78 |
+
# close the writer
|
| 79 |
+
def close_writer(self):
|
| 80 |
+
self.writer.close()
|
| 81 |
+
|
| 82 |
+
# Save the model
|
| 83 |
+
def save_model(self, step: int):
|
| 84 |
+
torch.save(self.main_model.state_dict(), f'TargetModel_{step}.pt')
|
| 85 |
+
torch.save(self.optimizer.state_dict(), f'Optimizer_{step}.pt')
|
view_board.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pygame
|
| 2 |
+
import numpy as np
|
| 3 |
+
import sys
|
| 4 |
+
from typing import Tuple, Union
|
| 5 |
+
from config import Config
|
| 6 |
+
|
| 7 |
+
SQUARESIZE = 100
|
| 8 |
+
|
| 9 |
+
def draw_board(screen, board):
|
| 10 |
+
COLUMN_COUNT = Config.col
|
| 11 |
+
ROW_COUNT = Config.row
|
| 12 |
+
SQUARESIZE = 100
|
| 13 |
+
RADIUS = int(SQUARESIZE/2 - 5)
|
| 14 |
+
|
| 15 |
+
BLUE = (52, 186, 235)
|
| 16 |
+
GREY = (70, 71, 70)
|
| 17 |
+
WHITE = (255,255,255)
|
| 18 |
+
YELLOW = (230,230,20)
|
| 19 |
+
|
| 20 |
+
width = COLUMN_COUNT * SQUARESIZE
|
| 21 |
+
height = (ROW_COUNT+1) * SQUARESIZE
|
| 22 |
+
|
| 23 |
+
size = (width, height)
|
| 24 |
+
board = np.flip(board,0)
|
| 25 |
+
for c in range(COLUMN_COUNT):
|
| 26 |
+
for r in range(ROW_COUNT):
|
| 27 |
+
pygame.draw.rect(screen, GREY, (c*SQUARESIZE, r*SQUARESIZE+SQUARESIZE, SQUARESIZE, SQUARESIZE))
|
| 28 |
+
pygame.draw.circle(screen, WHITE, (int(c*SQUARESIZE+SQUARESIZE/2), int(r*SQUARESIZE+SQUARESIZE+SQUARESIZE/2)), RADIUS)
|
| 29 |
+
|
| 30 |
+
for c in range(COLUMN_COUNT):
|
| 31 |
+
for r in range(ROW_COUNT):
|
| 32 |
+
if board[r][c] == 1:
|
| 33 |
+
pygame.draw.circle(screen, BLUE, (int(c*SQUARESIZE+SQUARESIZE/2), height-int(r*SQUARESIZE+SQUARESIZE/2)), RADIUS)
|
| 34 |
+
elif board[r][c] == -1:
|
| 35 |
+
pygame.draw.circle(screen, YELLOW, (int(c*SQUARESIZE+SQUARESIZE/2), height-int(r*SQUARESIZE+SQUARESIZE/2)), RADIUS)
|
| 36 |
+
|
| 37 |
+
def draw_winning_line(screen, start_pos:Union[None, Tuple[int, int]], end_pos:Union[None, Tuple[int, int]]):
|
| 38 |
+
if start_pos is None or end_pos is None:
|
| 39 |
+
return
|
| 40 |
+
|
| 41 |
+
offset = SQUARESIZE//2
|
| 42 |
+
start_line = (SQUARESIZE*start_pos[0]+1+offset, SQUARESIZE*(start_pos[1]+1)+offset)
|
| 43 |
+
end_line = (SQUARESIZE*end_pos[0]+offset, SQUARESIZE*(end_pos[1]+1)+offset)
|
| 44 |
+
|
| 45 |
+
# print("Start pos: ", start_pos)
|
| 46 |
+
# print("End pos: ", end_pos)
|
| 47 |
+
# print("Start line: ", start_line)
|
| 48 |
+
# print("End Line: ", end_line)
|
| 49 |
+
pygame.draw.line(screen, (255, 0, 0), start_line, end_line, 10)
|
| 50 |
+
|
| 51 |
+
def render(board):
|
| 52 |
+
pygame.init()
|
| 53 |
+
screen = pygame.display.set_mode((700,700))
|
| 54 |
+
|
| 55 |
+
while True:
|
| 56 |
+
for event in pygame.event.get():
|
| 57 |
+
if event.type == pygame.QUIT:
|
| 58 |
+
sys.exit()
|
| 59 |
+
draw_board(screen,board)
|
| 60 |
+
pygame.display.update()
|