AubreeL commited on
Commit
5ebf90d
·
verified ·
1 Parent(s): e763fc3

Upload mcts.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. mcts.py +94 -0
mcts.py ADDED
@@ -0,0 +1,94 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import chess
4
+ from model import TinyPCN, encode_board, encode_move
5
+
6
+ class MCTSNode:
7
+ def __init__(self, board, parent=None, move=None):
8
+ self.board = board.copy()
9
+ self.parent = parent
10
+ self.move = move
11
+ self.children = {}
12
+ self.N = 0 # visit count
13
+ self.W = 0.0 # total value
14
+ self.Q = 0.0 # mean value
15
+ self.P = None # prior probability
16
+
17
+ def is_expanded(self):
18
+ return len(self.children) > 0
19
+
20
+
21
+ def softmax(x):
22
+ x = np.array(x)
23
+ x = x - np.max(x)
24
+ exp_x = np.exp(x)
25
+ return exp_x / np.sum(exp_x)
26
+
27
+
28
+ def select_child(node, c_puct=1.0):
29
+ best_score = -float('inf')
30
+ best_child = None
31
+ for move, child in node.children.items():
32
+ u = c_puct * child.P * np.sqrt(node.N) / (1 + child.N)
33
+ score = child.Q + u
34
+ if score > best_score:
35
+ best_score = score
36
+ best_child = child
37
+ return best_child
38
+
39
+
40
+ def expand_node(node, net):
41
+ board_tensor = encode_board(node.board, "18").unsqueeze(0)
42
+ with torch.no_grad():
43
+ policy_logits, value = net(board_tensor)
44
+ policy = torch.softmax(policy_logits[0], dim=0).cpu().numpy()
45
+ value = float(value.item())
46
+ legal_moves = list(node.board.legal_moves)
47
+ move_indices = [encode_move(m, node.board) for m in legal_moves]
48
+ # Filter out invalid indices and use safe indexing
49
+ policy_scores = []
50
+ for i in move_indices:
51
+ if i >= 0 and i < len(policy):
52
+ policy_scores.append(policy[i])
53
+ else:
54
+ policy_scores.append(1e-9) # small prior for invalid moves
55
+ priors = softmax(policy_scores)
56
+ node.P = 1.0 # root prior
57
+ for move, p in zip(legal_moves, priors):
58
+ next_board = node.board.copy()
59
+ next_board.push(move)
60
+ node.children[move] = MCTSNode(next_board, parent=node, move=move)
61
+ node.children[move].P = p
62
+ return value
63
+
64
+
65
+ def backup(node, value):
66
+ while node:
67
+ node.N += 1
68
+ node.W += value
69
+ node.Q = node.W / node.N
70
+ value = -value # switch perspective
71
+ node = node.parent
72
+
73
+
74
+ def mcts_search(root, net, num_simulations=100, c_puct=1.0):
75
+ for _ in range(num_simulations):
76
+ node = root
77
+ # Selection
78
+ while node.is_expanded() and node.children:
79
+ node = select_child(node, c_puct)
80
+ # Expansion & Evaluation
81
+ value = expand_node(node, net)
82
+ # Backup
83
+ backup(node, value)
84
+ # Return visit counts for root's children
85
+ move_visits = {move: child.N for move, child in root.children.items()}
86
+ return move_visits
87
+
88
+ # Example usage:
89
+ if __name__ == "__main__":
90
+ net = TinyPCN()
91
+ board = chess.Board()
92
+ root = MCTSNode(board)
93
+ mcts_search(root, net, num_simulations=50)
94
+ print({str(move): n for move, n in root.children.items()})