Gruhit Patel commited on
Commit
1fab54b
·
verified ·
1 Parent(s): 5b133f3

init-backend

Browse files
Files changed (12) hide show
  1. agent.py +72 -0
  2. arena.py +83 -0
  3. buffer.py +50 -0
  4. config.py +47 -0
  5. game.py +131 -0
  6. main.py +63 -0
  7. main2.py +92 -0
  8. mcts.py +282 -0
  9. model.py +106 -0
  10. requirement.txt +7 -0
  11. trainer.py +85 -0
  12. view_board.py +60 -0
agent.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import Model
2
+ from buffer import Buffer
3
+ from game import Connect4
4
+ from mcts import MCTS_NN
5
+
6
+ import numpy as np
7
+ from typing import Tuple, List
8
+
9
+ class Agent:
10
+ def __init__(self, row:int, col:int, n_action: int, obs_shape: Tuple[int, int, int],
11
+ model: Model, iteration: int, temperature:float):
12
+
13
+ self.row = row
14
+ self.col = col
15
+ self.n_action = n_action
16
+ self.obs_shape = obs_shape
17
+ self.iteration = iteration
18
+ self.temperature = temperature
19
+
20
+ # Create buffer instance
21
+ self.buffer = Buffer(n_action=self.n_action, obs_shape=self.obs_shape)
22
+
23
+ # Target model instance
24
+ self.target_model = model
25
+
26
+ # Reset the MCTS class instance and buffer
27
+ def reset(self, state: Connect4, reset_buffer: bool = False) -> None:
28
+ # Reset the state of the Monte-carlo tree search instance
29
+ self.mcts = MCTS_NN(state=state, model=self.target_model)
30
+
31
+ # Reset the buffer
32
+ def reset_buffer(self) -> None:
33
+ self.buffer.reset()
34
+
35
+ # Get the policy from mcts simulation
36
+ def perform_mcts(self) -> np.ndarray:
37
+ for _ in range(self.iteration):
38
+ self.mcts.selection(self.mcts.root, add_dirichlet=True)
39
+
40
+ policy = self.mcts.get_policy_pie(self.temperature)
41
+
42
+ return policy
43
+
44
+ # Get an action for any state
45
+ def get_action(self) -> int:
46
+ policy = self.perform_mcts()
47
+ action = np.random.choice(self.n_action, p=policy)
48
+ return action, policy
49
+
50
+ # This method updates the buffer and send it to the buffer object
51
+ def update_buffer(self, episodic_buffer: List)->None:
52
+ # Get the last index of the episodic buffer
53
+ idx = len(episodic_buffer) - 1
54
+
55
+ # Always the last state will have value 1 as it would be the winning move
56
+ value = 1
57
+ while idx >= 0:
58
+ episodic_buffer[idx][1] = value
59
+ value *= -1 # For parent the value is negative
60
+ idx -= 1 # Go to the previous experience tuple
61
+
62
+ for state, value, policy in episodic_buffer:
63
+ self.buffer.store_experience(
64
+ state = state,
65
+ value = value,
66
+ policy = policy
67
+ )
68
+
69
+ # Update the root to set it to one of its child node
70
+ # based on the actio taken in the above method `get_action()`
71
+ def update(self, action: int) -> None:
72
+ self.mcts.update_root(action)
arena.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from game import Connect4
2
+ from agent import Agent
3
+ from model import Model
4
+ from mcts import MCTS_NN
5
+
6
+ from typing import Union
7
+ import tqdm
8
+ import numpy as np
9
+
10
+ def play_selfgames(agent: Agent, training_games: int):
11
+
12
+ for _ in tqdm(range(training_games)):
13
+ board = Connect4(row=agent.row, col=agent.col)
14
+ agent.reset(state = board)
15
+
16
+ # a buffer list to store the transition of current episode
17
+ episodic_buffer = []
18
+
19
+ while not board.is_win() and not board.is_draw():
20
+ # While getting the action the search is performed
21
+ # also the experience is stored in it
22
+ action, policy = agent.get_action()
23
+ episodic_buffer.append([
24
+ board.get_state(),
25
+ board.player_1,
26
+ policy
27
+ ])
28
+
29
+ board, _ = board.drop_piece(action)
30
+
31
+ # Update the root node of MCTS to one of its child node
32
+ agent.update(action)
33
+
34
+ # When the episode is compelted update the buffer
35
+ agent.update_buffer(episodic_buffer)
36
+
37
+
38
+ def get_move_for_bot(state: Connect4, model: Model, tree_iters: int, random_move: bool = False) -> int:
39
+ mcts = MCTS_NN(state = state, model = model)
40
+
41
+ for _ in range(tree_iters):
42
+ mcts.selection(mcts.root, random_move)
43
+
44
+ policy = mcts.get_policy_pie()
45
+ act = np.argmax(policy)
46
+
47
+ return act
48
+
49
+ def play_game_against_bot(bot1: Model, bot2: Model, tree_iters:int) -> Union[None, int]:
50
+ board = Connect4()
51
+ player_1 = True
52
+
53
+ # In function bot1 will be always datagen model to make 1st move
54
+ # bot2 will be main_model to make 2nd move
55
+ # We randomly allow them to make first move based for 50% of time
56
+ flip = False
57
+ if np.random.uniform() < 0.5:
58
+ flip = True
59
+ (bot1, bot2) = (bot2, bot1)
60
+ print("Bot has been flipped")
61
+
62
+ while not board.is_win() and not board.is_draw():
63
+ if player_1:
64
+ act = get_move_for_bot(board, model=bot1, tree_iters=tree_iters)
65
+ player_1 = False
66
+ else:
67
+ act = get_move_for_bot(board, model=bot2, tree_iters=tree_iters)
68
+ player_1 = True
69
+
70
+ board, win = board.drop_piece(act)
71
+ print(board)
72
+
73
+ # Here returning
74
+ # 0 - draw
75
+ # 1 - datagen won
76
+ # -1 - main_model won
77
+ # Hence when flipped we have to handle the values accordingly
78
+ if flip:
79
+ # Thus if we have flipped then main_model who is player 1 if its has won
80
+ # then we want to return -1 for it and vice-versa
81
+ return 0 if win == None else win*-1
82
+ else:
83
+ return 0 if win == None else win
buffer.py ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import numpy as np
3
+
4
+ class Buffer:
5
+ def __init__(self, n_action: int, obs_shape: Tuple[int, int, int]):
6
+ self.n_action = n_action
7
+ self.obs_shape = obs_shape
8
+ self.mem_size = 0
9
+
10
+ # Creating empty lists for storing value. Provide dynamicness
11
+ self.state = []
12
+ self.value = []
13
+ self.policy = []
14
+
15
+ def store_experience(self, state: np.ndarray, value: float, policy: np.ndarray):
16
+ self.state.append(state)
17
+ self.value.append(value)
18
+ self.policy.append(policy)
19
+
20
+ self.mem_size += 1
21
+
22
+ def sample(self, batch_size: int) -> Tuple[
23
+ np.ndarray,
24
+ np.ndarray,
25
+ np.ndarray
26
+ ]: # type: ignore
27
+ # shuffle the memmory
28
+ np.random.shuffle(self.state)
29
+ np.random.shuffle(self.value)
30
+ np.random.shuffle(self.policy)
31
+
32
+ for start_idx in range(0, self.mem_size, batch_size):
33
+ end_idx = min(start_idx+batch_size, self.mem_size)
34
+ s = self.state[start_idx:end_idx]
35
+ v = self.value[start_idx:end_idx]
36
+ p = self.policy[start_idx:end_idx]
37
+
38
+ yield (np.array(s), np.array(v), np.array(p))
39
+
40
+ # Reset the the buffer to store new experience
41
+ def reset(self) -> None:
42
+ self.state = []
43
+ self.value = []
44
+ self.policy = []
45
+
46
+ self.mem_size = 0
47
+
48
+ # Return the length of the buffer
49
+ def __len__(self):
50
+ return self.mem_size
config.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Tuple
2
+ import torch
3
+
4
+ class Config:
5
+ # Board
6
+ row:int = 6
7
+ col:int = 7
8
+
9
+ # Neural Network
10
+ num_hidden:int = 64
11
+ num_res_block:int = 4
12
+ rate: float = 0.3
13
+ obs_shape: Tuple[int, int, int] = (4, row, col)
14
+ n_action: int = col
15
+ device: str = 'cuda:0' if torch.cuda.is_available() else 'cpu'
16
+ checkpoint_path: str = "../Models/azv3.pt"
17
+
18
+ # Optimizer
19
+ base_lr: float = 0.01
20
+ weight_decay: float = 1e-4
21
+
22
+ # Monte-carlo tree search
23
+ temperature = 1.0
24
+ tree_iter = 100
25
+
26
+ # Training
27
+ selfplay_games:int = 50
28
+ epoch:int = 10
29
+ batch_size:int = 128
30
+
31
+ # Tournament
32
+ eval_games: int = 10
33
+
34
+ # How much elo rating should be given per winning
35
+ k: int = 10
36
+
37
+ # model update threshold
38
+ threshold: float = 0.55
39
+
40
+ # How many time you want to play selfplay games and train model
41
+ total_iters:int = 40
42
+
43
+ # Parallel_games
44
+ parallel_run: int = 4
45
+
46
+ DIRICHLET_ALPHA: float = 0.3 # Avg legal move / 75% of total move
47
+ EPSILON: float = 0.25
game.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from copy import deepcopy
2
+ import numpy as np
3
+ from typing import Union
4
+
5
+ class Connect4:
6
+ def __init__(self, board:'Connect4'=None, row:int = 6, col:int =7):
7
+ self.row = row
8
+ self.col = col
9
+ self.player_1 = 1
10
+ self.player_2 = -1
11
+
12
+ self.board = np.zeros((self.row, self.col))
13
+
14
+ self.winning_start = None
15
+ self.winning_end = None
16
+
17
+
18
+ if board is not None:
19
+ self.__dict__ = deepcopy(board.__dict__)
20
+
21
+ def drop_piece(self, action: int) -> 'Connect4':
22
+
23
+ board = Connect4(board=self)
24
+
25
+ # Find the row in that column which is valid to drop piece
26
+ valid_row_idx = sum(board.board[:, action] == 0) - 1
27
+ board.board[valid_row_idx, action] = self.player_1
28
+ (board.player_1, board.player_2) = (self.player_2, self.player_1)
29
+
30
+ return board, board.is_win()
31
+
32
+ # Get the encoded state for the board
33
+ def get_state(self) -> np.ndarray:
34
+ # Create a layer to state the player turn
35
+ turn = np.ones_like(self.board) if self.player_1 == 1 else np.zeros_like(self.board)
36
+ enc_state = np.stack(
37
+ (self.board == -1, self.board == 0, self.board == 1, turn)
38
+ ).astype(np.int32)
39
+
40
+ return enc_state
41
+
42
+ # check if the board results in a draw state
43
+ def is_draw(self):
44
+ return (self.board != 0).all()
45
+
46
+ def is_win(self) -> Union[None, int]:
47
+ # Initially no one is winner
48
+ winner = None
49
+
50
+ # Check for columns
51
+ if self.col_win():
52
+ winner = self.player_2
53
+ # Check for rows
54
+ elif self.row_win():
55
+ winner = self.player_2
56
+ # Check for diagonals
57
+ elif self.diag_win():
58
+ winner = self.player_2
59
+ return winner
60
+
61
+ # Check for column win
62
+ def col_win(self) -> bool:
63
+ # Iterate over each column
64
+ for c in range(self.col):
65
+ # for 4 consequtive rows
66
+ for r in range(self.row-3):
67
+ # if the the all 4 element are of player who made move then its win
68
+ if sum(self.board[r:r+4, c] == self.player_2) == 4:
69
+ self.winning_start = (c, r)
70
+ self.winning_end = (c, r+3)
71
+ return True
72
+
73
+ return False
74
+
75
+ # check for win in row
76
+ def row_win(self) -> bool:
77
+ # Iterate over each row
78
+ for r in range(self.row):
79
+ # For 4 consequtive cols
80
+ for c in range(self.col-3):
81
+ # If all of 4 elements are of player who made move then its win
82
+ if sum(self.board[r, c:c+4] == self.player_2) == 4:
83
+ self.winning_start = (c, r)
84
+ self.winning_end = (c+3, r)
85
+ return True
86
+
87
+ return False
88
+
89
+ # check for win in diagonal
90
+ def diag_win(self) -> bool:
91
+ # For a window of 4x4 if the main diag or other diag has
92
+ # same disc of player who made move then its a win
93
+ for r in range(self.row-3):
94
+ for c in range(self.col-3):
95
+ # Get a window of size 4x4
96
+ window = self.board[r:r+4, c:c+4]
97
+
98
+ # If all 4 element of main diag(/) is player who made move then its win
99
+ if sum(np.diag(window) == self.player_2) == 4:
100
+ self.winning_start = (c+3, r)
101
+ self.winning_end = (c, r+3)
102
+ # print("WinningMain Diag: ", self.winning_start, " - ", self.winning_end)
103
+ return True
104
+
105
+ # If all 4 element of other diag(\) is player who made move then its win
106
+ if sum(np.diag(window[:, ::-1]) == self.player_2) == 4:
107
+ self.winning_start = (c, r)
108
+ self.winning_end = (c+3, r+3)
109
+ # print("WinningMain Diag: ", self.winning_start, " - ", self.winning_end)
110
+ return True
111
+
112
+ return False
113
+
114
+ # get a list of valid move that can be played by the current player
115
+ def get_valid_moves(self) -> np.array:
116
+ valid_cols = [False]*self.col
117
+ for c in range(self.col):
118
+ if self.board[0, c] == 0:
119
+ valid_cols[c] = True
120
+
121
+ return np.array(valid_cols, dtype=bool)
122
+
123
+ def __str__(self) -> str:
124
+ print_str = ""
125
+ for r in range(self.row):
126
+ for c in range(self.col):
127
+ print_str += f"{self.board[r, c]:>3.0f}"
128
+
129
+ print_str += "\n"
130
+
131
+ return print_str
main.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from fastapi import FastAPI
2
+ from fastapi.middleware.cors import CORSMiddleware
3
+ from game import Connect4
4
+ from model import Model
5
+ from config import Config
6
+ from pydantic import BaseModel
7
+ from typing import List, Union
8
+ import numpy as np
9
+ from arena import get_move_for_bot
10
+ import torch
11
+
12
+ class Request(BaseModel):
13
+ board: List[List[int]]
14
+ currentPlayer: str
15
+ randomMoves: Union[None, bool]
16
+ mctsIterations: Union[None, int]
17
+
18
+ # Create an application instance
19
+ app = FastAPI()
20
+ app.add_middleware(
21
+ CORSMiddleware,
22
+ allow_origins=["*"],
23
+ allow_credentials=True,
24
+ allow_methods=["GET", "POST"],
25
+ allow_headers=["*"]
26
+ )
27
+
28
+ # Create the model
29
+ model = Model(
30
+ n_action = Config.n_action,
31
+ num_hidden = Config.num_hidden,
32
+ num_resblock = Config.num_res_block,
33
+ rate = Config.rate,
34
+ row = Config.row,
35
+ col = Config.col,
36
+ device = Config.device
37
+ )
38
+ model.load_state_dict(torch.load(Config.checkpoint_path))
39
+ model.eval()
40
+
41
+ @app.get("/")
42
+ def root():
43
+ return {"message": "This is a temporary response"}
44
+
45
+ @app.post("/get_move")
46
+ def get_move(req: Request):
47
+ global model
48
+ board_arr = np.array(req.board)
49
+ board = Connect4()
50
+ board.board = board_arr
51
+
52
+ if req.currentPlayer == "yellow":
53
+ (board.player_1, board.player_2) = (board.player_2, board.player_1)
54
+
55
+ # TODO: change the tree_iter to req.parameters
56
+ act = get_move_for_bot(
57
+ state = board,
58
+ model = model,
59
+ tree_iters = req.mctsIterations,
60
+ random_move = req.randomMoves
61
+ )
62
+
63
+ return {'move': int(act)}
main2.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import Model
2
+ from config import Config
3
+ from arena import get_move_for_bot
4
+ from game import Connect4
5
+ import pygame
6
+ from view_board import draw_board, draw_winning_line
7
+ import sys
8
+ import torch
9
+
10
+ def play_game(model: Model):
11
+ board = Connect4(
12
+ row = Config.row,
13
+ col = Config.col
14
+ )
15
+
16
+ pygame.init()
17
+ screen = pygame.display.set_mode((Config.col*100, (Config.row+1)*100))
18
+
19
+ ai_turn = True
20
+ game_end = False
21
+ while True:
22
+ draw_board(screen, board.board)
23
+ draw_winning_line(screen, board.winning_start, board.winning_end)
24
+
25
+ # render(board.board)
26
+ if ai_turn and not game_end:
27
+ # print("Getting move from AI...")
28
+ act = get_move_for_bot(board, model, Config.tree_iter)
29
+ # print(f"AI moved in column {act}")
30
+ board, win = board.drop_piece(act)
31
+
32
+ if win is not None:
33
+ print("AI has WON")
34
+ print("Board \n")
35
+ print(board)
36
+
37
+ print("Winner is...", win)
38
+ game_end = True
39
+
40
+ ai_turn = False
41
+ draw_board(screen, board.board)
42
+ pygame.display.update()
43
+
44
+ for event in pygame.event.get():
45
+ if event.type == pygame.QUIT:
46
+ sys.exit()
47
+
48
+ if event.type == pygame.MOUSEBUTTONDOWN and not game_end:
49
+ posx = event.pos[0]
50
+ act = posx//100
51
+ board, win = board.drop_piece(act)
52
+ ai_turn = True
53
+
54
+ if win is not None:
55
+ print("Human has Won")
56
+ print("Board \n")
57
+ print(board)
58
+ game_end = True
59
+
60
+ if event.type == pygame.MOUSEMOTION and not game_end:
61
+ pygame.draw.rect(screen, (0, 0, 0), (0, 0, 700, 100))
62
+ posx = event.pos[0]
63
+
64
+ # If ai is turn 1 then player's turn is second
65
+ if board.player_1 == -1:
66
+ pygame.draw.circle(screen, (230,230,20), (posx, int(100//2)), 50)
67
+ else:
68
+ pygame.draw.circle(screen, (52, 186, 235), (posx, int(100//2)), 50)
69
+
70
+ pygame.display.update()
71
+
72
+ if __name__ == "__main__":
73
+ model = Model(
74
+ n_action = Config.n_action,
75
+ num_hidden = Config.num_hidden,
76
+ num_resblock = Config.num_res_block,
77
+ rate = Config.rate,
78
+ row = Config.row,
79
+ col = Config.col,
80
+ device = Config.device
81
+ )
82
+
83
+ # This is LR = .01 model
84
+ # model_path = './Models/C4GruhitSPatel/FullBuffer5x5V1/TargetModel_500.pt'
85
+
86
+ # This is LR = .001 model
87
+ model_path = "./Models/C4GruhitML/C4CyclicLRV3/TargetModel_500.pt"
88
+ model.load_state_dict(torch.load(model_path))
89
+ model.eval()
90
+
91
+ play_game(model)
92
+ # print("Model Loaded")
mcts.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from model import Model
2
+ from typing import Union, Tuple
3
+ from game import Connect4
4
+ from config import Config
5
+
6
+ import torch
7
+ from torch import Tensor
8
+
9
+ import numpy as np
10
+
11
+ class Node:
12
+ def __init__(self, state: Union[Connect4, None], model: Model, name: str):
13
+
14
+ # Current state that the node represent
15
+ self.state = state
16
+
17
+ # Name of the node to trace it
18
+ self.name = name
19
+
20
+ # A model instance that the node will use to get value and policy
21
+ self.model = model
22
+
23
+ # visit count
24
+ self.N = 0
25
+
26
+ # Intermediate reward value
27
+ self.W = 0
28
+
29
+ # value of the node
30
+ self.value = None
31
+
32
+ # Prior policy for action from this node
33
+ self.policy = None
34
+
35
+ # Set the winner of the current node.
36
+ # Node by default indicating no one has won
37
+ self.win = None
38
+
39
+ # Children of current node
40
+ self.children = {}
41
+
42
+ # valid and invalid actions that can be take from this node
43
+ self.valid_actions = None
44
+ self.invalid_actions = None
45
+
46
+ # Set the valid and invalid actions
47
+ self.set_valid_actions()
48
+
49
+ # Initialize the branches to the childrens
50
+ self.initialize_edges()
51
+
52
+ # Set the valid actions that can be taken from the state that
53
+ # the node represent
54
+ def set_valid_actions(self) -> None:
55
+ if self.state is not None:
56
+ self.valid_actions = self.state.get_valid_moves()
57
+ self.invalid_actions = ~self.valid_actions
58
+
59
+ # initialize the edges from this node to potential childrens
60
+ def initialize_edges(self) -> None:
61
+ if self.state is not None:
62
+ self.children = {}
63
+ for act, valid_move in enumerate(self.valid_actions):
64
+ if valid_move:
65
+ # set state as none for childrens as we do not have it
66
+ self.children[act] = Node(
67
+ state=None,
68
+ model=self.model,
69
+ name=self.name + '_' + str(act)
70
+ )
71
+
72
+ def preprocess_state(self, x:np.ndarray) -> Tensor:
73
+ x = torch.tensor(x, dtype=torch.float32, device=Config.device)
74
+ x = x.unsqueeze(0)
75
+ return x
76
+
77
+ # define the forward pass for the current node
78
+ def forward(self) -> None:
79
+ with torch.no_grad():
80
+ value, policy = self.model(self.preprocess_state(self.state.get_state()))
81
+
82
+ value = value[0, 0]
83
+ policy = policy[0]
84
+
85
+ # Mask the invalid actions
86
+ policy[self.invalid_actions] = 0.
87
+
88
+ # Prevent from all probability from turning 0
89
+ if policy.sum() == 0:
90
+ policy[self.valid_actions] = 1.
91
+
92
+ policy = policy.softmax(dim=-1)
93
+
94
+ self.value = value.detach().cpu().numpy()
95
+ self.policy = policy.detach().cpu().numpy()
96
+
97
+
98
+ # Get policy for the current node
99
+ def get_policy(self) -> np.ndarray:
100
+ if self.policy is None:
101
+ self.forward()
102
+
103
+ return self.policy
104
+
105
+ # Get the value associated with the node
106
+ def get_value(self) -> float:
107
+ if self.value is None:
108
+ self.forward()
109
+
110
+ return self.value
111
+
112
+ class MCTS_NN:
113
+ def __init__(self, state:Connect4, model:Model, log=None):
114
+ self.root = Node(state=state, model=model, name='root')
115
+
116
+ if log is not None:
117
+ self.log = log
118
+
119
+ # For the simulation on the Monte-carlo tree
120
+ def selection(self, node: Node, add_dirichlet:bool=False, iter:int=0) -> float:
121
+ # Get the best child of the current node
122
+ # self.log.write(f'\nSelecting Best child of {node.name}')
123
+ best_child, best_action = self.get_best_child(node, add_dirichlet, iter)
124
+ # self.log.write(f"Iteartion {iter} - Best Action - {best_action} - Node: {node.name}")
125
+
126
+ # If the child is a leaf node(i.e.) either is terminal or is not expanded
127
+ # expand that node
128
+ if best_child.state is None:
129
+ # self.log.write(f'\nExpanding node {best_child.name}')
130
+ val = self.expolore_and_expand(parent=node, child=best_child, action=best_action, iter=iter)
131
+
132
+ # If the node is already expanded than traverse that node further
133
+ else:
134
+ # As per paper only add dirichlet noise for root node's
135
+ # child selection and not later on
136
+ # self.log.write(f'\nSelecting node further on {best_child.name}')
137
+ val = self.selection(node=best_child, add_dirichlet=False, iter=iter)
138
+
139
+ node.N += 1
140
+ node.W += val
141
+
142
+ return -val
143
+
144
+ # Expore and expand the tree
145
+ def expolore_and_expand(self, parent: Node, child: None, action: int, iter=0) -> float:
146
+ # self.log.write(f'\n<========== Explore or Expand Iteration {iter} ==========>')
147
+ # Check if the current state is a terminal state
148
+ if child.win is None:
149
+ # It is not expanded and is not terminal
150
+ # Perform the action for the parent state to get the next state
151
+ next_state, win = parent.state.drop_piece(action)
152
+
153
+ # First check if somone won in this next state
154
+ if win is not None:
155
+ val = -1 if win == parent.state.player_1 else 1
156
+ child.win = win
157
+ # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [Winner Found]')
158
+ # self.log.write(f'\nWinner in that state {win} - child.Value is {val}')
159
+ # self.log.write(f'\nWinning Child in state {child.name}: state\n{next_state}\n')
160
+ # self.log.write('='*100)
161
+ # self.log.write('\n')
162
+
163
+ # else check if the next state results in draw
164
+ elif next_state.is_draw():
165
+ # 0 value if no one has won in the state
166
+ val = 0
167
+
168
+ # 0 for win means no one won
169
+ child.win = 0
170
+ # self.log.write(f'\nPlayer Turn for child is {next_state.player_1}')
171
+ # self.log.write(f'\nDraw Child in state {child.name}: state\n{next_state}\n')
172
+ # self.log.write('='*100)
173
+ # self.log.write('\n')
174
+
175
+ # if the next_state is not winning nor it is draw
176
+ # then expand it normally
177
+ else:
178
+ # If no one is winning yet then get the value for the current
179
+ # state from the child's mode and set it
180
+ child.state = next_state
181
+ child.set_valid_actions()
182
+ child.initialize_edges()
183
+ val = child.get_value()
184
+ # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [No Winner]')
185
+ # self.log.write(f'\nLeaf node expanded for "{child.name}" with val {val:.5f}\n')
186
+ # self.log.write('='*100)
187
+ # self.log.write('\n')
188
+
189
+ else:
190
+ # If the current child represent a draw state then give value 0
191
+ if child.win == 0:
192
+ # self.log.write(f'\nTerminal DRAW state reached for child {child.name}\n')
193
+ # self.log.write('='*100)
194
+ # self.log.write('\n')
195
+ val = 0
196
+
197
+ # If the winner in child node was the player who played a move
198
+ # in the parent node then set -1 as value as it means that
199
+ # the player in child node has lost
200
+ elif child.win == parent.state.player_1:
201
+ # self.log.write(f'\nTerminal Parent Winning state reached for child {child.name}\n')
202
+ # self.log.write('='*100)
203
+ # self.log.write('\n')
204
+ val = -1
205
+
206
+ # if the winner of child node is the same as the player of child node
207
+ # then provide value of +1
208
+ else:
209
+ # self.log.write(f'\nTerminal child Winning state reached for child {child.name}\n')
210
+ # self.log.write('='*100)
211
+ # self.log.write('\n')
212
+ val = 1
213
+
214
+ # Update the visit count and intermidiate reward of child node
215
+ child.N += 1
216
+ child.W += val
217
+
218
+ # Return negative of val because the player in parent node will be
219
+ # the opposite player from the current node. Hence what is good
220
+ # for current node's player should be bad for the parent node's player
221
+ return -val
222
+
223
+
224
+ # Calculate the PUCT score for a node's children
225
+ def get_puct_score(self, parent: Node, child: Node, prior: float) -> float:
226
+ # PUCT is the sum of q_value of current node + the U(S, a)
227
+ q_value = 0
228
+ if child.N == 0:
229
+ q_value = 0
230
+ else:
231
+ # q_value = 1 - ((child.W/child.N) + 1)/2
232
+ q_value = -child.W/child.N
233
+
234
+ # C_puct represent the exploration constant
235
+ c_puct = 1
236
+ u_sa = c_puct * prior * (np.sqrt(parent.N))/(1+child.N)
237
+ return q_value + u_sa
238
+
239
+ def get_dirichlet_noise(self, node: None) -> np.ndarray:
240
+ num_valid_action = node.valid_actions.sum()
241
+ noise_vec = np.random.dirichlet([Config.DIRICHLET_ALPHA]*num_valid_action)
242
+ noise_arr = np.zeros((len(node.valid_actions),), dtype=noise_vec.dtype)
243
+ noise_arr[node.valid_actions] = noise_vec
244
+ return noise_arr
245
+
246
+ # Get the best child for any node
247
+ def get_best_child(self, node: Node, add_dirichlet: bool, iter=0) -> Tuple[Node, int]:
248
+ # the best node is simple the one with highest PUCT value
249
+ policy = node.get_policy()
250
+
251
+ if add_dirichlet:
252
+ noise_arr = self.get_dirichlet_noise(node)
253
+ policy = (1-Config.EPSILON)*policy + Config.EPSILON*noise_arr
254
+
255
+ best_puct = float('-inf')
256
+ best_child = None
257
+ best_action = None
258
+ # self.log.write(f'\n\n==================== Iteration {iter} ====================\n')
259
+ for action, child in node.children.items():
260
+ puct = self.get_puct_score(parent=node, child=child, prior=policy[action])
261
+ # self.log.write(f'{action} - PUCT: {puct:.4f} | N = {child.N} | W = {child.W:.4f} | P = {policy[action]:.4f}\n')
262
+ if puct > best_puct:
263
+ best_puct = puct
264
+ best_child = child
265
+ best_action = action
266
+
267
+ return best_child, best_action
268
+
269
+ # return the policy pie for the root node based on the visit count
270
+ def get_policy_pie(self, temperature:float=1):
271
+ actions = np.zeros((len(self.root.valid_actions),))
272
+
273
+ for action, child in self.root.children.items():
274
+ actions[action] = (child.N)**(1/temperature)
275
+
276
+ actions /= actions.sum()
277
+
278
+ return actions
279
+
280
+ # Traverse the tree by steping to one of the child node of root node
281
+ def update_root(self, action: int) -> None:
282
+ self.root = self.root.children[action]
model.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import Tensor
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+
6
+ class ResNetBlock(nn.Module):
7
+ def __init__(self, num_hidden: int):
8
+ super(ResNetBlock, self).__init__()
9
+
10
+ self.conv1 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False)
11
+ self.bn1 = nn.BatchNorm2d(num_hidden)
12
+ self.conv2 = nn.Conv2d(num_hidden, num_hidden, kernel_size=3, padding=1, bias=False)
13
+ self.bn2 = nn.BatchNorm2d(num_hidden)
14
+
15
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
16
+ residual = x
17
+ x = F.relu(self.bn1(self.conv1(x)))
18
+ x = self.bn2(self.conv2(x))
19
+ x += residual
20
+ x = F.relu(x)
21
+
22
+ return x
23
+
24
+ class DropoutBlock(nn.Module):
25
+ def __init__(self, in_units: int, out_units: int, rate: float):
26
+ super(DropoutBlock, self).__init__()
27
+ self.model = nn.Sequential(
28
+ nn.Linear(in_units, out_units),
29
+ nn.BatchNorm1d(out_units),
30
+ nn.ReLU(),
31
+ nn.Dropout(rate)
32
+ )
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ return self.model(x)
36
+
37
+ class Model(nn.Module):
38
+ def __init__(self, n_action: int, num_hidden: int, num_resblock:int,
39
+ rate:float, row:int, col: int, device: str):
40
+ super(Model, self).__init__()
41
+
42
+ # Bottom layer
43
+ self.initial_block = nn.Sequential(
44
+ nn.Conv2d(4, num_hidden, kernel_size=3, padding=1),
45
+ nn.BatchNorm2d(num_hidden),
46
+ nn.ReLU()
47
+ ).to(device)
48
+
49
+ self.res_blocks = nn.Sequential(
50
+ *[ResNetBlock(num_hidden) for _ in range(num_resblock)]
51
+ ).to(device)
52
+
53
+ self.dropout_model = nn.Sequential(
54
+ DropoutBlock(num_hidden*row*col, 200, rate),
55
+ DropoutBlock(200, 100, rate)
56
+ )
57
+
58
+ self.model = nn.Sequential(
59
+ self.initial_block,
60
+ self.res_blocks,
61
+ nn.Flatten(),
62
+ self.dropout_model
63
+ )
64
+
65
+ self.policy_head = nn.Sequential(
66
+ nn.Linear(100, 100),
67
+ nn.ReLU(),
68
+ nn.Linear(100, n_action),
69
+ ).to(device)
70
+
71
+ self.value_head = nn.Sequential(
72
+ nn.Linear(100, 100),
73
+ nn.ReLU(),
74
+ nn.Linear(100, 1),
75
+ nn.Tanh()
76
+ ).to(device)
77
+
78
+ self.to(device)
79
+
80
+ self.device = device
81
+
82
+ # Losses
83
+ # Mean Square Error for minimizing the difference between estimated value and target value
84
+ self.mse_loss = nn.MSELoss()
85
+
86
+ # Cross entropy loss to evaluate the correct policy as compared to target policy
87
+ self.ce_loss = nn.CrossEntropyLoss()
88
+
89
+ def forward(self, x):
90
+ x = self.model(x)
91
+ value = self.value_head(x)
92
+ policy = self.policy_head(x)
93
+
94
+ return value, policy
95
+
96
+ # Perform the loss calculation
97
+ def get_loss(self, pred_val, pred_policy, true_val, true_policy):
98
+ val_loss = self.mse_loss(pred_val, true_val)
99
+ policy_loss = self.ce_loss(pred_policy, true_policy)
100
+
101
+ final_loss = val_loss + policy_loss
102
+ return {
103
+ 'total_loss': final_loss,
104
+ 'value_loss': val_loss,
105
+ 'policy_loss': policy_loss
106
+ }
requirement.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ torch==2.2.2
2
+ tqdm==4.66.2
3
+ pygame==2.5.2
4
+ fastapi==0.110.1
5
+ pydantic==2.6.4
6
+ uvicorn==0.29.0
7
+ numpy==1.26.4
trainer.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import optim
3
+ from torch.utils.tensorboard import SummaryWriter
4
+
5
+ import tqdm
6
+ import numpy as np
7
+
8
+ from model import Model
9
+ from buffer import Buffer
10
+
11
+
12
+ class Trainer:
13
+ def __init__(self, model: Model, buffer: Buffer, base_lr:float = 0.001,
14
+ weight_decay=1e-4, device:str='cpu'):
15
+ self.main_model = model
16
+ self.main_buffer = buffer
17
+ self.global_step = 0
18
+ self.device = device
19
+
20
+ # optimizer
21
+ self.optimizer = optim.SGD(
22
+ self.main_model.parameters(),
23
+ lr = base_lr,
24
+ weight_decay = weight_decay,
25
+ momentum = 0.9
26
+ )
27
+
28
+ # self.scheduler = optim.lr_scheduler.CyclicLR(
29
+ # self.optimizer,
30
+ # base_lr = base_lr,
31
+ # max_lr = 0.1
32
+ # )
33
+
34
+ # Tensorboard summary writer
35
+ self.writer = SummaryWriter()
36
+
37
+ def transfer_buffer(self, buffer) -> None:
38
+ for state, value, policy in zip(buffer.state, buffer.value, buffer.policy):
39
+ self.main_buffer.store_experience(
40
+ state = state,
41
+ value = value,
42
+ policy = policy
43
+ )
44
+
45
+ def reset_buffer(self) -> None:
46
+ self.main_buffer.reset()
47
+
48
+ # learn from the buffer
49
+ def learn(self, state: np.ndarray, value: np.ndarray, policy: np.ndarray) -> float:
50
+
51
+ state = torch.tensor(state, dtype=torch.float32, device=self.device)
52
+ value = torch.tensor(value, dtype=torch.float32, device=self.device).unsqueeze(-1)
53
+ policy = torch.tensor(policy, dtype=torch.float32, device=self.device)
54
+
55
+ pred_val, pred_policy = self.main_model(state)
56
+
57
+ self.optimizer.zero_grad()
58
+ loss = self.main_model.get_loss(pred_val, pred_policy, value, policy)
59
+ loss.backward()
60
+ self.optimizer.step()
61
+
62
+ return loss.detach().cpu().numpy()
63
+
64
+ # Training loop for the model
65
+ def train_model(self, epochs: int, batch_size: int):
66
+
67
+ train_steps = np.ceil(len(self.main_buffer) / batch_size).astype(np.int32)
68
+
69
+ # perform the training
70
+ for epoch in range(epochs):
71
+ for state, value, policy in tqdm(self.main_buffer.sample(batch_size), total=train_steps, desc=f'Epoch:{epoch+1}'):
72
+ loss = self.learn(state, value, policy)
73
+ self.writer.add_scalar("loss", loss, self.global_step)
74
+ self.global_step += 1
75
+
76
+ self.writer.flush()
77
+
78
+ # close the writer
79
+ def close_writer(self):
80
+ self.writer.close()
81
+
82
+ # Save the model
83
+ def save_model(self, step: int):
84
+ torch.save(self.main_model.state_dict(), f'TargetModel_{step}.pt')
85
+ torch.save(self.optimizer.state_dict(), f'Optimizer_{step}.pt')
view_board.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pygame
2
+ import numpy as np
3
+ import sys
4
+ from typing import Tuple, Union
5
+ from config import Config
6
+
7
+ SQUARESIZE = 100
8
+
9
+ def draw_board(screen, board):
10
+ COLUMN_COUNT = Config.col
11
+ ROW_COUNT = Config.row
12
+ SQUARESIZE = 100
13
+ RADIUS = int(SQUARESIZE/2 - 5)
14
+
15
+ BLUE = (52, 186, 235)
16
+ GREY = (70, 71, 70)
17
+ WHITE = (255,255,255)
18
+ YELLOW = (230,230,20)
19
+
20
+ width = COLUMN_COUNT * SQUARESIZE
21
+ height = (ROW_COUNT+1) * SQUARESIZE
22
+
23
+ size = (width, height)
24
+ board = np.flip(board,0)
25
+ for c in range(COLUMN_COUNT):
26
+ for r in range(ROW_COUNT):
27
+ pygame.draw.rect(screen, GREY, (c*SQUARESIZE, r*SQUARESIZE+SQUARESIZE, SQUARESIZE, SQUARESIZE))
28
+ pygame.draw.circle(screen, WHITE, (int(c*SQUARESIZE+SQUARESIZE/2), int(r*SQUARESIZE+SQUARESIZE+SQUARESIZE/2)), RADIUS)
29
+
30
+ for c in range(COLUMN_COUNT):
31
+ for r in range(ROW_COUNT):
32
+ if board[r][c] == 1:
33
+ pygame.draw.circle(screen, BLUE, (int(c*SQUARESIZE+SQUARESIZE/2), height-int(r*SQUARESIZE+SQUARESIZE/2)), RADIUS)
34
+ elif board[r][c] == -1:
35
+ pygame.draw.circle(screen, YELLOW, (int(c*SQUARESIZE+SQUARESIZE/2), height-int(r*SQUARESIZE+SQUARESIZE/2)), RADIUS)
36
+
37
+ def draw_winning_line(screen, start_pos:Union[None, Tuple[int, int]], end_pos:Union[None, Tuple[int, int]]):
38
+ if start_pos is None or end_pos is None:
39
+ return
40
+
41
+ offset = SQUARESIZE//2
42
+ start_line = (SQUARESIZE*start_pos[0]+1+offset, SQUARESIZE*(start_pos[1]+1)+offset)
43
+ end_line = (SQUARESIZE*end_pos[0]+offset, SQUARESIZE*(end_pos[1]+1)+offset)
44
+
45
+ # print("Start pos: ", start_pos)
46
+ # print("End pos: ", end_pos)
47
+ # print("Start line: ", start_line)
48
+ # print("End Line: ", end_line)
49
+ pygame.draw.line(screen, (255, 0, 0), start_line, end_line, 10)
50
+
51
+ def render(board):
52
+ pygame.init()
53
+ screen = pygame.display.set_mode((700,700))
54
+
55
+ while True:
56
+ for event in pygame.event.get():
57
+ if event.type == pygame.QUIT:
58
+ sys.exit()
59
+ draw_board(screen,board)
60
+ pygame.display.update()