trioskosmos commited on
Commit
5206d7b
·
verified ·
1 Parent(s): 25d7b1e

Upload ai/agents/mcts.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. ai/agents/mcts.py +348 -0
ai/agents/mcts.py ADDED
@@ -0,0 +1,348 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MCTS (Monte Carlo Tree Search) implementation for AlphaZero-style self-play.
3
+
4
+ This module provides a pure MCTS implementation that can work with or without
5
+ a neural network. When using a neural network, it uses the network's value
6
+ and policy predictions to guide the search.
7
+ """
8
+
9
+ import math
10
+ from dataclasses import dataclass
11
+ from typing import Dict, List, Optional, Tuple
12
+
13
+ import numpy as np
14
+
15
+ from engine.game.game_state import GameState
16
+
17
+
18
+ @dataclass
19
+ class MCTSConfig:
20
+ """Configuration for MCTS"""
21
+
22
+ num_simulations: int = 10 # Number of simulations per move
23
+ c_puct: float = 1.4 # Exploration constant
24
+ dirichlet_alpha: float = 0.3 # For root exploration noise
25
+ dirichlet_epsilon: float = 0.25 # Fraction of noise added to prior
26
+ virtual_loss: float = 3.0 # Virtual loss for parallel search
27
+ temperature: float = 1.0 # Policy temperature
28
+
29
+
30
+ class MCTSNode:
31
+ """A node in the MCTS tree"""
32
+
33
+ def __init__(self, prior: float = 1.0):
34
+ self.visit_count = 0
35
+ self.value_sum = 0.0
36
+ self.virtual_loss = 0.0 # Accumulated virtual loss
37
+ self.prior = prior # Prior probability from policy network
38
+ self.children: Dict[int, "MCTSNode"] = {}
39
+ self.state: Optional[GameState] = None
40
+
41
+ @property
42
+ def value(self) -> float:
43
+ """Average value of this node (adjusted for virtual loss)"""
44
+ if self.visit_count == 0:
45
+ return 0.0 - self.virtual_loss
46
+ # Q = (W - VL) / N
47
+ # Standard approach: subtract virtual loss from value sum logic?
48
+ # Or (W / N) - VL?
49
+ # AlphaZero: Q = (W - v_loss) / N
50
+ return (self.value_sum - self.virtual_loss) / (self.visit_count + 1e-8)
51
+
52
+ def is_expanded(self) -> bool:
53
+ return len(self.children) > 0
54
+
55
+ def select_child(self, c_puct: float) -> Tuple[int, "MCTSNode"]:
56
+ """Select child with highest UCB score"""
57
+ best_score = -float("inf")
58
+ best_action = -1
59
+ best_child = None
60
+
61
+ # Virtual loss increases denominator in some implementations,
62
+ # but here we just penalize Q and rely on high N to reduce UCB exploration if visited.
63
+ # But wait, we want to discourage visiting the SAME node.
64
+ # So we penalize Q.
65
+
66
+ sqrt_parent_visits = math.sqrt(self.visit_count)
67
+
68
+ for action, child in self.children.items():
69
+ # UCB formula: Q + c * P * sqrt(N) / (1 + n)
70
+ # Child value includes its own virtual loss penalty
71
+ ucb = child.value + c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)
72
+
73
+ if ucb > best_score:
74
+ best_score = ucb
75
+ best_action = action
76
+ best_child = child
77
+
78
+ return best_action, best_child
79
+
80
+ def expand(self, state: GameState, policy: np.ndarray) -> None:
81
+ """Expand node with children for all legal actions"""
82
+ self.state = state
83
+ legal_actions = state.get_legal_actions()
84
+
85
+ for action in range(len(legal_actions)):
86
+ if legal_actions[action]:
87
+ self.children[action] = MCTSNode(prior=policy[action])
88
+
89
+
90
+ class MCTS:
91
+ """Monte Carlo Tree Search with AlphaZero-style neural network guidance"""
92
+
93
+ def __init__(self, config: MCTSConfig = None):
94
+ self.config = config or MCTSConfig()
95
+ self.root = None
96
+
97
+ def reset(self) -> None:
98
+ """Reset the search tree"""
99
+ self.root = None
100
+
101
+ def get_policy_value(self, state: GameState) -> Tuple[np.ndarray, float]:
102
+ """
103
+ Get policy and value from neural network.
104
+
105
+ For now, uses uniform policy and random rollout value.
106
+ Replace with actual neural network for full AlphaZero.
107
+ """
108
+ # Uniform policy over legal actions
109
+ legal = state.get_legal_actions()
110
+ policy = legal.astype(np.float32)
111
+ if policy.sum() > 0:
112
+ policy /= policy.sum()
113
+
114
+ # Random rollout for value estimation
115
+ value = self._random_rollout(state)
116
+
117
+ return policy, value
118
+
119
+ def _random_rollout(self, state: GameState, max_steps: int = 50) -> float:
120
+ """Perform random rollout to estimate value"""
121
+ current = state.copy()
122
+ current_player = state.current_player
123
+
124
+ for _ in range(max_steps):
125
+ if current.is_terminal():
126
+ return current.get_reward(current_player)
127
+
128
+ legal = current.get_legal_actions()
129
+ legal_indices = np.where(legal)[0]
130
+
131
+ if len(legal_indices) == 0:
132
+ return 0.0
133
+
134
+ action = np.random.choice(legal_indices)
135
+ current = current.step(action)
136
+
137
+ # Game didn't finish - use heuristic
138
+ return self._heuristic_value(current, current_player)
139
+
140
+ def _heuristic_value(self, state: GameState, player_idx: int) -> float:
141
+ """Simple heuristic value for non-terminal states"""
142
+ p = state.players[player_idx]
143
+ opp = state.players[1 - player_idx]
144
+
145
+ # Compare success lives
146
+ my_lives = len(p.success_lives)
147
+ opp_lives = len(opp.success_lives)
148
+
149
+ if my_lives > opp_lives:
150
+ return 0.5 + 0.1 * (my_lives - opp_lives)
151
+ elif opp_lives > my_lives:
152
+ return -0.5 - 0.1 * (opp_lives - my_lives)
153
+
154
+ # Compare board strength
155
+ my_blades = p.get_total_blades(state.member_db)
156
+ opp_blades = opp.get_total_blades(state.member_db)
157
+
158
+ return 0.1 * (my_blades - opp_blades) / 10.0
159
+
160
+ def search(self, state: GameState) -> np.ndarray:
161
+ """
162
+ Run MCTS and return action probabilities.
163
+
164
+ Args:
165
+ state: Current game state
166
+
167
+ Returns:
168
+ Action probabilities based on visit counts
169
+ """
170
+ # Initialize root
171
+ policy, _ = self.get_policy_value(state)
172
+ self.root = MCTSNode()
173
+ self.root.expand(state, policy)
174
+
175
+ # Add exploration noise at root
176
+ self._add_exploration_noise(self.root)
177
+
178
+ # Run simulations
179
+ for _ in range(self.config.num_simulations):
180
+ self._simulate(state)
181
+
182
+ # Return visit count distribution
183
+ visits = np.zeros(len(policy), dtype=np.float32)
184
+ for action, child in self.root.children.items():
185
+ visits[action] = child.visit_count
186
+
187
+ # Apply temperature
188
+ if self.config.temperature == 0:
189
+ # Greedy - pick best
190
+ best = np.argmax(visits)
191
+ visits = np.zeros_like(visits)
192
+ visits[best] = 1.0
193
+ else:
194
+ # Softmax with temperature
195
+ visits = np.power(visits, 1.0 / self.config.temperature)
196
+
197
+ if visits.sum() > 0:
198
+ visits /= visits.sum()
199
+
200
+ return visits
201
+
202
+ def _add_exploration_noise(self, node: MCTSNode) -> None:
203
+ """Add Dirichlet noise to root node for exploration"""
204
+ actions = list(node.children.keys())
205
+ if not actions:
206
+ return
207
+
208
+ noise = np.random.dirichlet([self.config.dirichlet_alpha] * len(actions))
209
+
210
+ for i, action in enumerate(actions):
211
+ child = node.children[action]
212
+ child.prior = (1 - self.config.dirichlet_epsilon) * child.prior + self.config.dirichlet_epsilon * noise[i]
213
+
214
+ def _simulate(self, root_state: GameState) -> None:
215
+ """Run one MCTS simulation"""
216
+ node = self.root
217
+ state = root_state.copy()
218
+ search_path = [node]
219
+
220
+ # Selection - traverse tree until we reach a leaf
221
+ while node.is_expanded() and not state.is_terminal():
222
+ action, node = node.select_child(self.config.c_puct)
223
+ state = state.step(action)
224
+ search_path.append(node)
225
+
226
+ # Get value for this node
227
+ if state.is_terminal():
228
+ value = state.get_reward(root_state.current_player)
229
+ else:
230
+ # Expansion
231
+ policy, value = self.get_policy_value(state)
232
+ node.expand(state, policy)
233
+
234
+ # Backpropagation
235
+ for node in reversed(search_path):
236
+ node.visit_count += 1
237
+ node.value_sum += value
238
+ value = -value # Flip value for opponent's perspective
239
+
240
+ def select_action(self, state: GameState, greedy: bool = False) -> int:
241
+ """Select action based on MCTS policy"""
242
+ temp = self.config.temperature
243
+ if greedy:
244
+ self.config.temperature = 0
245
+
246
+ action_probs = self.search(state)
247
+
248
+ if greedy:
249
+ self.config.temperature = temp
250
+ action = np.argmax(action_probs)
251
+ else:
252
+ action = np.random.choice(len(action_probs), p=action_probs)
253
+
254
+ return action
255
+
256
+
257
+ def play_game(mcts1: MCTS, mcts2: MCTS, verbose: bool = True) -> int:
258
+ """
259
+ Play a complete game between two MCTS agents.
260
+
261
+ Returns:
262
+ Winner (0 or 1) or 2 for draw
263
+ """
264
+ from engine.game.game_state import initialize_game
265
+
266
+ state = initialize_game()
267
+ mcts_players = [mcts1, mcts2]
268
+
269
+ move_count = 0
270
+ max_moves = 500
271
+
272
+ while not state.is_terminal() and move_count < max_moves:
273
+ current_mcts = mcts_players[state.current_player]
274
+ action = current_mcts.select_action(state)
275
+
276
+ if verbose and move_count % 10 == 0:
277
+ print(f"Move {move_count}: Player {state.current_player}, Phase {state.phase.name}, Action {action}")
278
+
279
+ state = state.step(action)
280
+ move_count += 1
281
+
282
+ if state.is_terminal():
283
+ winner = state.get_winner()
284
+ if verbose:
285
+ print(f"Game over after {move_count} moves. Winner: {winner}")
286
+ return winner
287
+ else:
288
+ if verbose:
289
+ print(f"Game exceeded {max_moves} moves, declaring draw")
290
+ return 2
291
+
292
+
293
+ def self_play(num_games: int = 10, simulations: int = 50) -> List[Tuple[List, List, int]]:
294
+ """
295
+ Run self-play games to generate training data.
296
+
297
+ Returns:
298
+ List of (states, policies, winner) tuples for training
299
+ """
300
+ training_data = []
301
+ config = MCTSConfig(num_simulations=simulations)
302
+
303
+ for game_idx in range(num_games):
304
+ from game.game_state import initialize_game
305
+
306
+ state = initialize_game()
307
+ mcts = MCTS(config)
308
+
309
+ game_states = []
310
+ game_policies = []
311
+
312
+ move_count = 0
313
+ max_moves = 500
314
+
315
+ while not state.is_terminal() and move_count < max_moves:
316
+ # Get MCTS policy
317
+ policy = mcts.search(state)
318
+
319
+ # Store state and policy for training
320
+ game_states.append(state.get_observation())
321
+ game_policies.append(policy)
322
+
323
+ # Select action
324
+ action = np.random.choice(len(policy), p=policy)
325
+ state = state.step(action)
326
+
327
+ # Reset MCTS tree for next move
328
+ mcts.reset()
329
+ move_count += 1
330
+
331
+ winner = state.get_winner() if state.is_terminal() else 2
332
+ training_data.append((game_states, game_policies, winner))
333
+
334
+ print(f"Game {game_idx + 1}/{num_games} complete. Moves: {move_count}, Winner: {winner}")
335
+
336
+ return training_data
337
+
338
+
339
+ if __name__ == "__main__":
340
+ print("Testing MCTS self-play...")
341
+
342
+ # Quick test game
343
+ config = MCTSConfig(num_simulations=20) # Low for testing
344
+ mcts1 = MCTS(config)
345
+ mcts2 = MCTS(config)
346
+
347
+ winner = play_game(mcts1, mcts2, verbose=True)
348
+ print(f"Test game complete. Winner: {winner}")