File size: 10,724 Bytes
1fab54b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
from model import Model
from typing import Union, Tuple
from game import Connect4
from config import Config

import torch
from torch import Tensor

import numpy as np

class Node:
    def __init__(self, state: Union[Connect4, None], model: Model, name: str):

        # Current state that the node represent
        self.state = state

        # Name of the node to trace it
        self.name = name

        # A model instance that the node will use to get value and policy
        self.model = model

        # visit count
        self.N = 0

        # Intermediate reward value
        self.W = 0

        # value of the node
        self.value = None

        # Prior policy for action from this node
        self.policy = None

        # Set the winner of the current node.
        # Node by default indicating no one has won
        self.win = None

        # Children of current node
        self.children = {}

        # valid and invalid actions that can be take from this node
        self.valid_actions = None
        self.invalid_actions = None

        # Set the valid and invalid actions
        self.set_valid_actions()

        # Initialize the branches to the childrens
        self.initialize_edges()

    # Set the valid actions that can be taken from the state that
    # the node represent
    def set_valid_actions(self) -> None:
        if self.state is not None:
            self.valid_actions = self.state.get_valid_moves()
            self.invalid_actions = ~self.valid_actions

    # initialize the edges from this node to potential childrens
    def initialize_edges(self) -> None:
        if self.state is not None:
            self.children = {}
            for act, valid_move in enumerate(self.valid_actions):
                if valid_move:
                    # set state as none for childrens as we do not have it
                    self.children[act] = Node(
                        state=None,
                        model=self.model,
                        name=self.name + '_' + str(act)
                    )

    def preprocess_state(self, x:np.ndarray) -> Tensor:
        x = torch.tensor(x, dtype=torch.float32, device=Config.device)
        x = x.unsqueeze(0)
        return x

    # define the forward pass for the current node
    def forward(self) -> None:
        with torch.no_grad():
            value, policy = self.model(self.preprocess_state(self.state.get_state()))

        value = value[0, 0]
        policy = policy[0]

        # Mask the invalid actions
        policy[self.invalid_actions] = 0.

        # Prevent from all probability from turning 0
        if policy.sum() == 0:
            policy[self.valid_actions] = 1.

        policy = policy.softmax(dim=-1)

        self.value = value.detach().cpu().numpy()
        self.policy = policy.detach().cpu().numpy()


    # Get policy for the current node
    def get_policy(self) -> np.ndarray:
        if self.policy is None:
            self.forward()

        return self.policy

    # Get the value associated with the node
    def get_value(self) -> float:
        if self.value is None:
            self.forward()

        return self.value
    
class MCTS_NN:
    def __init__(self, state:Connect4, model:Model, log=None):
        self.root = Node(state=state, model=model, name='root')

        if log is not None:
            self.log = log

    # For the simulation on the Monte-carlo tree
    def selection(self, node: Node, add_dirichlet:bool=False, iter:int=0) -> float:
        # Get the best child of the current node
        # self.log.write(f'\nSelecting Best child of {node.name}')
        best_child, best_action = self.get_best_child(node, add_dirichlet, iter)
        # self.log.write(f"Iteartion {iter} - Best Action - {best_action} - Node: {node.name}")

        # If the child is a leaf node(i.e.) either is terminal or is not expanded
        # expand that node
        if best_child.state is None:
            # self.log.write(f'\nExpanding node {best_child.name}')
            val = self.expolore_and_expand(parent=node, child=best_child, action=best_action, iter=iter)

        # If the node is already expanded than traverse that node further
        else:
            # As per paper only add dirichlet noise for root node's
            # child selection and not later on
            # self.log.write(f'\nSelecting node further on {best_child.name}')
            val = self.selection(node=best_child, add_dirichlet=False, iter=iter)

        node.N += 1
        node.W += val

        return -val

    # Expore and expand the tree
    def expolore_and_expand(self, parent: Node, child: None, action: int, iter=0) -> float:
        # self.log.write(f'\n<========== Explore or Expand Iteration {iter} ==========>')
        # Check if the current state is a terminal state
        if child.win is None:
            # It is not expanded and is not terminal
            # Perform the action for the parent state to get the next state
            next_state, win = parent.state.drop_piece(action)

            # First check if somone won in this next state
            if win is not None:
                val = -1 if win == parent.state.player_1 else 1
                child.win = win
                # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [Winner Found]')
                # self.log.write(f'\nWinner in that state {win} - child.Value is {val}')
                # self.log.write(f'\nWinning Child in state {child.name}: state\n{next_state}\n')
                # self.log.write('='*100)
                # self.log.write('\n')

            # else check if the next state results in draw
            elif next_state.is_draw():
                # 0 value if no one has won in the state
                val = 0

                # 0 for win means no one won
                child.win = 0
                # self.log.write(f'\nPlayer Turn for child is {next_state.player_1}')
                # self.log.write(f'\nDraw Child in state {child.name}: state\n{next_state}\n')
                # self.log.write('='*100)
                # self.log.write('\n')

            # if the next_state is not winning nor it is draw
            # then expand it normally
            else:
                # If no one is winning yet then get the value for the current
                # state from the child's mode and set it
                child.state = next_state
                child.set_valid_actions()
                child.initialize_edges()
                val = child.get_value()
                # self.log.write(f'\nPlayer Turn for child is {next_state.player_1} | [No Winner]')
                # self.log.write(f'\nLeaf node expanded for "{child.name}" with val {val:.5f}\n')
                # self.log.write('='*100)
                # self.log.write('\n')

        else:
            # If the current child represent a draw state then give value 0
            if child.win == 0:
                # self.log.write(f'\nTerminal DRAW state reached for child {child.name}\n')
                # self.log.write('='*100)
                # self.log.write('\n')
                val = 0

            # If the winner in child node was the player who played a move
            # in the parent node then set -1 as value as it means that
            # the player in child node has lost
            elif child.win == parent.state.player_1:
                # self.log.write(f'\nTerminal Parent Winning state reached for child {child.name}\n')
                # self.log.write('='*100)
                # self.log.write('\n')
                val = -1

            # if the winner of child node is the same as the player of child node
            # then provide value of +1
            else:
                # self.log.write(f'\nTerminal child Winning state reached for child {child.name}\n')
                # self.log.write('='*100)
                # self.log.write('\n')
                val = 1

        # Update the visit count and intermidiate reward of child node
        child.N += 1
        child.W += val

        # Return negative of val because the player in parent node will be
        # the opposite player from the current node. Hence what is good
        # for current node's player should be bad for the parent node's player
        return -val


    # Calculate the PUCT score for a node's children
    def get_puct_score(self, parent: Node, child: Node, prior: float) -> float:
        # PUCT is the sum of q_value of current node + the U(S, a)
        q_value = 0
        if child.N == 0:
            q_value = 0
        else:
            # q_value = 1 - ((child.W/child.N) + 1)/2
            q_value = -child.W/child.N

        # C_puct represent the exploration constant
        c_puct = 1
        u_sa = c_puct * prior * (np.sqrt(parent.N))/(1+child.N)
        return q_value + u_sa

    def get_dirichlet_noise(self, node: None) -> np.ndarray:
        num_valid_action = node.valid_actions.sum()
        noise_vec = np.random.dirichlet([Config.DIRICHLET_ALPHA]*num_valid_action)
        noise_arr = np.zeros((len(node.valid_actions),), dtype=noise_vec.dtype)
        noise_arr[node.valid_actions] = noise_vec
        return noise_arr

    # Get the best child for any node
    def get_best_child(self, node: Node, add_dirichlet: bool, iter=0) -> Tuple[Node, int]:
        # the best node is simple the one with highest PUCT value
        policy = node.get_policy()

        if add_dirichlet:
            noise_arr = self.get_dirichlet_noise(node)
            policy = (1-Config.EPSILON)*policy + Config.EPSILON*noise_arr

        best_puct = float('-inf')
        best_child = None
        best_action = None
        # self.log.write(f'\n\n==================== Iteration {iter} ====================\n')
        for action, child in node.children.items():
            puct = self.get_puct_score(parent=node, child=child, prior=policy[action])
            # self.log.write(f'{action} - PUCT: {puct:.4f} | N = {child.N} | W = {child.W:.4f} | P = {policy[action]:.4f}\n')
            if puct > best_puct:
                best_puct = puct
                best_child = child
                best_action = action

        return best_child, best_action

    # return the policy pie for the root node based on the visit count
    def get_policy_pie(self, temperature:float=1):
        actions = np.zeros((len(self.root.valid_actions),))

        for action, child in self.root.children.items():
            actions[action] = (child.N)**(1/temperature)

        actions /= actions.sum()

        return actions

    # Traverse the tree by steping to one of the child node of root node
    def update_root(self, action: int) -> None:
        self.root = self.root.children[action]