dualist / mcts.py
brandonlanexyz's picture
Initial upload of Dualist Othello AI (Iteration 652)
9deb5ea verified
import math
import numpy as np
import torch
from src.game import OthelloGame
from src.bitboard import make_input_planes, bit_to_row_col, popcount
class MCTSNode:
def __init__(self, prior, to_play):
self.prior = prior
self.visit_count = 0
self.value_sum = 0
self.children = {} # move_bit -> MCTSNode
self.to_play = to_play # Whose turn it is at this node
def value(self):
if self.visit_count == 0:
return 0
return self.value_sum / self.visit_count
def expand(self, policy_logits, valid_moves, next_to_play):
"""
Expands the node using the policy from the neural network.
"""
# Softmax
policy = np.exp(policy_logits - np.max(policy_logits)) # Stable softmax
policy /= np.sum(policy)
# Masking invalid moves?
# Ideally we only care about probabilities of valid moves.
# But indices 0-63 + 64 (pass).
valid_probs_sum = 0
temp_children = {}
for move_bit in valid_moves:
if move_bit == 0: # Pass
idx = 64
else:
r, c = bit_to_row_col(move_bit)
# Helper returns r,c. Index is r*8+c.
if r == -1: idx = 64 # Should not happen if move_bit != 0
else: idx = r * 8 + c
prob = policy[idx]
valid_probs_sum += prob
temp_children[move_bit] = prob
# Renormalize probabilities over valid moves
if valid_probs_sum > 0:
for move, prob in temp_children.items():
self.children[move] = MCTSNode(prior=prob / valid_probs_sum, to_play=next_to_play)
else:
# Unexpected: Policy gave 0 prob to all valid moves. Uniform.
prob = 1.0 / len(valid_moves)
for move in valid_moves:
self.children[move] = MCTSNode(prior=prob, to_play=next_to_play)
class MCTS:
def __init__(self, model, cpuct=1.0, num_simulations=800):
self.model = model
self.cpuct = cpuct
self.num_simulations = num_simulations
def search(self, game: OthelloGame):
"""
Executes MCTS simulations and returns the root node (containing mechanics for move selection).
"""
# Create Root
valid_moves_bb = game.get_valid_moves(game.player_bb, game.opponent_bb)
valid_moves_list = self._get_moves_list(valid_moves_bb)
# Handle case where current player has no moves.
# In Othello, if one cannot move, they Pass (move=0).
# Unless BOTH cannot move, then Terminal.
# game.get_valid_moves returns 0 if no moves.
if valid_moves_bb == 0:
if game.is_terminal():
return None # Game Over
valid_moves_list = [0]
# Evaluate Root (to initialize it)
root = MCTSNode(prior=0, to_play=game.turn)
# Input for NN: Always Canonical (Player, Opponent)
state_tensor = make_input_planes(game.player_bb, game.opponent_bb)
# Move to model device
device = next(self.model.parameters()).device
state_tensor = state_tensor.to(device)
self.model.eval()
with torch.no_grad():
policy_logits, _ = self.model(state_tensor)
# Determine next player for root's children
# If we play a move, the turn usually swaps.
# But we need to check if the move was a Pass?
# Logic: Node stores `to_play`. Children stores `next_to_play`.
# In `expand`, we pass `next_to_play`.
# But `next_to_play` depends on the move?
# Usually yes. But in Othello, turn ALWAYS swaps unless... wait.
# If I play a move, it is now Opponent's turn.
# Even if Opponent has to Pass immediately, it is THEIR turn to Pass.
# So `next_to_play` is always `-game.turn`.
root.expand(policy_logits.cpu().numpy().flatten(), valid_moves_list, -game.turn)
# Add exploration noise to root
self._add_dirichlet_noise(root)
for _ in range(self.num_simulations):
node = root
sim_game = self._clone_game(game)
search_path = [node]
last_value = 0
# 1. Selection
while node.children:
move_bit, node = self._select_child(node)
search_path.append(node)
sim_game.play_move(move_bit)
# 2. Evaluation & Expansion
if sim_game.is_terminal():
# Terminal Value from perspective of Current Turn (sim_game.turn)
# Wait, if terminal, there is no turn.
# Value relative to `node.to_play` (which determines who Just Passed/Finished?)
# Let's standarize: Value relative to Black (1).
p1_score = popcount(sim_game.player_bb) if sim_game.turn == 1 else popcount(sim_game.opponent_bb) # game.player_bb tracks 'Current Player'
# wait. sim_game.player_bb is WHOSE turn it is.
# If sim_game.turn == 1 (Black). player_bb is Black.
# If sim_game.turn == -1 (White). player_bb is White.
# Score difference from Black's perspective
if sim_game.turn == 1:
black_score = popcount(sim_game.player_bb)
white_score = popcount(sim_game.opponent_bb)
else:
white_score = popcount(sim_game.player_bb)
black_score = popcount(sim_game.opponent_bb)
diff = black_score - white_score
if diff > 0: last_value = 1.0 # Black wins
elif diff < 0: last_value = -1.0 # White wins
else: last_value = 0.0
else:
# Evaluate
state_tensor = make_input_planes(sim_game.player_bb, sim_game.opponent_bb)
# Move to model device
device = next(self.model.parameters()).device
state_tensor = state_tensor.to(device)
with torch.no_grad():
policy_logits, v = self.model(state_tensor)
# v is value for Current Player (sim_game.turn).
# If Black -> v is prob Black wins.
# If White -> v is prob White wins.
# We need standardized value for backprop?
# Let's convert to Black's perspective.
val_for_current = v.item()
if sim_game.turn == 1:
last_value = val_for_current
else:
last_value = -val_for_current # If good for White (-1), then Bad for Black (-1).
# Wait. If White wins, val_for_current (White) = 1.
# Then last_value (Black) = -1. Correct.
valid_bb = sim_game.get_valid_moves(sim_game.player_bb, sim_game.opponent_bb)
valid_list = self._get_moves_list(valid_bb)
if valid_bb == 0: valid_list = [0]
node.expand(policy_logits.cpu().numpy().flatten(), valid_list, -sim_game.turn)
# 3. Backup
self._backpropagate(search_path, last_value)
return root
def _select_child(self, node):
best_score = -float('inf')
best_action = None
best_child = None
for action, child in node.children.items():
# UCB
# Q is value for 'node.to_play'.
# child.value() is average raw value accumulated.
# We accumulated 'Black Perspective Value'.
# If node.to_play is Black (1). We want high Value (1).
# If node.to_play is White (-1). We want low Value (-1).
# Let's adjust Q based on turn.
mean_val = child.value() # This is Black-perspective value.
if node.to_play == 1: # Black
q = mean_val
else: # White
q = -mean_val
# Normalize q to [0, 1]? Tanh gives [-1, 1].
# AlphaZero uses [0, 1]. Tanh uses [-1, 1].
# PUCT expects q and u to be comparable.
# If q in [-1, 1], u should be similar scale.
u = self.cpuct * child.prior * math.sqrt(node.visit_count) / (1 + child.visit_count)
score = q + u
if score > best_score:
best_score = score
best_action = action
best_child = child
return best_action, best_child
def _backpropagate(self, search_path, value):
"""
value: The evaluation of the lead node, from BLACK's perspective (1=Black wins, -1=White wins).
"""
for node in search_path:
node.value_sum += value
node.visit_count += 1
# We store Sum of Black-Values.
# So average is Average Black Value.
def _add_dirichlet_noise(self, node):
eps = 0.25
alpha = 0.3
moves = list(node.children.keys())
noise = np.random.dirichlet([alpha] * len(moves))
for i, move in enumerate(moves):
node.children[move].prior = (1 - eps) * node.children[move].prior + eps * noise[i]
def _get_moves_list(self, moves_bb):
moves = []
if moves_bb == 0: return []
# Extract bits
# In python integers have infinite precision, so normal bit hacks work but need care with loops.
# Ideally copy bb.
temp = moves_bb
while temp:
# Isolate LSB
lsb = temp & -temp
moves.append(lsb)
temp ^= lsb # Remove LSB
return moves
def _clone_game(self, game):
new_game = OthelloGame()
new_game.player_bb = game.player_bb
new_game.opponent_bb = game.opponent_bb
new_game.turn = game.turn
return new_game