File size: 10,608 Bytes
9deb5ea
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
import math
import numpy as np
import torch
from src.game import OthelloGame
from src.bitboard import make_input_planes, bit_to_row_col, popcount

class MCTSNode:
    def __init__(self, prior, to_play):
        self.prior = prior
        self.visit_count = 0
        self.value_sum = 0
        self.children = {} # move_bit -> MCTSNode
        self.to_play = to_play # Whose turn it is at this node
        
    def value(self):
        if self.visit_count == 0:
            return 0
        return self.value_sum / self.visit_count

    def expand(self, policy_logits, valid_moves, next_to_play):
        """

        Expands the node using the policy from the neural network.

        """
        # Softmax
        policy = np.exp(policy_logits - np.max(policy_logits)) # Stable softmax
        policy /= np.sum(policy)
        
        # Masking invalid moves? 
        # Ideally we only care about probabilities of valid moves.
        # But indices 0-63 + 64 (pass).
        
        valid_probs_sum = 0
        temp_children = {}
        
        for move_bit in valid_moves:
            if move_bit == 0: # Pass
                idx = 64
            else:
                r, c = bit_to_row_col(move_bit)
                # Helper returns r,c. Index is r*8+c.
                if r == -1: idx = 64 # Should not happen if move_bit != 0
                else: idx = r * 8 + c
            
            prob = policy[idx]
            valid_probs_sum += prob
            temp_children[move_bit] = prob
            
        # Renormalize probabilities over valid moves
        if valid_probs_sum > 0:
            for move, prob in temp_children.items():
                self.children[move] = MCTSNode(prior=prob / valid_probs_sum, to_play=next_to_play)
        else:
             # Unexpected: Policy gave 0 prob to all valid moves. Uniform.
             prob = 1.0 / len(valid_moves)
             for move in valid_moves:
                 self.children[move] = MCTSNode(prior=prob, to_play=next_to_play)

class MCTS:
    def __init__(self, model, cpuct=1.0, num_simulations=800):
        self.model = model
        self.cpuct = cpuct
        self.num_simulations = num_simulations
        
    def search(self, game: OthelloGame):
        """

        Executes MCTS simulations and returns the root node (containing mechanics for move selection).

        """
        # Create Root
        valid_moves_bb = game.get_valid_moves(game.player_bb, game.opponent_bb)
        valid_moves_list = self._get_moves_list(valid_moves_bb)
        
        # Handle case where current player has no moves. 
        # In Othello, if one cannot move, they Pass (move=0).
        # Unless BOTH cannot move, then Terminal.
        # game.get_valid_moves returns 0 if no moves.
        if valid_moves_bb == 0:
             if game.is_terminal():
                 return None # Game Over
             valid_moves_list = [0]
        
        # Evaluate Root (to initialize it)
        root = MCTSNode(prior=0, to_play=game.turn)
        
        # Input for NN: Always Canonical (Player, Opponent)
        state_tensor = make_input_planes(game.player_bb, game.opponent_bb)
        
        # Move to model device
        device = next(self.model.parameters()).device
        state_tensor = state_tensor.to(device)

        self.model.eval()
        with torch.no_grad():
            policy_logits, _ = self.model(state_tensor)
            
        # Determine next player for root's children
        # If we play a move, the turn usually swaps.
        # But we need to check if the move was a Pass?
        # Logic: Node stores `to_play`. Children stores `next_to_play`.
        # In `expand`, we pass `next_to_play`.
        # But `next_to_play` depends on the move?
        # Usually yes. But in Othello, turn ALWAYS swaps unless... wait.
        # If I play a move, it is now Opponent's turn.
        # Even if Opponent has to Pass immediately, it is THEIR turn to Pass.
        # So `next_to_play` is always `-game.turn`.
        
        root.expand(policy_logits.cpu().numpy().flatten(), valid_moves_list, -game.turn)
        
        # Add exploration noise to root
        self._add_dirichlet_noise(root)
        
        for _ in range(self.num_simulations):
            node = root
            sim_game = self._clone_game(game)
            search_path = [node]
            last_value = 0
            
            # 1. Selection
            while node.children:
                move_bit, node = self._select_child(node)
                search_path.append(node)
                sim_game.play_move(move_bit)
            
            # 2. Evaluation & Expansion
            if sim_game.is_terminal():
                 # Terminal Value from perspective of Current Turn (sim_game.turn)
                 # Wait, if terminal, there is no turn.
                 # Value relative to `node.to_play` (which determines who Just Passed/Finished?)
                 # Let's standarize: Value relative to Black (1).
                 p1_score = popcount(sim_game.player_bb) if sim_game.turn == 1 else popcount(sim_game.opponent_bb) # game.player_bb tracks 'Current Player'
                 # wait. sim_game.player_bb is WHOSE turn it is.
                 # If sim_game.turn == 1 (Black). player_bb is Black.
                 # If sim_game.turn == -1 (White). player_bb is White.
                 
                 # Score difference from Black's perspective
                 if sim_game.turn == 1:
                     black_score = popcount(sim_game.player_bb)
                     white_score = popcount(sim_game.opponent_bb)
                 else:
                     white_score = popcount(sim_game.player_bb)
                     black_score = popcount(sim_game.opponent_bb)
                     
                 diff = black_score - white_score
                 if diff > 0: last_value = 1.0 # Black wins
                 elif diff < 0: last_value = -1.0 # White wins
                 else: last_value = 0.0
                 
            else:
                 # Evaluate
                 state_tensor = make_input_planes(sim_game.player_bb, sim_game.opponent_bb)
                 
                 # Move to model device
                 device = next(self.model.parameters()).device
                 state_tensor = state_tensor.to(device)
                 
                 with torch.no_grad():
                    policy_logits, v = self.model(state_tensor)
                 
                 # v is value for Current Player (sim_game.turn).
                 # If Black -> v is prob Black wins.
                 # If White -> v is prob White wins.
                 # We need standardized value for backprop?
                 # Let's convert to Black's perspective.
                 val_for_current = v.item()
                 if sim_game.turn == 1:
                     last_value = val_for_current
                 else:
                     last_value = -val_for_current # If good for White (-1), then Bad for Black (-1).
                     # Wait. If White wins, val_for_current (White) = 1.
                     # Then last_value (Black) = -1. Correct.
                 
                 valid_bb = sim_game.get_valid_moves(sim_game.player_bb, sim_game.opponent_bb)
                 valid_list = self._get_moves_list(valid_bb)
                 if valid_bb == 0: valid_list = [0]
                 
                 node.expand(policy_logits.cpu().numpy().flatten(), valid_list, -sim_game.turn)
            
            # 3. Backup
            self._backpropagate(search_path, last_value)
            
        return root

    def _select_child(self, node):
        best_score = -float('inf')
        best_action = None
        best_child = None
        
        for action, child in node.children.items():
            # UCB
            # Q is value for 'node.to_play'.
            # child.value() is average raw value accumulated.
            # We accumulated 'Black Perspective Value'.
            # If node.to_play is Black (1). We want high Value (1).
            # If node.to_play is White (-1). We want low Value (-1).
            
            # Let's adjust Q based on turn.
            mean_val = child.value() # This is Black-perspective value.
            
            if node.to_play == 1: # Black
                q = mean_val
            else: # White
                q = -mean_val
                
            # Normalize q to [0, 1]? Tanh gives [-1, 1].
            # AlphaZero uses [0, 1]. Tanh uses [-1, 1].
            # PUCT expects q and u to be comparable.
            # If q in [-1, 1], u should be similar scale.
            
            u = self.cpuct * child.prior * math.sqrt(node.visit_count) / (1 + child.visit_count)
            
            score = q + u
            if score > best_score:
                best_score = score
                best_action = action
                best_child = child
                
        return best_action, best_child
        
    def _backpropagate(self, search_path, value):
        """

        value: The evaluation of the lead node, from BLACK's perspective (1=Black wins, -1=White wins).

        """
        for node in search_path:
             node.value_sum += value
             node.visit_count += 1
             # We store Sum of Black-Values.
             # So average is Average Black Value.

    def _add_dirichlet_noise(self, node):
        eps = 0.25
        alpha = 0.3
        moves = list(node.children.keys())
        noise = np.random.dirichlet([alpha] * len(moves))
        
        for i, move in enumerate(moves):
            node.children[move].prior = (1 - eps) * node.children[move].prior + eps * noise[i]

    def _get_moves_list(self, moves_bb):
        moves = []
        if moves_bb == 0: return [] 
        
        # Extract bits
        # In python integers have infinite precision, so normal bit hacks work but need care with loops.
        # Ideally copy bb.
        temp = moves_bb
        while temp:
             # Isolate LSB
             lsb = temp & -temp 
             moves.append(lsb)
             temp ^= lsb # Remove LSB
        return moves

    def _clone_game(self, game):
        new_game = OthelloGame()
        new_game.player_bb = game.player_bb
        new_game.opponent_bb = game.opponent_bb
        new_game.turn = game.turn
        return new_game