chess-bot / mcts.py
AubreeL's picture
Upload mcts.py with huggingface_hub
5ebf90d verified
import numpy as np
import torch
import chess
from model import TinyPCN, encode_board, encode_move
class MCTSNode:
def __init__(self, board, parent=None, move=None):
self.board = board.copy()
self.parent = parent
self.move = move
self.children = {}
self.N = 0 # visit count
self.W = 0.0 # total value
self.Q = 0.0 # mean value
self.P = None # prior probability
def is_expanded(self):
return len(self.children) > 0
def softmax(x):
x = np.array(x)
x = x - np.max(x)
exp_x = np.exp(x)
return exp_x / np.sum(exp_x)
def select_child(node, c_puct=1.0):
best_score = -float('inf')
best_child = None
for move, child in node.children.items():
u = c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
score = child.Q + u
if score > best_score:
best_score = score
best_child = child
return best_child
def expand_node(node, net):
board_tensor = encode_board(node.board, "18").unsqueeze(0)
with torch.no_grad():
policy_logits, value = net(board_tensor)
policy = torch.softmax(policy_logits[0], dim=0).cpu().numpy()
value = float(value.item())
legal_moves = list(node.board.legal_moves)
move_indices = [encode_move(m, node.board) for m in legal_moves]
# Filter out invalid indices and use safe indexing
policy_scores = []
for i in move_indices:
if i >= 0 and i < len(policy):
policy_scores.append(policy[i])
else:
policy_scores.append(1e-9) # small prior for invalid moves
priors = softmax(policy_scores)
node.P = 1.0 # root prior
for move, p in zip(legal_moves, priors):
next_board = node.board.copy()
next_board.push(move)
node.children[move] = MCTSNode(next_board, parent=node, move=move)
node.children[move].P = p
return value
def backup(node, value):
while node:
node.N += 1
node.W += value
node.Q = node.W / node.N
value = -value # switch perspective
node = node.parent
def mcts_search(root, net, num_simulations=100, c_puct=1.0):
for _ in range(num_simulations):
node = root
# Selection
while node.is_expanded() and node.children:
node = select_child(node, c_puct)
# Expansion & Evaluation
value = expand_node(node, net)
# Backup
backup(node, value)
# Return visit counts for root's children
move_visits = {move: child.N for move, child in root.children.items()}
return move_visits
# Example usage:
if __name__ == "__main__":
net = TinyPCN()
board = chess.Board()
root = MCTSNode(board)
mcts_search(root, net, num_simulations=50)
print({str(move): n for move, n in root.children.items()})