import math import numpy as np import torch from src.game import OthelloGame from src.bitboard import make_input_planes, bit_to_row_col, popcount class MCTSNode: def __init__(self, prior, to_play): self.prior = prior self.visit_count = 0 self.value_sum = 0 self.children = {} # move_bit -> MCTSNode self.to_play = to_play # Whose turn it is at this node def value(self): if self.visit_count == 0: return 0 return self.value_sum / self.visit_count def expand(self, policy_logits, valid_moves, next_to_play): """ Expands the node using the policy from the neural network. """ # Softmax policy = np.exp(policy_logits - np.max(policy_logits)) # Stable softmax policy /= np.sum(policy) # Masking invalid moves? # Ideally we only care about probabilities of valid moves. # But indices 0-63 + 64 (pass). valid_probs_sum = 0 temp_children = {} for move_bit in valid_moves: if move_bit == 0: # Pass idx = 64 else: r, c = bit_to_row_col(move_bit) # Helper returns r,c. Index is r*8+c. if r == -1: idx = 64 # Should not happen if move_bit != 0 else: idx = r * 8 + c prob = policy[idx] valid_probs_sum += prob temp_children[move_bit] = prob # Renormalize probabilities over valid moves if valid_probs_sum > 0: for move, prob in temp_children.items(): self.children[move] = MCTSNode(prior=prob / valid_probs_sum, to_play=next_to_play) else: # Unexpected: Policy gave 0 prob to all valid moves. Uniform. prob = 1.0 / len(valid_moves) for move in valid_moves: self.children[move] = MCTSNode(prior=prob, to_play=next_to_play) class MCTS: def __init__(self, model, cpuct=1.0, num_simulations=800): self.model = model self.cpuct = cpuct self.num_simulations = num_simulations def search(self, game: OthelloGame): """ Executes MCTS simulations and returns the root node (containing mechanics for move selection). """ # Create Root valid_moves_bb = game.get_valid_moves(game.player_bb, game.opponent_bb) valid_moves_list = self._get_moves_list(valid_moves_bb) # Handle case where current player has no moves. # In Othello, if one cannot move, they Pass (move=0). # Unless BOTH cannot move, then Terminal. # game.get_valid_moves returns 0 if no moves. if valid_moves_bb == 0: if game.is_terminal(): return None # Game Over valid_moves_list = [0] # Evaluate Root (to initialize it) root = MCTSNode(prior=0, to_play=game.turn) # Input for NN: Always Canonical (Player, Opponent) state_tensor = make_input_planes(game.player_bb, game.opponent_bb) # Move to model device device = next(self.model.parameters()).device state_tensor = state_tensor.to(device) self.model.eval() with torch.no_grad(): policy_logits, _ = self.model(state_tensor) # Determine next player for root's children # If we play a move, the turn usually swaps. # But we need to check if the move was a Pass? # Logic: Node stores `to_play`. Children stores `next_to_play`. # In `expand`, we pass `next_to_play`. # But `next_to_play` depends on the move? # Usually yes. But in Othello, turn ALWAYS swaps unless... wait. # If I play a move, it is now Opponent's turn. # Even if Opponent has to Pass immediately, it is THEIR turn to Pass. # So `next_to_play` is always `-game.turn`. root.expand(policy_logits.cpu().numpy().flatten(), valid_moves_list, -game.turn) # Add exploration noise to root self._add_dirichlet_noise(root) for _ in range(self.num_simulations): node = root sim_game = self._clone_game(game) search_path = [node] last_value = 0 # 1. Selection while node.children: move_bit, node = self._select_child(node) search_path.append(node) sim_game.play_move(move_bit) # 2. Evaluation & Expansion if sim_game.is_terminal(): # Terminal Value from perspective of Current Turn (sim_game.turn) # Wait, if terminal, there is no turn. # Value relative to `node.to_play` (which determines who Just Passed/Finished?) # Let's standarize: Value relative to Black (1). p1_score = popcount(sim_game.player_bb) if sim_game.turn == 1 else popcount(sim_game.opponent_bb) # game.player_bb tracks 'Current Player' # wait. sim_game.player_bb is WHOSE turn it is. # If sim_game.turn == 1 (Black). player_bb is Black. # If sim_game.turn == -1 (White). player_bb is White. # Score difference from Black's perspective if sim_game.turn == 1: black_score = popcount(sim_game.player_bb) white_score = popcount(sim_game.opponent_bb) else: white_score = popcount(sim_game.player_bb) black_score = popcount(sim_game.opponent_bb) diff = black_score - white_score if diff > 0: last_value = 1.0 # Black wins elif diff < 0: last_value = -1.0 # White wins else: last_value = 0.0 else: # Evaluate state_tensor = make_input_planes(sim_game.player_bb, sim_game.opponent_bb) # Move to model device device = next(self.model.parameters()).device state_tensor = state_tensor.to(device) with torch.no_grad(): policy_logits, v = self.model(state_tensor) # v is value for Current Player (sim_game.turn). # If Black -> v is prob Black wins. # If White -> v is prob White wins. # We need standardized value for backprop? # Let's convert to Black's perspective. val_for_current = v.item() if sim_game.turn == 1: last_value = val_for_current else: last_value = -val_for_current # If good for White (-1), then Bad for Black (-1). # Wait. If White wins, val_for_current (White) = 1. # Then last_value (Black) = -1. Correct. valid_bb = sim_game.get_valid_moves(sim_game.player_bb, sim_game.opponent_bb) valid_list = self._get_moves_list(valid_bb) if valid_bb == 0: valid_list = [0] node.expand(policy_logits.cpu().numpy().flatten(), valid_list, -sim_game.turn) # 3. Backup self._backpropagate(search_path, last_value) return root def _select_child(self, node): best_score = -float('inf') best_action = None best_child = None for action, child in node.children.items(): # UCB # Q is value for 'node.to_play'. # child.value() is average raw value accumulated. # We accumulated 'Black Perspective Value'. # If node.to_play is Black (1). We want high Value (1). # If node.to_play is White (-1). We want low Value (-1). # Let's adjust Q based on turn. mean_val = child.value() # This is Black-perspective value. if node.to_play == 1: # Black q = mean_val else: # White q = -mean_val # Normalize q to [0, 1]? Tanh gives [-1, 1]. # AlphaZero uses [0, 1]. Tanh uses [-1, 1]. # PUCT expects q and u to be comparable. # If q in [-1, 1], u should be similar scale. u = self.cpuct * child.prior * math.sqrt(node.visit_count) / (1 + child.visit_count) score = q + u if score > best_score: best_score = score best_action = action best_child = child return best_action, best_child def _backpropagate(self, search_path, value): """ value: The evaluation of the lead node, from BLACK's perspective (1=Black wins, -1=White wins). """ for node in search_path: node.value_sum += value node.visit_count += 1 # We store Sum of Black-Values. # So average is Average Black Value. def _add_dirichlet_noise(self, node): eps = 0.25 alpha = 0.3 moves = list(node.children.keys()) noise = np.random.dirichlet([alpha] * len(moves)) for i, move in enumerate(moves): node.children[move].prior = (1 - eps) * node.children[move].prior + eps * noise[i] def _get_moves_list(self, moves_bb): moves = [] if moves_bb == 0: return [] # Extract bits # In python integers have infinite precision, so normal bit hacks work but need care with loops. # Ideally copy bb. temp = moves_bb while temp: # Isolate LSB lsb = temp & -temp moves.append(lsb) temp ^= lsb # Remove LSB return moves def _clone_game(self, game): new_game = OthelloGame() new_game.player_bb = game.player_bb new_game.opponent_bb = game.opponent_bb new_game.turn = game.turn return new_game