| | import math
|
| | import numpy as np
|
| | import torch
|
| | from src.game import OthelloGame
|
| | from src.bitboard import make_input_planes, bit_to_row_col, popcount
|
| |
|
| | class MCTSNode:
|
| | def __init__(self, prior, to_play):
|
| | self.prior = prior
|
| | self.visit_count = 0
|
| | self.value_sum = 0
|
| | self.children = {}
|
| | self.to_play = to_play
|
| |
|
| | def value(self):
|
| | if self.visit_count == 0:
|
| | return 0
|
| | return self.value_sum / self.visit_count
|
| |
|
| | def expand(self, policy_logits, valid_moves, next_to_play):
|
| | """
|
| | Expands the node using the policy from the neural network.
|
| | """
|
| |
|
| | policy = np.exp(policy_logits - np.max(policy_logits))
|
| | policy /= np.sum(policy)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | valid_probs_sum = 0
|
| | temp_children = {}
|
| |
|
| | for move_bit in valid_moves:
|
| | if move_bit == 0:
|
| | idx = 64
|
| | else:
|
| | r, c = bit_to_row_col(move_bit)
|
| |
|
| | if r == -1: idx = 64
|
| | else: idx = r * 8 + c
|
| |
|
| | prob = policy[idx]
|
| | valid_probs_sum += prob
|
| | temp_children[move_bit] = prob
|
| |
|
| |
|
| | if valid_probs_sum > 0:
|
| | for move, prob in temp_children.items():
|
| | self.children[move] = MCTSNode(prior=prob / valid_probs_sum, to_play=next_to_play)
|
| | else:
|
| |
|
| | prob = 1.0 / len(valid_moves)
|
| | for move in valid_moves:
|
| | self.children[move] = MCTSNode(prior=prob, to_play=next_to_play)
|
| |
|
| | class MCTS:
|
| | def __init__(self, model, cpuct=1.0, num_simulations=800):
|
| | self.model = model
|
| | self.cpuct = cpuct
|
| | self.num_simulations = num_simulations
|
| |
|
| | def search(self, game: OthelloGame):
|
| | """
|
| | Executes MCTS simulations and returns the root node (containing mechanics for move selection).
|
| | """
|
| |
|
| | valid_moves_bb = game.get_valid_moves(game.player_bb, game.opponent_bb)
|
| | valid_moves_list = self._get_moves_list(valid_moves_bb)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if valid_moves_bb == 0:
|
| | if game.is_terminal():
|
| | return None
|
| | valid_moves_list = [0]
|
| |
|
| |
|
| | root = MCTSNode(prior=0, to_play=game.turn)
|
| |
|
| |
|
| | state_tensor = make_input_planes(game.player_bb, game.opponent_bb)
|
| |
|
| |
|
| | device = next(self.model.parameters()).device
|
| | state_tensor = state_tensor.to(device)
|
| |
|
| | self.model.eval()
|
| | with torch.no_grad():
|
| | policy_logits, _ = self.model(state_tensor)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | root.expand(policy_logits.cpu().numpy().flatten(), valid_moves_list, -game.turn)
|
| |
|
| |
|
| | self._add_dirichlet_noise(root)
|
| |
|
| | for _ in range(self.num_simulations):
|
| | node = root
|
| | sim_game = self._clone_game(game)
|
| | search_path = [node]
|
| | last_value = 0
|
| |
|
| |
|
| | while node.children:
|
| | move_bit, node = self._select_child(node)
|
| | search_path.append(node)
|
| | sim_game.play_move(move_bit)
|
| |
|
| |
|
| | if sim_game.is_terminal():
|
| |
|
| |
|
| |
|
| |
|
| | p1_score = popcount(sim_game.player_bb) if sim_game.turn == 1 else popcount(sim_game.opponent_bb)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | if sim_game.turn == 1:
|
| | black_score = popcount(sim_game.player_bb)
|
| | white_score = popcount(sim_game.opponent_bb)
|
| | else:
|
| | white_score = popcount(sim_game.player_bb)
|
| | black_score = popcount(sim_game.opponent_bb)
|
| |
|
| | diff = black_score - white_score
|
| | if diff > 0: last_value = 1.0
|
| | elif diff < 0: last_value = -1.0
|
| | else: last_value = 0.0
|
| |
|
| | else:
|
| |
|
| | state_tensor = make_input_planes(sim_game.player_bb, sim_game.opponent_bb)
|
| |
|
| |
|
| | device = next(self.model.parameters()).device
|
| | state_tensor = state_tensor.to(device)
|
| |
|
| | with torch.no_grad():
|
| | policy_logits, v = self.model(state_tensor)
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | val_for_current = v.item()
|
| | if sim_game.turn == 1:
|
| | last_value = val_for_current
|
| | else:
|
| | last_value = -val_for_current
|
| |
|
| |
|
| |
|
| | valid_bb = sim_game.get_valid_moves(sim_game.player_bb, sim_game.opponent_bb)
|
| | valid_list = self._get_moves_list(valid_bb)
|
| | if valid_bb == 0: valid_list = [0]
|
| |
|
| | node.expand(policy_logits.cpu().numpy().flatten(), valid_list, -sim_game.turn)
|
| |
|
| |
|
| | self._backpropagate(search_path, last_value)
|
| |
|
| | return root
|
| |
|
| | def _select_child(self, node):
|
| | best_score = -float('inf')
|
| | best_action = None
|
| | best_child = None
|
| |
|
| | for action, child in node.children.items():
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | mean_val = child.value()
|
| |
|
| | if node.to_play == 1:
|
| | q = mean_val
|
| | else:
|
| | q = -mean_val
|
| |
|
| |
|
| |
|
| |
|
| |
|
| |
|
| | u = self.cpuct * child.prior * math.sqrt(node.visit_count) / (1 + child.visit_count)
|
| |
|
| | score = q + u
|
| | if score > best_score:
|
| | best_score = score
|
| | best_action = action
|
| | best_child = child
|
| |
|
| | return best_action, best_child
|
| |
|
| | def _backpropagate(self, search_path, value):
|
| | """
|
| | value: The evaluation of the lead node, from BLACK's perspective (1=Black wins, -1=White wins).
|
| | """
|
| | for node in search_path:
|
| | node.value_sum += value
|
| | node.visit_count += 1
|
| |
|
| |
|
| |
|
| | def _add_dirichlet_noise(self, node):
|
| | eps = 0.25
|
| | alpha = 0.3
|
| | moves = list(node.children.keys())
|
| | noise = np.random.dirichlet([alpha] * len(moves))
|
| |
|
| | for i, move in enumerate(moves):
|
| | node.children[move].prior = (1 - eps) * node.children[move].prior + eps * noise[i]
|
| |
|
| | def _get_moves_list(self, moves_bb):
|
| | moves = []
|
| | if moves_bb == 0: return []
|
| |
|
| |
|
| |
|
| |
|
| | temp = moves_bb
|
| | while temp:
|
| |
|
| | lsb = temp & -temp
|
| | moves.append(lsb)
|
| | temp ^= lsb
|
| | return moves
|
| |
|
| | def _clone_game(self, game):
|
| | new_game = OthelloGame()
|
| | new_game.player_bb = game.player_bb
|
| | new_game.opponent_bb = game.opponent_bb
|
| | new_game.turn = game.turn
|
| | return new_game
|
| |
|