Spaces:
Sleeping
Sleeping
| """ | |
| MCTS (Monte Carlo Tree Search) implementation for AlphaZero-style self-play. | |
| This module provides a pure MCTS implementation that can work with or without | |
| a neural network. When using a neural network, it uses the network's value | |
| and policy predictions to guide the search. | |
| """ | |
| import math | |
| from dataclasses import dataclass | |
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| from engine.game.game_state import GameState | |
| class MCTSConfig: | |
| """Configuration for MCTS""" | |
| num_simulations: int = 10 # Number of simulations per move | |
| c_puct: float = 1.4 # Exploration constant | |
| dirichlet_alpha: float = 0.3 # For root exploration noise | |
| dirichlet_epsilon: float = 0.25 # Fraction of noise added to prior | |
| virtual_loss: float = 3.0 # Virtual loss for parallel search | |
| temperature: float = 1.0 # Policy temperature | |
| class MCTSNode: | |
| """A node in the MCTS tree""" | |
| def __init__(self, prior: float = 1.0): | |
| self.visit_count = 0 | |
| self.value_sum = 0.0 | |
| self.virtual_loss = 0.0 # Accumulated virtual loss | |
| self.prior = prior # Prior probability from policy network | |
| self.children: Dict[int, "MCTSNode"] = {} | |
| self.state: Optional[GameState] = None | |
| def value(self) -> float: | |
| """Average value of this node (adjusted for virtual loss)""" | |
| if self.visit_count == 0: | |
| return 0.0 - self.virtual_loss | |
| # Q = (W - VL) / N | |
| # Standard approach: subtract virtual loss from value sum logic? | |
| # Or (W / N) - VL? | |
| # AlphaZero: Q = (W - v_loss) / N | |
| return (self.value_sum - self.virtual_loss) / (self.visit_count + 1e-8) | |
| def is_expanded(self) -> bool: | |
| return len(self.children) > 0 | |
| def select_child(self, c_puct: float) -> Tuple[int, "MCTSNode"]: | |
| """Select child with highest UCB score""" | |
| best_score = -float("inf") | |
| best_action = -1 | |
| best_child = None | |
| # Virtual loss increases denominator in some implementations, | |
| # but here we just penalize Q and rely on high N to reduce UCB exploration if visited. | |
| # But wait, we want to discourage visiting the SAME node. | |
| # So we penalize Q. | |
| sqrt_parent_visits = math.sqrt(self.visit_count) | |
| for action, child in self.children.items(): | |
| # UCB formula: Q + c * P * sqrt(N) / (1 + n) | |
| # Child value includes its own virtual loss penalty | |
| ucb = child.value + c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count) | |
| if ucb > best_score: | |
| best_score = ucb | |
| best_action = action | |
| best_child = child | |
| return best_action, best_child | |
| def expand(self, state: GameState, policy: np.ndarray) -> None: | |
| """Expand node with children for all legal actions""" | |
| self.state = state | |
| legal_actions = state.get_legal_actions() | |
| for action in range(len(legal_actions)): | |
| if legal_actions[action]: | |
| self.children[action] = MCTSNode(prior=policy[action]) | |
| class MCTS: | |
| """Monte Carlo Tree Search with AlphaZero-style neural network guidance""" | |
| def __init__(self, config: MCTSConfig = None): | |
| self.config = config or MCTSConfig() | |
| self.root = None | |
| def reset(self) -> None: | |
| """Reset the search tree""" | |
| self.root = None | |
| def get_policy_value(self, state: GameState) -> Tuple[np.ndarray, float]: | |
| """ | |
| Get policy and value from neural network. | |
| For now, uses uniform policy and random rollout value. | |
| Replace with actual neural network for full AlphaZero. | |
| """ | |
| # Uniform policy over legal actions | |
| legal = state.get_legal_actions() | |
| policy = legal.astype(np.float32) | |
| if policy.sum() > 0: | |
| policy /= policy.sum() | |
| # Random rollout for value estimation | |
| value = self._random_rollout(state) | |
| return policy, value | |
| def _random_rollout(self, state: GameState, max_steps: int = 50) -> float: | |
| """Perform random rollout to estimate value""" | |
| current = state.copy() | |
| current_player = state.current_player | |
| for _ in range(max_steps): | |
| if current.is_terminal(): | |
| return current.get_reward(current_player) | |
| legal = current.get_legal_actions() | |
| legal_indices = np.where(legal)[0] | |
| if len(legal_indices) == 0: | |
| return 0.0 | |
| action = np.random.choice(legal_indices) | |
| current = current.step(action) | |
| # Game didn't finish - use heuristic | |
| return self._heuristic_value(current, current_player) | |
| def _heuristic_value(self, state: GameState, player_idx: int) -> float: | |
| """Simple heuristic value for non-terminal states""" | |
| p = state.players[player_idx] | |
| opp = state.players[1 - player_idx] | |
| # Compare success lives | |
| my_lives = len(p.success_lives) | |
| opp_lives = len(opp.success_lives) | |
| if my_lives > opp_lives: | |
| return 0.5 + 0.1 * (my_lives - opp_lives) | |
| elif opp_lives > my_lives: | |
| return -0.5 - 0.1 * (opp_lives - my_lives) | |
| # Compare board strength | |
| my_blades = p.get_total_blades(state.member_db) | |
| opp_blades = opp.get_total_blades(state.member_db) | |
| return 0.1 * (my_blades - opp_blades) / 10.0 | |
| def search(self, state: GameState) -> np.ndarray: | |
| """ | |
| Run MCTS and return action probabilities. | |
| Args: | |
| state: Current game state | |
| Returns: | |
| Action probabilities based on visit counts | |
| """ | |
| # Initialize root | |
| policy, _ = self.get_policy_value(state) | |
| self.root = MCTSNode() | |
| self.root.expand(state, policy) | |
| # Add exploration noise at root | |
| self._add_exploration_noise(self.root) | |
| # Run simulations | |
| for _ in range(self.config.num_simulations): | |
| self._simulate(state) | |
| # Return visit count distribution | |
| visits = np.zeros(len(policy), dtype=np.float32) | |
| for action, child in self.root.children.items(): | |
| visits[action] = child.visit_count | |
| # Apply temperature | |
| if self.config.temperature == 0: | |
| # Greedy - pick best | |
| best = np.argmax(visits) | |
| visits = np.zeros_like(visits) | |
| visits[best] = 1.0 | |
| else: | |
| # Softmax with temperature | |
| visits = np.power(visits, 1.0 / self.config.temperature) | |
| if visits.sum() > 0: | |
| visits /= visits.sum() | |
| return visits | |
| def _add_exploration_noise(self, node: MCTSNode) -> None: | |
| """Add Dirichlet noise to root node for exploration""" | |
| actions = list(node.children.keys()) | |
| if not actions: | |
| return | |
| noise = np.random.dirichlet([self.config.dirichlet_alpha] * len(actions)) | |
| for i, action in enumerate(actions): | |
| child = node.children[action] | |
| child.prior = (1 - self.config.dirichlet_epsilon) * child.prior + self.config.dirichlet_epsilon * noise[i] | |
| def _simulate(self, root_state: GameState) -> None: | |
| """Run one MCTS simulation""" | |
| node = self.root | |
| state = root_state.copy() | |
| search_path = [node] | |
| # Selection - traverse tree until we reach a leaf | |
| while node.is_expanded() and not state.is_terminal(): | |
| action, node = node.select_child(self.config.c_puct) | |
| state = state.step(action) | |
| search_path.append(node) | |
| # Get value for this node | |
| if state.is_terminal(): | |
| value = state.get_reward(root_state.current_player) | |
| else: | |
| # Expansion | |
| policy, value = self.get_policy_value(state) | |
| node.expand(state, policy) | |
| # Backpropagation | |
| for node in reversed(search_path): | |
| node.visit_count += 1 | |
| node.value_sum += value | |
| value = -value # Flip value for opponent's perspective | |
| def select_action(self, state: GameState, greedy: bool = False) -> int: | |
| """Select action based on MCTS policy""" | |
| temp = self.config.temperature | |
| if greedy: | |
| self.config.temperature = 0 | |
| action_probs = self.search(state) | |
| if greedy: | |
| self.config.temperature = temp | |
| action = np.argmax(action_probs) | |
| else: | |
| action = np.random.choice(len(action_probs), p=action_probs) | |
| return action | |
| def play_game(mcts1: MCTS, mcts2: MCTS, verbose: bool = True) -> int: | |
| """ | |
| Play a complete game between two MCTS agents. | |
| Returns: | |
| Winner (0 or 1) or 2 for draw | |
| """ | |
| from engine.game.game_state import initialize_game | |
| state = initialize_game() | |
| mcts_players = [mcts1, mcts2] | |
| move_count = 0 | |
| max_moves = 500 | |
| while not state.is_terminal() and move_count < max_moves: | |
| current_mcts = mcts_players[state.current_player] | |
| action = current_mcts.select_action(state) | |
| if verbose and move_count % 10 == 0: | |
| print(f"Move {move_count}: Player {state.current_player}, Phase {state.phase.name}, Action {action}") | |
| state = state.step(action) | |
| move_count += 1 | |
| if state.is_terminal(): | |
| winner = state.get_winner() | |
| if verbose: | |
| print(f"Game over after {move_count} moves. Winner: {winner}") | |
| return winner | |
| else: | |
| if verbose: | |
| print(f"Game exceeded {max_moves} moves, declaring draw") | |
| return 2 | |
| def self_play(num_games: int = 10, simulations: int = 50) -> List[Tuple[List, List, int]]: | |
| """ | |
| Run self-play games to generate training data. | |
| Returns: | |
| List of (states, policies, winner) tuples for training | |
| """ | |
| training_data = [] | |
| config = MCTSConfig(num_simulations=simulations) | |
| for game_idx in range(num_games): | |
| from game.game_state import initialize_game | |
| state = initialize_game() | |
| mcts = MCTS(config) | |
| game_states = [] | |
| game_policies = [] | |
| move_count = 0 | |
| max_moves = 500 | |
| while not state.is_terminal() and move_count < max_moves: | |
| # Get MCTS policy | |
| policy = mcts.search(state) | |
| # Store state and policy for training | |
| game_states.append(state.get_observation()) | |
| game_policies.append(policy) | |
| # Select action | |
| action = np.random.choice(len(policy), p=policy) | |
| state = state.step(action) | |
| # Reset MCTS tree for next move | |
| mcts.reset() | |
| move_count += 1 | |
| winner = state.get_winner() if state.is_terminal() else 2 | |
| training_data.append((game_states, game_policies, winner)) | |
| print(f"Game {game_idx + 1}/{num_games} complete. Moves: {move_count}, Winner: {winner}") | |
| return training_data | |
| if __name__ == "__main__": | |
| print("Testing MCTS self-play...") | |
| # Quick test game | |
| config = MCTSConfig(num_simulations=20) # Low for testing | |
| mcts1 = MCTS(config) | |
| mcts2 = MCTS(config) | |
| winner = play_game(mcts1, mcts2, verbose=True) | |
| print(f"Test game complete. Winner: {winner}") | |