File size: 11,612 Bytes
5206d7b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
"""

MCTS (Monte Carlo Tree Search) implementation for AlphaZero-style self-play.



This module provides a pure MCTS implementation that can work with or without

a neural network. When using a neural network, it uses the network's value

and policy predictions to guide the search.

"""

import math
from dataclasses import dataclass
from typing import Dict, List, Optional, Tuple

import numpy as np

from engine.game.game_state import GameState


@dataclass
class MCTSConfig:
    """Configuration for MCTS"""

    num_simulations: int = 10  # Number of simulations per move
    c_puct: float = 1.4  # Exploration constant
    dirichlet_alpha: float = 0.3  # For root exploration noise
    dirichlet_epsilon: float = 0.25  # Fraction of noise added to prior
    virtual_loss: float = 3.0  # Virtual loss for parallel search
    temperature: float = 1.0  # Policy temperature


class MCTSNode:
    """A node in the MCTS tree"""

    def __init__(self, prior: float = 1.0):
        self.visit_count = 0
        self.value_sum = 0.0
        self.virtual_loss = 0.0  # Accumulated virtual loss
        self.prior = prior  # Prior probability from policy network
        self.children: Dict[int, "MCTSNode"] = {}
        self.state: Optional[GameState] = None

    @property
    def value(self) -> float:
        """Average value of this node (adjusted for virtual loss)"""
        if self.visit_count == 0:
            return 0.0 - self.virtual_loss
        # Q = (W - VL) / N
        # Standard approach: subtract virtual loss from value sum logic?
        # Or (W / N) - VL?
        # AlphaZero: Q = (W - v_loss) / N
        return (self.value_sum - self.virtual_loss) / (self.visit_count + 1e-8)

    def is_expanded(self) -> bool:
        return len(self.children) > 0

    def select_child(self, c_puct: float) -> Tuple[int, "MCTSNode"]:
        """Select child with highest UCB score"""
        best_score = -float("inf")
        best_action = -1
        best_child = None

        # Virtual loss increases denominator in some implementations,
        # but here we just penalize Q and rely on high N to reduce UCB exploration if visited.
        # But wait, we want to discourage visiting the SAME node.
        # So we penalize Q.

        sqrt_parent_visits = math.sqrt(self.visit_count)

        for action, child in self.children.items():
            # UCB formula: Q + c * P * sqrt(N) / (1 + n)
            # Child value includes its own virtual loss penalty
            ucb = child.value + c_puct * child.prior * sqrt_parent_visits / (1 + child.visit_count)

            if ucb > best_score:
                best_score = ucb
                best_action = action
                best_child = child

        return best_action, best_child

    def expand(self, state: GameState, policy: np.ndarray) -> None:
        """Expand node with children for all legal actions"""
        self.state = state
        legal_actions = state.get_legal_actions()

        for action in range(len(legal_actions)):
            if legal_actions[action]:
                self.children[action] = MCTSNode(prior=policy[action])


class MCTS:
    """Monte Carlo Tree Search with AlphaZero-style neural network guidance"""

    def __init__(self, config: MCTSConfig = None):
        self.config = config or MCTSConfig()
        self.root = None

    def reset(self) -> None:
        """Reset the search tree"""
        self.root = None

    def get_policy_value(self, state: GameState) -> Tuple[np.ndarray, float]:
        """

        Get policy and value from neural network.



        For now, uses uniform policy and random rollout value.

        Replace with actual neural network for full AlphaZero.

        """
        # Uniform policy over legal actions
        legal = state.get_legal_actions()
        policy = legal.astype(np.float32)
        if policy.sum() > 0:
            policy /= policy.sum()

        # Random rollout for value estimation
        value = self._random_rollout(state)

        return policy, value

    def _random_rollout(self, state: GameState, max_steps: int = 50) -> float:
        """Perform random rollout to estimate value"""
        current = state.copy()
        current_player = state.current_player

        for _ in range(max_steps):
            if current.is_terminal():
                return current.get_reward(current_player)

            legal = current.get_legal_actions()
            legal_indices = np.where(legal)[0]

            if len(legal_indices) == 0:
                return 0.0

            action = np.random.choice(legal_indices)
            current = current.step(action)

        # Game didn't finish - use heuristic
        return self._heuristic_value(current, current_player)

    def _heuristic_value(self, state: GameState, player_idx: int) -> float:
        """Simple heuristic value for non-terminal states"""
        p = state.players[player_idx]
        opp = state.players[1 - player_idx]

        # Compare success lives
        my_lives = len(p.success_lives)
        opp_lives = len(opp.success_lives)

        if my_lives > opp_lives:
            return 0.5 + 0.1 * (my_lives - opp_lives)
        elif opp_lives > my_lives:
            return -0.5 - 0.1 * (opp_lives - my_lives)

        # Compare board strength
        my_blades = p.get_total_blades(state.member_db)
        opp_blades = opp.get_total_blades(state.member_db)

        return 0.1 * (my_blades - opp_blades) / 10.0

    def search(self, state: GameState) -> np.ndarray:
        """

        Run MCTS and return action probabilities.



        Args:

            state: Current game state



        Returns:

            Action probabilities based on visit counts

        """
        # Initialize root
        policy, _ = self.get_policy_value(state)
        self.root = MCTSNode()
        self.root.expand(state, policy)

        # Add exploration noise at root
        self._add_exploration_noise(self.root)

        # Run simulations
        for _ in range(self.config.num_simulations):
            self._simulate(state)

        # Return visit count distribution
        visits = np.zeros(len(policy), dtype=np.float32)
        for action, child in self.root.children.items():
            visits[action] = child.visit_count

        # Apply temperature
        if self.config.temperature == 0:
            # Greedy - pick best
            best = np.argmax(visits)
            visits = np.zeros_like(visits)
            visits[best] = 1.0
        else:
            # Softmax with temperature
            visits = np.power(visits, 1.0 / self.config.temperature)

        if visits.sum() > 0:
            visits /= visits.sum()

        return visits

    def _add_exploration_noise(self, node: MCTSNode) -> None:
        """Add Dirichlet noise to root node for exploration"""
        actions = list(node.children.keys())
        if not actions:
            return

        noise = np.random.dirichlet([self.config.dirichlet_alpha] * len(actions))

        for i, action in enumerate(actions):
            child = node.children[action]
            child.prior = (1 - self.config.dirichlet_epsilon) * child.prior + self.config.dirichlet_epsilon * noise[i]

    def _simulate(self, root_state: GameState) -> None:
        """Run one MCTS simulation"""
        node = self.root
        state = root_state.copy()
        search_path = [node]

        # Selection - traverse tree until we reach a leaf
        while node.is_expanded() and not state.is_terminal():
            action, node = node.select_child(self.config.c_puct)
            state = state.step(action)
            search_path.append(node)

        # Get value for this node
        if state.is_terminal():
            value = state.get_reward(root_state.current_player)
        else:
            # Expansion
            policy, value = self.get_policy_value(state)
            node.expand(state, policy)

        # Backpropagation
        for node in reversed(search_path):
            node.visit_count += 1
            node.value_sum += value
            value = -value  # Flip value for opponent's perspective

    def select_action(self, state: GameState, greedy: bool = False) -> int:
        """Select action based on MCTS policy"""
        temp = self.config.temperature
        if greedy:
            self.config.temperature = 0

        action_probs = self.search(state)

        if greedy:
            self.config.temperature = temp
            action = np.argmax(action_probs)
        else:
            action = np.random.choice(len(action_probs), p=action_probs)

        return action


def play_game(mcts1: MCTS, mcts2: MCTS, verbose: bool = True) -> int:
    """

    Play a complete game between two MCTS agents.



    Returns:

        Winner (0 or 1) or 2 for draw

    """
    from engine.game.game_state import initialize_game

    state = initialize_game()
    mcts_players = [mcts1, mcts2]

    move_count = 0
    max_moves = 500

    while not state.is_terminal() and move_count < max_moves:
        current_mcts = mcts_players[state.current_player]
        action = current_mcts.select_action(state)

        if verbose and move_count % 10 == 0:
            print(f"Move {move_count}: Player {state.current_player}, Phase {state.phase.name}, Action {action}")

        state = state.step(action)
        move_count += 1

    if state.is_terminal():
        winner = state.get_winner()
        if verbose:
            print(f"Game over after {move_count} moves. Winner: {winner}")
        return winner
    else:
        if verbose:
            print(f"Game exceeded {max_moves} moves, declaring draw")
        return 2


def self_play(num_games: int = 10, simulations: int = 50) -> List[Tuple[List, List, int]]:
    """

    Run self-play games to generate training data.



    Returns:

        List of (states, policies, winner) tuples for training

    """
    training_data = []
    config = MCTSConfig(num_simulations=simulations)

    for game_idx in range(num_games):
        from game.game_state import initialize_game

        state = initialize_game()
        mcts = MCTS(config)

        game_states = []
        game_policies = []

        move_count = 0
        max_moves = 500

        while not state.is_terminal() and move_count < max_moves:
            # Get MCTS policy
            policy = mcts.search(state)

            # Store state and policy for training
            game_states.append(state.get_observation())
            game_policies.append(policy)

            # Select action
            action = np.random.choice(len(policy), p=policy)
            state = state.step(action)

            # Reset MCTS tree for next move
            mcts.reset()
            move_count += 1

        winner = state.get_winner() if state.is_terminal() else 2
        training_data.append((game_states, game_policies, winner))

        print(f"Game {game_idx + 1}/{num_games} complete. Moves: {move_count}, Winner: {winner}")

    return training_data


if __name__ == "__main__":
    print("Testing MCTS self-play...")

    # Quick test game
    config = MCTSConfig(num_simulations=20)  # Low for testing
    mcts1 = MCTS(config)
    mcts2 = MCTS(config)

    winner = play_game(mcts1, mcts2, verbose=True)
    print(f"Test game complete. Winner: {winner}")