Spaces:
Running
Running
| """ | |
| Neural Network for AlphaZero-style training. | |
| This module provides a simple neural network architecture for policy and value | |
| prediction. For a production system, you would use a more sophisticated | |
| architecture (e.g., ResNet with attention) and train on GPU with PyTorch/TensorFlow. | |
| """ | |
| from dataclasses import dataclass | |
| from typing import Tuple | |
| import numpy as np | |
| class NetworkConfig: | |
| """Configuration for AlphaZero Network""" | |
| input_size: int = 800 # Feature-based encoding (32 floats per card slot) | |
| # Size of observation vector (Matches GameState.get_observation) | |
| hidden_size: int = 256 | |
| num_hidden_layers: int = 3 | |
| action_size: int = 1000 # Size of action space (Matches GameState.get_legal_actions) | |
| learning_rate: float = 0.001 | |
| l2_reg: float = 0.0001 | |
| def sigmoid(x: np.ndarray) -> np.ndarray: | |
| return 1 / (1 + np.exp(-np.clip(x, -500, 500))) | |
| def relu(x: np.ndarray) -> np.ndarray: | |
| return np.maximum(0, x) | |
| def softmax(x: np.ndarray) -> np.ndarray: | |
| exp_x = np.exp(x - np.max(x)) | |
| return exp_x / exp_x.sum() | |
| def tanh(x: np.ndarray) -> np.ndarray: | |
| return np.tanh(x) | |
| class SimpleNetwork: | |
| """ | |
| Simple feedforward neural network for policy and value prediction. | |
| Architecture: | |
| - Input layer (observation) | |
| - Hidden layers with ReLU | |
| - Policy head (softmax over actions) | |
| - Value head (tanh for [-1, 1]) | |
| """ | |
| def __init__(self, config: NetworkConfig = None): | |
| self.config = config or NetworkConfig() | |
| self._init_weights() | |
| def _init_weights(self) -> None: | |
| """Initialize weights using He initialization""" | |
| config = self.config | |
| # Shared layers | |
| self.hidden_weights = [] | |
| self.hidden_biases = [] | |
| in_size = config.input_size | |
| for _ in range(config.num_hidden_layers): | |
| std = np.sqrt(2.0 / in_size) | |
| w = np.random.randn(in_size, config.hidden_size) * std | |
| b = np.zeros(config.hidden_size) | |
| self.hidden_weights.append(w) | |
| self.hidden_biases.append(b) | |
| in_size = config.hidden_size | |
| # Policy head | |
| std = np.sqrt(2.0 / config.hidden_size) | |
| self.policy_weight = np.random.randn(config.hidden_size, config.action_size) * std | |
| self.policy_bias = np.zeros(config.action_size) | |
| # Value head | |
| self.value_weight = np.random.randn(config.hidden_size, 1) * std | |
| self.value_bias = np.zeros(1) | |
| def forward(self, observation: np.ndarray) -> Tuple[np.ndarray, float]: | |
| """ | |
| Forward pass. | |
| Args: | |
| observation: Input features | |
| Returns: | |
| (policy probabilities, value) | |
| """ | |
| # Store activations for backward pass | |
| self.activations = [observation] | |
| x = observation | |
| for w, b in zip(self.hidden_weights, self.hidden_biases, strict=False): | |
| x = relu(x @ w + b) | |
| self.activations.append(x) | |
| # Policy head | |
| policy_logits = x @ self.policy_weight + self.policy_bias | |
| policy = softmax(policy_logits) | |
| # Value head | |
| value = tanh(x @ self.value_weight + self.value_bias)[0] | |
| self.last_policy_logits = policy_logits | |
| self.last_value = value | |
| return policy, value | |
| def predict(self, state) -> Tuple[np.ndarray, float]: | |
| """Get policy and value for a game state""" | |
| obs = state.get_observation() | |
| policy, value = self.forward(obs) | |
| # Mask illegal actions | |
| legal = state.get_legal_actions() | |
| masked_policy = policy * legal | |
| if masked_policy.sum() > 0: | |
| masked_policy /= masked_policy.sum() | |
| else: | |
| # Fall back to uniform over legal | |
| masked_policy = legal.astype(np.float32) | |
| masked_policy /= masked_policy.sum() | |
| return masked_policy, value | |
| def predict_batch(self, states) -> list: | |
| """Get policy and value for a batch of game states""" | |
| if not states: | |
| return [] | |
| obs = np.array([s.get_observation() for s in states]) | |
| policies, values = self.forward(obs) | |
| results = [] | |
| for i, (policy, value) in enumerate(zip(policies, values)): | |
| legal = states[i].get_legal_actions() | |
| masked_policy = policy * legal | |
| if masked_policy.sum() > 0: | |
| masked_policy /= masked_policy.sum() | |
| else: | |
| # Fall back to uniform over legal | |
| masked_policy = legal.astype(np.float32) | |
| masked_policy /= masked_policy.sum() | |
| results.append((masked_policy, value)) | |
| return results | |
| def train_step( | |
| self, observations: np.ndarray, target_policies: np.ndarray, target_values: np.ndarray | |
| ) -> Tuple[float, float, float]: | |
| """ | |
| One training step (Vectorized). | |
| Args: | |
| observations: Batch of observations (batch_size, input_size) | |
| target_policies: Target policy distributions (batch_size, action_size) | |
| target_values: Target values (batch_size,) | |
| Returns: | |
| (total_loss, policy_loss, value_loss) | |
| """ | |
| batch_size = len(observations) | |
| config = self.config | |
| # 1. Forward Pass (Batch) | |
| pred_policy, pred_value = self.forward(observations) | |
| # pred_policy: (B, action_size) | |
| # pred_value: (B,) | |
| # 2. Loss Calculation | |
| # Policy loss: Cross-entropy | |
| # Mean over batch | |
| policy_loss = -np.mean(np.sum(target_policies * np.log(pred_policy + 1e-8), axis=1)) | |
| # Value loss: MSE | |
| value_loss = np.mean((pred_value - target_values) ** 2) | |
| total_loss = policy_loss + value_loss | |
| # 3. Backward Pass (Gradients) | |
| # d_policy = (pred - target) / batch_size (Gradient of Mean Cross Entropy) | |
| # However, we treat the sum of gradients and then average manually update, | |
| # so let's stick to the convention: dL/dLogits = (pred - target) / B | |
| d_policy_logits = (pred_policy - target_policies) / batch_size | |
| # d_value = 2 * (pred - target) * tanh'(pre_tanh) / batch_size | |
| # tanh' = 1 - tanh^2 = 1 - pred_value^2 | |
| d_value_out = 2 * (pred_value - target_values) / batch_size | |
| d_value_pre_tanh = d_value_out * (1 - pred_value**2) | |
| # Gradients for heads | |
| # hidden_out: (B, hidden_size) (Last activation) | |
| hidden_out = self.activations[-1] | |
| # d_Weights = Input.T @ Error | |
| # Policy: (H, B) @ (B, A) -> (H, A) | |
| grad_policy_w = hidden_out.T @ d_policy_logits | |
| grad_policy_b = np.sum(d_policy_logits, axis=0) | |
| # Value: (H, B) @ (B, 1) -> (H, 1) | |
| # d_value_pre_tanh needs shape (B, 1) | |
| d_value_pre_tanh = d_value_pre_tanh.reshape(-1, 1) | |
| grad_value_w = hidden_out.T @ d_value_pre_tanh | |
| grad_value_b = np.sum(d_value_pre_tanh, axis=0) | |
| # Backprop through hidden layers | |
| # d_hidden_last = d_policy @ W_p.T + d_value @ W_v.T | |
| # (B, A) @ (A, H) + (B, 1) @ (1, H) -> (B, H) | |
| d_hidden = d_policy_logits @ self.policy_weight.T + d_value_pre_tanh @ self.value_weight.T | |
| # Store grads to apply later | |
| grads_w = [] | |
| grads_b = [] | |
| # Iterate backwards through hidden layers | |
| for layer_idx in range(len(self.hidden_weights) - 1, -1, -1): | |
| # ReLU derivative: mask where activation > 0 | |
| # self.activations has inputs at [0], layer 1 out at [1], etc. | |
| # layer_idx maps to weights[layer_idx], which produces activations[layer_idx+1] | |
| mask = (self.activations[layer_idx + 1] > 0).astype(np.float32) | |
| d_hidden = d_hidden * mask | |
| prev_activation = self.activations[layer_idx] | |
| # Gradients for this layer | |
| # (In, B) @ (B, Out) -> (In, Out) | |
| g_w = prev_activation.T @ d_hidden | |
| g_b = np.sum(d_hidden, axis=0) | |
| grads_w.insert(0, g_w) | |
| grads_b.insert(0, g_b) | |
| if layer_idx > 0: | |
| # Propagate to previous layer | |
| d_hidden = d_hidden @ self.hidden_weights[layer_idx].T | |
| # 4. Apply Gradients (SGD + L2) | |
| for i in range(len(self.hidden_weights)): | |
| # L2: w = w - lr * (grad + l2 * w) | |
| self.hidden_weights[i] -= config.learning_rate * (grads_w[i] + config.l2_reg * self.hidden_weights[i]) | |
| self.hidden_biases[i] -= config.learning_rate * grads_b[i] | |
| self.policy_weight -= config.learning_rate * (grad_policy_w + config.l2_reg * self.policy_weight) | |
| self.policy_bias -= config.learning_rate * grad_policy_b | |
| self.value_weight -= config.learning_rate * (grad_value_w + config.l2_reg * self.value_weight) | |
| self.value_bias -= config.learning_rate * grad_value_b | |
| return total_loss, policy_loss, value_loss | |
| def save(self, filepath: str) -> None: | |
| """Save network weights to file""" | |
| # Use allow_pickle and object-array conversion to handle inhomogeneous layer shapes | |
| np.savez( | |
| filepath, | |
| hidden_weights=np.array(self.hidden_weights, dtype=object), | |
| hidden_biases=np.array(self.hidden_biases, dtype=object), | |
| policy_weight=self.policy_weight, | |
| policy_bias=self.policy_bias, | |
| value_weight=self.value_weight, | |
| value_bias=self.value_bias, | |
| ) | |
| def load(self, filepath: str) -> None: | |
| """Load network weights from file""" | |
| data = np.load(filepath, allow_pickle=True) | |
| # Convert object arrays back to lists of arrays | |
| self.hidden_weights = list(data["hidden_weights"]) | |
| self.hidden_biases = list(data["hidden_biases"]) | |
| self.policy_weight = data["policy_weight"] | |
| self.policy_bias = data["policy_bias"] | |
| self.value_weight = data["value_weight"] | |
| self.value_bias = data["value_bias"] | |
| class NeuralMCTS: | |
| """MCTS that uses a neural network for policy and value with parallel search""" | |
| def __init__( | |
| self, network: SimpleNetwork, num_simulations: int = 100, batch_size: int = 8, virtual_loss: float = 3.0 | |
| ): | |
| self.network = network | |
| self.num_simulations = num_simulations | |
| self.batch_size = batch_size | |
| self.c_puct = 1.4 | |
| self.virtual_loss = virtual_loss | |
| self.root = None | |
| def get_policy_value(self, state) -> Tuple[np.ndarray, float]: | |
| """Get policy and value from neural network""" | |
| return self.network.predict(state) | |
| def search(self, state) -> np.ndarray: | |
| """Run MCTS with neural network guidance (Parallel)""" | |
| from ai.mcts import MCTSNode | |
| # Initial root expansion (always blocking) | |
| policy, _ = self.get_policy_value(state) | |
| self.root = MCTSNode() | |
| self.root.expand(state, policy) | |
| # We can't batch perfectly if simulations not divisible, but approx is fine | |
| num_batches = (self.num_simulations + self.batch_size - 1) // self.batch_size | |
| for _ in range(num_batches): | |
| self._simulate_batch(state, self.batch_size) | |
| # Return visit count distribution | |
| # Note: visits length must match action_size from network config or game state | |
| # MCTSNode children keys are actions. | |
| # We need a fixed size array for the policy target. | |
| action_size = len(state.get_legal_actions()) | |
| visits = np.zeros(action_size, dtype=np.float32) | |
| for action, child in self.root.children.items(): | |
| visits[action] = child.visit_count | |
| if visits.sum() > 0: | |
| visits /= visits.sum() | |
| return visits | |
| def _simulate_batch(self, root_state, batch_size) -> None: | |
| """Run a batch of MCTS simulations parallelized via Virtual Loss""" | |
| paths = [] | |
| leaf_nodes = [] | |
| request_states = [] | |
| # 1. Selection Phase for K threads | |
| for _ in range(batch_size): | |
| node = self.root | |
| state = root_state.copy() | |
| path = [node] | |
| # Selection | |
| while node.is_expanded() and not state.is_terminal(): | |
| action, child = node.select_child(self.c_puct) | |
| # Apply Virtual Loss immediately so subsequent selections in this batch diverge | |
| child.virtual_loss += self.virtual_loss | |
| state = state.step(action) | |
| node = child | |
| path.append(node) | |
| paths.append((path, state)) | |
| leaf_nodes.append(node) | |
| if not state.is_terminal(): | |
| request_states.append(state) | |
| # 2. Evaluation Phase (Batched) | |
| responses = [] | |
| if request_states: | |
| if hasattr(self.network, "predict_batch"): | |
| responses = self.network.predict_batch(request_states) | |
| else: | |
| responses = [self.network.predict(s) for s in request_states] | |
| # 3. Expansion & Backpropagation Phase | |
| resp_idx = 0 | |
| for i in range(batch_size): | |
| path, state = paths[i] | |
| leaf = leaf_nodes[i] | |
| value = 0.0 | |
| if state.is_terminal(): | |
| value = state.get_reward(root_state.current_player) | |
| else: | |
| # Retrieve prediction | |
| policy, v = responses[resp_idx] | |
| resp_idx += 1 | |
| value = v | |
| # Expand | |
| leaf.expand(state, policy) | |
| # Backpropagate | |
| for node in reversed(path): | |
| node.visit_count += 1 | |
| node.value_sum += value | |
| # Remove Virtual Loss (except from root which we didn't add to? | |
| # Wait, select_child returns child, and we added to child. | |
| # Root is path[0]. path[1] is first child. | |
| # So we should only subtract from path[1:] if we logic matches. | |
| # But wait, did we add to root? No. | |
| # So check: if node != self.root: node.virtual_loss -= ... | |
| if node != self.root: | |
| node.virtual_loss -= self.virtual_loss | |
| value = -value | |
| def train_network(network: SimpleNetwork, training_data: list, epochs: int = 10, batch_size: int = 32) -> None: | |
| """ | |
| Train network on self-play data. | |
| Args: | |
| network: Network to train | |
| training_data: List of (states, policies, winner) tuples | |
| epochs: Number of training epochs | |
| batch_size: Batch size for training | |
| """ | |
| print(f"Training on {len(training_data)} games...") | |
| # Flatten data with rewards | |
| all_states = [] | |
| all_policies = [] | |
| all_values = [] | |
| for states, policies, winner, r0, r1 in training_data: | |
| for i, (s, p) in enumerate(zip(states, policies, strict=False)): | |
| all_states.append(s) | |
| all_policies.append(p) | |
| # Value from perspective of player who made the move | |
| player_idx = i % 2 | |
| # Use actual calculated reward (with score shaping) | |
| if player_idx == 0: | |
| all_values.append(r0) | |
| else: | |
| all_values.append(r1) | |
| all_states = np.array(all_states) | |
| all_policies = np.array(all_policies) | |
| all_values = np.array(all_values) | |
| n_samples = len(all_states) | |
| for epoch in range(epochs): | |
| # Shuffle data | |
| indices = np.random.permutation(n_samples) | |
| total_loss = 0.0 | |
| for i in range(0, n_samples, batch_size): | |
| batch_idx = indices[i : i + batch_size] | |
| loss, p_loss, v_loss = network.train_step( | |
| all_states[batch_idx], all_policies[batch_idx], all_values[batch_idx] | |
| ) | |
| total_loss += loss | |
| num_batches = (n_samples + batch_size - 1) // batch_size | |
| print(f"Epoch {epoch + 1}/{epochs}, Loss: {total_loss / num_batches:.4f}") | |
| if __name__ == "__main__": | |
| # Test network | |
| from engine.game.game_state import initialize_game | |
| print("Testing neural network...") | |
| config = NetworkConfig() | |
| network = SimpleNetwork(config) | |
| # Test forward pass | |
| state = initialize_game() | |
| policy, value = network.predict(state) | |
| print(f"Policy shape: {policy.shape}") | |
| print(f"Policy sum: {policy.sum():.4f}") | |
| print(f"Value: {value:.4f}") | |
| # Test training step | |
| obs = state.get_observation() | |
| target_p = np.zeros(config.action_size) | |
| target_p[0] = 0.8 | |
| target_p[1] = 0.2 | |
| target_v = 0.5 | |
| loss, p_loss, v_loss = network.train_step(obs.reshape(1, -1), target_p.reshape(1, -1), np.array([target_v])) | |
| print(f"Training loss: {loss:.4f} (policy: {p_loss:.4f}, value: {v_loss:.4f})") | |