/** * Monte Carlo Tree Search (MCTS) Agent for Trigo * * Implements AlphaGo Zero-style MCTS with: * - PUCT (Polynomial Upper Confidence Trees) selection * - Neural network guidance for policy and value * - Visit count statistics for training data generation * * Based on: Silver et al., "Mastering the Game of Go without Human Knowledge" */ import { TrigoGame } from "./trigo/game"; import type { Move } from "./trigo/types"; import { TrigoTreeAgent } from "./trigoTreeAgent"; import { TrigoEvaluationAgent } from "./trigoEvaluationAgent"; /** * MCTS Configuration */ export interface MCTSConfig { numSimulations: number; // Number of MCTS simulations per move (default: 600) cPuct: number; // PUCT exploration constant (default: 1.0) temperature: number; // Selection temperature for first 30 moves (default: 1.0) dirichletAlpha: number; // Dirichlet noise alpha parameter (default: 0.03) dirichletEpsilon: number; // Dirichlet noise mixing weight (default: 0.25) } /** * MCTS Tree Node * Stores search statistics for all legal actions from a given game state * * Memory optimization: Only root node stores the full game state. * Non-root nodes only store the action that led to them. * During simulation, a working state is cloned once and mutated along the path. */ interface MCTSNode { state: TrigoGame | null; // Game state (only stored at root node for memory efficiency) parent: MCTSNode | null; // Parent node (null for root) action: Move | null; // Action that led to this node (null for root) // MCTS statistics per action (action key -> value) N: Map; // Visit counts N(s,a) W: Map; // Total action-value W(s,a) Q: Map; // Mean action-value Q(s,a) = W(s,a) / N(s,a) P: Map; // Prior probabilities P(s,a) from policy network children: Map; // Child nodes (action key -> child node) expanded: boolean; // Whether this node has been expanded terminalValue: number | null; // Cached terminal value (null if not terminal or not computed) // Terminal propagation optimization (GPT-5.1 suggestions) depth: number; // Distance from root (0 for root) playerToMove: number; // Player to move at this node (1=Black, 2=White) } /** * MCTS Agent * Combines tree search with neural network evaluation */ export class MCTSAgent { private treeAgent: TrigoTreeAgent; // For policy priors private evaluationAgent: TrigoEvaluationAgent; // For value evaluation private config: MCTSConfig; public debugMode: boolean = false; // Enable debug logging constructor( treeAgent: TrigoTreeAgent, evaluationAgent: TrigoEvaluationAgent, config: Partial = {} ) { this.treeAgent = treeAgent; this.evaluationAgent = evaluationAgent; // Default configuration (AlphaGo Zero-inspired) this.config = { numSimulations: config.numSimulations ?? 600, cPuct: config.cPuct ?? 1.0, temperature: config.temperature ?? 1.0, dirichletAlpha: config.dirichletAlpha ?? 0.03, dirichletEpsilon: config.dirichletEpsilon ?? 0.25 }; } /** * Select best move using MCTS * * @param game Current game state * @param moveNumber Move number (for temperature schedule) * @returns Selected move with visit count statistics */ async selectMove(game: TrigoGame, moveNumber: number): Promise<{ move: Move; visitCounts: Map; searchPolicy: Map; // Normalized visit counts π(a|s) rootValue: number; }> { // Create root node const root = this.createNode(game, null, null); // Check if root is already terminal (game over) const terminalResult = this.checkTerminal(game); if (terminalResult !== null) { const currentPlayer = game.getCurrentPlayer(); return { move: { player: currentPlayer === 1 ? "black" : "white", isPass: true }, visitCounts: new Map(), searchPolicy: new Map(), rootValue: terminalResult }; } // Run MCTS simulations for (let i = 0; i < this.config.numSimulations; i++) { await this.runSimulation(root, i); } // Temperature schedule: τ=1 for first 30 moves, τ→0 afterward const temperature = moveNumber < 30 ? this.config.temperature : 0.01; // Select move based on visit counts const move = this.selectPlayAction(root, temperature); // Set correct player for returned move const currentPlayer = game.getCurrentPlayer(); move.player = currentPlayer === 1 ? "black" : "white"; // Compute search policy (normalized visit counts) const searchPolicy = this.computeSearchPolicy(root, temperature); // Get root value estimate (average Q-value weighted by visit counts) const rootValue = this.getRootValue(root); return { move, visitCounts: new Map(root.N), searchPolicy, rootValue }; } /** * Run a single MCTS simulation * Select -> Expand & Evaluate -> Backup * * Memory optimization: Clone state once at start, mutate along path. * This reduces memory from O(nodes) to O(simulations). */ private async runSimulation(root: MCTSNode, simIndex?: number): Promise { // Invariant: root node must always have a non-null state if (!root.state) { throw new Error("runSimulation: root node must have a non-null state"); } // Clone root state once for this simulation const workingState = root.state.clone(); // 1. Selection: Traverse tree using PUCT until reaching unexpanded node const { node, path } = this.select(root, workingState); // 2. Expand and Evaluate: Get value from neural network const value = await this.expandAndEvaluate(node, workingState); // Debug logging if (this.debugMode && simIndex !== undefined && simIndex < 10) { const pathStr = path.map(p => p.actionKey).join(" → "); const terminalStr = node.terminalValue !== null ? " [TERMINAL]" : ""; console.log(`Sim ${simIndex + 1}: ${pathStr || "(root)"} → value=${value.toFixed(4)}${terminalStr}`); } // 3. Backup: Propagate value up the tree this.backup(path, value); } /** * Selection phase: Traverse tree using PUCT * * @param root Root node to start selection from * @param workingState Mutable game state that gets updated along the path * @returns Leaf node and path taken */ private select(root: MCTSNode, workingState: TrigoGame): { node: MCTSNode; path: Array<{ node: MCTSNode; actionKey: string }>; } { const path: Array<{ node: MCTSNode; actionKey: string }> = []; let node = root; // Traverse until we reach an unexpanded node while (node.expanded) { // GPT-5.1 recommendation: Stop at terminal nodes immediately // Terminal nodes should not be expanded or evaluated further if (node.terminalValue !== null) { break; // Return terminal node, use its cached value } // Get all legal actions const actionKeys = Array.from(node.P.keys()); // Terminal node check: if expanded but no actions, this is a terminal node if (actionKeys.length === 0) { break; // Return this terminal node as leaf } // Select action with best PUCT value // Both players select HIGHEST PUCT value: // - Black: PUCT = -Q + U, max PUCT = max(-Q) = min(Q) ✓ // - White: PUCT = Q + U, max PUCT = max(Q) ✓ const currentPlayer = workingState.getCurrentPlayer(); const isWhite = currentPlayer === 2; let bestActionKey = actionKeys[0]; let bestPuct = this.calculatePUCT(node, bestActionKey, isWhite); for (let i = 1; i < actionKeys.length; i++) { const actionKey = actionKeys[i]; const puct = this.calculatePUCT(node, actionKey, isWhite); if (puct > bestPuct) { bestPuct = puct; bestActionKey = actionKey; } } // Record path path.push({ node, actionKey: bestActionKey }); // Apply action to working state (instead of cloning) const action = this.decodeAction(bestActionKey); if (action.isPass) { workingState.pass(); } else if (action.x !== undefined && action.y !== undefined && action.z !== undefined) { workingState.drop({ x: action.x, y: action.y, z: action.z }); } // Move to child (create if doesn't exist) if (!node.children.has(bestActionKey)) { // Create child node WITHOUT storing state (memory optimization) const childNode = this.createNode(null, node, action); node.children.set(bestActionKey, childNode); } node = node.children.get(bestActionKey)!; } return { node, path }; } /** * Expand and evaluate leaf node using neural networks * * @param node Leaf node to expand * @param workingState Current game state at this node (passed from simulation) * @returns Value estimate from evaluation network */ private async expandAndEvaluate(node: MCTSNode, workingState: TrigoGame): Promise { // Check if terminal value is already cached if (node.terminalValue !== null) { return node.terminalValue; } // Check if game is over (terminal state) const terminalValue = this.checkTerminal(workingState); if (terminalValue !== null) { // Mark terminal node as expanded with empty action set to prevent revisits // Cache the terminal value to avoid repeated checks node.expanded = true; node.terminalValue = terminalValue; node.P = new Map(); // No actions available (terminal) node.N = new Map(); node.W = new Map(); node.Q = new Map(); node.children = new Map(); return terminalValue; } // Non-terminal state: expand with policy network and evaluate // Get all valid moves const currentPlayer = workingState.getCurrentPlayer() === 1 ? "black" : "white"; const validPositions = workingState.validMovePositions(); const moves: Move[] = validPositions.map(pos => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); moves.push({ player: currentPlayer, isPass: true }); // Get policy priors from tree agent const scoredMoves = await this.treeAgent.scoreMoves(workingState, moves); // Convert log probabilities to probabilities and normalize (stable softmax) const maxScore = Math.max(...scoredMoves.map(m => m.score)); const expScores = scoredMoves.map(m => Math.exp(m.score - maxScore)); const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); // Initialize priors P(s,a) node.P = new Map(); node.N = new Map(); node.W = new Map(); node.Q = new Map(); // Handle edge case: if all scores are -Infinity or sumExp is 0/NaN const useFallback = !isFinite(sumExp) || sumExp < 1e-10; for (let i = 0; i < scoredMoves.length; i++) { const actionKey = this.encodeAction(scoredMoves[i].move); // Use uniform distribution as fallback if normalization fails const prior = useFallback ? (1.0 / scoredMoves.length) : (expScores[i] / sumExp); node.P.set(actionKey, prior); node.N.set(actionKey, 0); node.W.set(actionKey, 0); node.Q.set(actionKey, 0); } // Add Dirichlet noise at root if (node.parent === null) { this.addDirichletNoise(node.P); } // Mark as expanded node.expanded = true; // Get value estimate from evaluation agent const evaluation = await this.evaluationAgent.evaluatePosition(workingState); // Return value directly (value model returns white-positive by design) return evaluation.value; } /** * Backup phase: Propagate value up the tree * * White-positive minimax propagation: * - All Q-values represent White's advantage (positive = White winning) * - When all children are terminal, mark parent as terminal with minimax value: * * White's turn: terminal_value = max(children terminal values) * * Black's turn: terminal_value = min(children terminal values) * * Improvements (based on GPT-5.1 review): * - Uses stored playerToMove instead of computing from depth * - Uses stored depth instead of recomputing via parent walk * * @param path Path from root to leaf * @param value Value to propagate (white-positive: positive = white winning) */ private backup(path: Array<{ node: MCTSNode; actionKey: string }>, value: number): void { // Propagate value up the tree (white-positive throughout) // No sign flipping needed - Q values are always white-positive for (let i = path.length - 1; i >= 0; i--) { const { node, actionKey } = path[i]; // Update statistics const n = node.N.get(actionKey) ?? 0; const w = node.W.get(actionKey) ?? 0; node.N.set(actionKey, n + 1); node.W.set(actionKey, w + value); node.Q.set(actionKey, (w + value) / (n + 1)); // ========== Terminal State Propagation ========== // Check if this node should be marked as terminal // Condition: node is fully expanded AND all children are terminal AND node itself not yet marked if (node.expanded && node.terminalValue === null) { const actionKeys = Array.from(node.P.keys()); // Skip propagation if no actions (already a terminal leaf, or error state) if (actionKeys.length === 0) { continue; } // Check if ALL children are terminal let allChildrenTerminal = true; const childTerminalValues: number[] = []; for (const key of actionKeys) { const child = node.children.get(key); // If child doesn't exist yet, not all children explored if (!child) { allChildrenTerminal = false; break; } // If child is not terminal, not all children terminal if (child.terminalValue === null) { allChildrenTerminal = false; break; } // Child is terminal, collect its value childTerminalValues.push(child.terminalValue); } // If all children are terminal, mark current node as terminal with minimax value if (allChildrenTerminal && childTerminalValues.length > 0) { // Use stored playerToMove instead of computing from depth (GPT-5.1 suggestion) const isWhiteTurn = node.playerToMove === 2; // 2 = White, 1 = Black // Apply minimax: choose best child value from current player's perspective let terminalValue: number; if (isWhiteTurn) { // White maximizes Q-value (white-positive) terminalValue = Math.max(...childTerminalValues); } else { // Black minimizes Q-value (white-positive) terminalValue = Math.min(...childTerminalValues); } // Mark this node as terminal with the minimax value node.terminalValue = terminalValue; // Debug logging for terminal propagation if (this.debugMode) { const playerName = isWhiteTurn ? 'White' : 'Black'; console.log( `[Terminal Propagation] Node at depth ${node.depth} (${playerName}) marked terminal: ` + `value=${terminalValue.toFixed(4)}, children=[${childTerminalValues.map(v => v.toFixed(2)).join(', ')}]` ); } } } // ================================================ } } /** * Calculate PUCT value for action selection * * PUCT = Q(s,a) + U(s,a) [for White, who maximizes] * PUCT = -Q(s,a) + U(s,a) [for Black, who minimizes] * where U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b)) / (1 + N(s,a)) * * @param node Current node * @param actionKey Action to evaluate * @param isWhite Whether current player is White * @returns PUCT value */ private calculatePUCT(node: MCTSNode, actionKey: string, isWhite: boolean): number { const Q = node.Q.get(actionKey) ?? 0; const N = node.N.get(actionKey) ?? 0; const P = node.P.get(actionKey) ?? 0; // Sum of all visit counts at this node const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0); // Exploration term: U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b) + 1) / (1 + N(s,a)) // +1 in sqrt to avoid zero exploration when node first expanded const U = this.config.cPuct * P * Math.sqrt(totalN + 1) / (1 + N); // Black minimizes Q (flips sign), White maximizes Q return (isWhite ? Q : -Q) + U; } /** * Select action to play based on visit counts * Uses temperature to control exploration vs exploitation * * @param node Root node * @param temperature Selection temperature (τ=1 for exploration, τ→0 for greedy) * @returns Selected move */ private selectPlayAction(node: MCTSNode, temperature: number): Move { const actionKeys = Array.from(node.N.keys()); // Edge case: no actions available (unexpanded root or terminal state) if (actionKeys.length === 0) { // Fallback to priors if available const priorKeys = Array.from(node.P.keys()); if (priorKeys.length > 0) { // Sample from prior distribution const priors = priorKeys.map(key => node.P.get(key) ?? 0); const sumP = priors.reduce((sum, p) => sum + p, 0); if (sumP > 0) { let rand = Math.random() * sumP; for (let i = 0; i < priorKeys.length; i++) { rand -= priors[i]; if (rand <= 0) { return this.decodeAction(priorKeys[i]); } } return this.decodeAction(priorKeys[priorKeys.length - 1]); } // Uniform fallback const randomIndex = Math.floor(Math.random() * priorKeys.length); return this.decodeAction(priorKeys[randomIndex]); } // No actions at all - return Pass as last resort return { player: "black", isPass: true }; } if (temperature < 0.01) { // Greedy: Select action with highest visit count let bestActionKey = actionKeys[0]; let bestN = node.N.get(bestActionKey) ?? 0; for (let i = 1; i < actionKeys.length; i++) { const actionKey = actionKeys[i]; const n = node.N.get(actionKey) ?? 0; if (n > bestN) { bestN = n; bestActionKey = actionKey; } } return this.decodeAction(bestActionKey); } else { // Temperature-based sampling: π(a|s) ∝ N(s,a)^(1/τ) const nValues = actionKeys.map(key => node.N.get(key) ?? 0); const nPowered = nValues.map(n => Math.pow(n, 1 / temperature)); const sumN = nPowered.reduce((sum, n) => sum + n, 0); // Handle edge case: if all visits are 0 or sum is invalid if (!isFinite(sumN) || sumN <= 0) { // Fallback to uniform random selection (or use priors) const randomIndex = Math.floor(Math.random() * actionKeys.length); return this.decodeAction(actionKeys[randomIndex]); } // Sample from distribution let rand = Math.random() * sumN; for (let i = 0; i < actionKeys.length; i++) { rand -= nPowered[i]; if (rand <= 0) { return this.decodeAction(actionKeys[i]); } } // Fallback (shouldn't reach here due to floating point precision) return this.decodeAction(actionKeys[actionKeys.length - 1]); } } /** * Compute search policy from visit counts * π(a|s) = N(s,a)^(1/τ) / Σ_b N(s,b)^(1/τ) * * @param node Root node * @param temperature Selection temperature * @returns Normalized policy distribution */ private computeSearchPolicy(node: MCTSNode, temperature: number): Map { const policy = new Map(); const actionKeys = Array.from(node.N.keys()); // Compute π(a|s) ∝ N(s,a)^(1/τ) const nPowered = actionKeys.map(key => Math.pow(node.N.get(key) ?? 0, 1 / temperature)); const sumN = nPowered.reduce((sum, n) => sum + n, 0); // Handle edge case: if all visits are 0 or sum is invalid if (!isFinite(sumN) || sumN <= 0) { // Fallback to uniform distribution const uniform = 1 / actionKeys.length; for (const key of actionKeys) { policy.set(key, uniform); } return policy; } for (let i = 0; i < actionKeys.length; i++) { const actionKey = actionKeys[i]; policy.set(actionKey, nPowered[i] / sumN); } return policy; } /** * Get root value estimate (weighted average of Q-values) */ private getRootValue(node: MCTSNode): number { const actionKeys = Array.from(node.N.keys()); const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0); if (totalN === 0) { return 0; } let weightedSum = 0; for (const actionKey of actionKeys) { const q = node.Q.get(actionKey) ?? 0; const n = node.N.get(actionKey) ?? 0; weightedSum += q * n; } return weightedSum / totalN; } /** * Add Dirichlet noise to prior probabilities at root * P(s,a) = (1 - ε) * p_a + ε * η_a * where η ~ Dir(α) * * Note: Pass move is excluded from noise to prevent exploration of * clearly suboptimal opening passes. */ private addDirichletNoise(priors: Map): void { // Exclude Pass from Dirichlet noise - it should not be explored at root const actionKeys = Array.from(priors.keys()).filter(key => key !== "pass"); const alpha = this.config.dirichletAlpha; const epsilon = this.config.dirichletEpsilon; // If only Pass is available, no noise to add if (actionKeys.length === 0) { return; } // Generate Dirichlet noise (simplified using Gamma distribution) const noise: number[] = []; let noiseSum = 0; for (let i = 0; i < actionKeys.length; i++) { // Gamma(α, 1) approximation using rejection sampling const sample = this.sampleGamma(alpha); noise.push(sample); noiseSum += sample; } // Handle edge case: if all Gamma samples are 0 (extremely unlikely but possible) if (!isFinite(noiseSum) || noiseSum <= 0) { // Fallback: use uniform noise (no mixing, keep original priors) return; } // Normalize and mix with priors (only for non-Pass actions) for (let i = 0; i < actionKeys.length; i++) { const actionKey = actionKeys[i]; const prior = priors.get(actionKey) ?? 0; const noiseFraction = noise[i] / noiseSum; priors.set(actionKey, (1 - epsilon) * prior + epsilon * noiseFraction); } } /** * Sample from Gamma distribution using Marsaglia and Tsang method (2000) * Used for Dirichlet noise generation */ private sampleGamma(alpha: number): number { if (alpha <= 0) { throw new Error("Gamma distribution alpha must be > 0"); } // For alpha < 1, use transformation: sample Gamma(alpha+1) then multiply by U^(1/alpha) if (alpha < 1) { const u = Math.random(); const g = this.sampleGamma(alpha + 1); return g * Math.pow(u, 1 / alpha); } // For alpha >= 1, use Marsaglia and Tsang's method const d = alpha - 1/3; const c = 1 / Math.sqrt(9 * d); while (true) { let x, v; do { x = this.randomNormal(); v = 1 + c * x; } while (v <= 0); v = v * v * v; const u = Math.random(); // Fast acceptance check if (u < 1 - 0.0331 * x * x * x * x) { return d * v; } // Fallback acceptance check if (Math.log(u) < 0.5 * x * x + d * (1 - v + Math.log(v))) { return d * v; } } } /** * Sample from standard normal distribution (Box-Muller transform) */ private randomNormal(): number { const u1 = Math.random(); const u2 = Math.random(); return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2); } /** * Check if game state is terminal and return value if so * * Terminal conditions (checked in order of cost): * 1. Game already finished (double-pass or resignation) - CHEAPEST * 2. Board coverage > 50% AND naturally terminal (calls isNaturallyTerminal) - EXPENSIVE * * NOTE: The coverage check (> 50%) is an optimization to avoid expensive * territory calculations on sparse boards where natural termination is unlikely. * * @param state Game state to check * @returns Terminal value (white-positive) if terminal, null otherwise */ private checkTerminal(state: TrigoGame): number | null { // 1. Check if game is already finished (double-pass, resignation, etc.) // This is the cheapest check - just reading a status flag if (state.getGameStatus() === "finished") { const territory = state.getTerritory(); return this.calculateTerminalValue(territory); } // 2. Check for "natural" game end (all territory claimed, no capturing moves) // Optimization: Only check if board is reasonably full (> 50% coverage) // because natural termination is unlikely on sparse boards const board = state.getBoard(); const shape = state.getShape(); const totalPositions = shape.x * shape.y * shape.z; // Count stones (cheap) let stoneCount = 0; for (let x = 0; x < shape.x; x++) { for (let y = 0; y < shape.y; y++) { for (let z = 0; z < shape.z; z++) { const stone = board[x][y][z]; if (stone === 1 || stone === 2) { // StoneType.BLACK or WHITE stoneCount++; } } } } const coverageRatio = stoneCount / totalPositions; // Only check for natural termination if board is reasonably full if (coverageRatio > 0.5) { if (state.isNaturallyTerminal()) { const territory = state.getTerritory(); return this.calculateTerminalValue(territory); } } return null; // Not terminal } /** * Calculate terminal value from territory scores * Uses logarithmic scaling matching the training code * * @param territory Territory counts from game * @returns Value (white-positive: positive = white winning) */ private calculateTerminalValue(territory: { black: number; white: number; neutral: number }): number { const scoreDiff = territory.white - territory.black; if (Math.abs(scoreDiff) < 1e-6) { // Draw/tie case return 0.0; } // Match training formula from valueCausalLoss.py:_expand_value_targets // target = sign(score) * (1 + log(|score|)) * territory_value_factor // The log term incentivizes winning by larger margins (logarithmically) const territory_value_factor = 1.0; // Default from training config const signScore = Math.sign(scoreDiff); return signScore * (1 + Math.log(Math.abs(scoreDiff))) * territory_value_factor; } /** * Create a new MCTS node * * @param state Game state (only provided for root node, null for others to save memory) * @param parent Parent node * @param action Action that led to this node * @param playerToMove Player to move at this node (derived from state if available) */ private createNode(state: TrigoGame | null, parent: MCTSNode | null, action: Move | null, playerToMove?: number): MCTSNode { // Determine player to move let player: number; if (playerToMove !== undefined) { player = playerToMove; } else if (state) { // Most reliable: derive from actual game state player = state.getCurrentPlayer(); } else if (parent) { // NOTE: Fallback assumes strictly alternating turns (no passes keeping same player) // For standard Go-like games with strict alternation, this is safe. player = parent.playerToMove === 1 ? 2 : 1; } else { // Default to Black for root if no info player = 1; } return { state, parent, action, N: new Map(), W: new Map(), Q: new Map(), P: new Map(), children: new Map(), expanded: false, terminalValue: null, depth: parent ? parent.depth + 1 : 0, playerToMove: player }; } /** * Encode move to string key for storage in maps * Note: Only encodes position, player info is handled separately */ private encodeAction(move: Move): string { if (move.isPass) { return "pass"; } return `${move.x},${move.y},${move.z}`; } /** * Decode string key back to move * Note: Returns move with placeholder player - caller must set correct player * based on game state before using the move externally */ private decodeAction(key: string): Move { if (key === "pass") { // Player is placeholder - will be set by caller (selectMove sets it from game state) return { player: "black", isPass: true }; } const [x, y, z] = key.split(",").map(Number); // Player is placeholder - will be set by caller (selectMove sets it from game state) return { player: "black", x, y, z }; } }