Spaces:
Sleeping
Sleeping
| /** | |
| * Monte Carlo Tree Search (MCTS) Agent for Trigo | |
| * | |
| * Implements AlphaGo Zero-style MCTS with: | |
| * - PUCT (Polynomial Upper Confidence Trees) selection | |
| * - Neural network guidance for policy and value | |
| * - Visit count statistics for training data generation | |
| * | |
| * Based on: Silver et al., "Mastering the Game of Go without Human Knowledge" | |
| */ | |
| import { TrigoGame } from "./trigo/game"; | |
| import type { Move } from "./trigo/types"; | |
| import { TrigoTreeAgent } from "./trigoTreeAgent"; | |
| import { TrigoEvaluationAgent } from "./trigoEvaluationAgent"; | |
| /** | |
| * MCTS Configuration | |
| */ | |
| export interface MCTSConfig { | |
| numSimulations: number; // Number of MCTS simulations per move (default: 600) | |
| cPuct: number; // PUCT exploration constant (default: 1.0) | |
| temperature: number; // Selection temperature for first 30 moves (default: 1.0) | |
| dirichletAlpha: number; // Dirichlet noise alpha parameter (default: 0.03) | |
| dirichletEpsilon: number; // Dirichlet noise mixing weight (default: 0.25) | |
| } | |
| /** | |
| * MCTS Tree Node | |
| * Stores search statistics for all legal actions from a given game state | |
| * | |
| * Memory optimization: Only root node stores the full game state. | |
| * Non-root nodes only store the action that led to them. | |
| * During simulation, a working state is cloned once and mutated along the path. | |
| */ | |
| interface MCTSNode { | |
| state: TrigoGame | null; // Game state (only stored at root node for memory efficiency) | |
| parent: MCTSNode | null; // Parent node (null for root) | |
| action: Move | null; // Action that led to this node (null for root) | |
| // MCTS statistics per action (action key -> value) | |
| N: Map<string, number>; // Visit counts N(s,a) | |
| W: Map<string, number>; // Total action-value W(s,a) | |
| Q: Map<string, number>; // Mean action-value Q(s,a) = W(s,a) / N(s,a) | |
| P: Map<string, number>; // Prior probabilities P(s,a) from policy network | |
| children: Map<string, MCTSNode>; // Child nodes (action key -> child node) | |
| expanded: boolean; // Whether this node has been expanded | |
| terminalValue: number | null; // Cached terminal value (null if not terminal or not computed) | |
| // Terminal propagation optimization (GPT-5.1 suggestions) | |
| depth: number; // Distance from root (0 for root) | |
| playerToMove: number; // Player to move at this node (1=Black, 2=White) | |
| } | |
| /** | |
| * MCTS Agent | |
| * Combines tree search with neural network evaluation | |
| */ | |
| export class MCTSAgent { | |
| private treeAgent: TrigoTreeAgent; // For policy priors | |
| private evaluationAgent: TrigoEvaluationAgent; // For value evaluation | |
| private config: MCTSConfig; | |
| public debugMode: boolean = false; // Enable debug logging | |
| constructor( | |
| treeAgent: TrigoTreeAgent, | |
| evaluationAgent: TrigoEvaluationAgent, | |
| config: Partial<MCTSConfig> = {} | |
| ) { | |
| this.treeAgent = treeAgent; | |
| this.evaluationAgent = evaluationAgent; | |
| // Default configuration (AlphaGo Zero-inspired) | |
| this.config = { | |
| numSimulations: config.numSimulations ?? 600, | |
| cPuct: config.cPuct ?? 1.0, | |
| temperature: config.temperature ?? 1.0, | |
| dirichletAlpha: config.dirichletAlpha ?? 0.03, | |
| dirichletEpsilon: config.dirichletEpsilon ?? 0.25 | |
| }; | |
| } | |
| /** | |
| * Select best move using MCTS | |
| * | |
| * @param game Current game state | |
| * @param moveNumber Move number (for temperature schedule) | |
| * @returns Selected move with visit count statistics | |
| */ | |
| async selectMove(game: TrigoGame, moveNumber: number): Promise<{ | |
| move: Move; | |
| visitCounts: Map<string, number>; | |
| searchPolicy: Map<string, number>; // Normalized visit counts π(a|s) | |
| rootValue: number; | |
| }> { | |
| // Create root node | |
| const root = this.createNode(game, null, null); | |
| // Check if root is already terminal (game over) | |
| const terminalResult = this.checkTerminal(game); | |
| if (terminalResult !== null) { | |
| const currentPlayer = game.getCurrentPlayer(); | |
| return { | |
| move: { player: currentPlayer === 1 ? "black" : "white", isPass: true }, | |
| visitCounts: new Map(), | |
| searchPolicy: new Map(), | |
| rootValue: terminalResult | |
| }; | |
| } | |
| // Run MCTS simulations | |
| for (let i = 0; i < this.config.numSimulations; i++) { | |
| await this.runSimulation(root, i); | |
| } | |
| // Temperature schedule: τ=1 for first 30 moves, τ→0 afterward | |
| const temperature = moveNumber < 30 ? this.config.temperature : 0.01; | |
| // Select move based on visit counts | |
| const move = this.selectPlayAction(root, temperature); | |
| // Set correct player for returned move | |
| const currentPlayer = game.getCurrentPlayer(); | |
| move.player = currentPlayer === 1 ? "black" : "white"; | |
| // Compute search policy (normalized visit counts) | |
| const searchPolicy = this.computeSearchPolicy(root, temperature); | |
| // Get root value estimate (average Q-value weighted by visit counts) | |
| const rootValue = this.getRootValue(root); | |
| return { | |
| move, | |
| visitCounts: new Map(root.N), | |
| searchPolicy, | |
| rootValue | |
| }; | |
| } | |
| /** | |
| * Run a single MCTS simulation | |
| * Select -> Expand & Evaluate -> Backup | |
| * | |
| * Memory optimization: Clone state once at start, mutate along path. | |
| * This reduces memory from O(nodes) to O(simulations). | |
| */ | |
| private async runSimulation(root: MCTSNode, simIndex?: number): Promise<void> { | |
| // Invariant: root node must always have a non-null state | |
| if (!root.state) { | |
| throw new Error("runSimulation: root node must have a non-null state"); | |
| } | |
| // Clone root state once for this simulation | |
| const workingState = root.state.clone(); | |
| // 1. Selection: Traverse tree using PUCT until reaching unexpanded node | |
| const { node, path } = this.select(root, workingState); | |
| // 2. Expand and Evaluate: Get value from neural network | |
| const value = await this.expandAndEvaluate(node, workingState); | |
| // Debug logging | |
| if (this.debugMode && simIndex !== undefined && simIndex < 10) { | |
| const pathStr = path.map(p => p.actionKey).join(" → "); | |
| const terminalStr = node.terminalValue !== null ? " [TERMINAL]" : ""; | |
| console.log(`Sim ${simIndex + 1}: ${pathStr || "(root)"} → value=${value.toFixed(4)}${terminalStr}`); | |
| } | |
| // 3. Backup: Propagate value up the tree | |
| this.backup(path, value); | |
| } | |
| /** | |
| * Selection phase: Traverse tree using PUCT | |
| * | |
| * @param root Root node to start selection from | |
| * @param workingState Mutable game state that gets updated along the path | |
| * @returns Leaf node and path taken | |
| */ | |
| private select(root: MCTSNode, workingState: TrigoGame): { | |
| node: MCTSNode; | |
| path: Array<{ node: MCTSNode; actionKey: string }>; | |
| } { | |
| const path: Array<{ node: MCTSNode; actionKey: string }> = []; | |
| let node = root; | |
| // Traverse until we reach an unexpanded node | |
| while (node.expanded) { | |
| // GPT-5.1 recommendation: Stop at terminal nodes immediately | |
| // Terminal nodes should not be expanded or evaluated further | |
| if (node.terminalValue !== null) { | |
| break; // Return terminal node, use its cached value | |
| } | |
| // Get all legal actions | |
| const actionKeys = Array.from(node.P.keys()); | |
| // Terminal node check: if expanded but no actions, this is a terminal node | |
| if (actionKeys.length === 0) { | |
| break; // Return this terminal node as leaf | |
| } | |
| // Select action with best PUCT value | |
| // Both players select HIGHEST PUCT value: | |
| // - Black: PUCT = -Q + U, max PUCT = max(-Q) = min(Q) ✓ | |
| // - White: PUCT = Q + U, max PUCT = max(Q) ✓ | |
| const currentPlayer = workingState.getCurrentPlayer(); | |
| const isWhite = currentPlayer === 2; | |
| let bestActionKey = actionKeys[0]; | |
| let bestPuct = this.calculatePUCT(node, bestActionKey, isWhite); | |
| for (let i = 1; i < actionKeys.length; i++) { | |
| const actionKey = actionKeys[i]; | |
| const puct = this.calculatePUCT(node, actionKey, isWhite); | |
| if (puct > bestPuct) { | |
| bestPuct = puct; | |
| bestActionKey = actionKey; | |
| } | |
| } | |
| // Record path | |
| path.push({ node, actionKey: bestActionKey }); | |
| // Apply action to working state (instead of cloning) | |
| const action = this.decodeAction(bestActionKey); | |
| if (action.isPass) { | |
| workingState.pass(); | |
| } else if (action.x !== undefined && action.y !== undefined && action.z !== undefined) { | |
| workingState.drop({ x: action.x, y: action.y, z: action.z }); | |
| } | |
| // Move to child (create if doesn't exist) | |
| if (!node.children.has(bestActionKey)) { | |
| // Create child node WITHOUT storing state (memory optimization) | |
| const childNode = this.createNode(null, node, action); | |
| node.children.set(bestActionKey, childNode); | |
| } | |
| node = node.children.get(bestActionKey)!; | |
| } | |
| return { node, path }; | |
| } | |
| /** | |
| * Expand and evaluate leaf node using neural networks | |
| * | |
| * @param node Leaf node to expand | |
| * @param workingState Current game state at this node (passed from simulation) | |
| * @returns Value estimate from evaluation network | |
| */ | |
| private async expandAndEvaluate(node: MCTSNode, workingState: TrigoGame): Promise<number> { | |
| // Check if terminal value is already cached | |
| if (node.terminalValue !== null) { | |
| return node.terminalValue; | |
| } | |
| // Check if game is over (terminal state) | |
| const terminalValue = this.checkTerminal(workingState); | |
| if (terminalValue !== null) { | |
| // Mark terminal node as expanded with empty action set to prevent revisits | |
| // Cache the terminal value to avoid repeated checks | |
| node.expanded = true; | |
| node.terminalValue = terminalValue; | |
| node.P = new Map(); // No actions available (terminal) | |
| node.N = new Map(); | |
| node.W = new Map(); | |
| node.Q = new Map(); | |
| node.children = new Map(); | |
| return terminalValue; | |
| } | |
| // Non-terminal state: expand with policy network and evaluate | |
| // Get all valid moves | |
| const currentPlayer = workingState.getCurrentPlayer() === 1 ? "black" : "white"; | |
| const validPositions = workingState.validMovePositions(); | |
| const moves: Move[] = validPositions.map(pos => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| moves.push({ player: currentPlayer, isPass: true }); | |
| // Get policy priors from tree agent | |
| const scoredMoves = await this.treeAgent.scoreMoves(workingState, moves); | |
| // Convert log probabilities to probabilities and normalize (stable softmax) | |
| const maxScore = Math.max(...scoredMoves.map(m => m.score)); | |
| const expScores = scoredMoves.map(m => Math.exp(m.score - maxScore)); | |
| const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); | |
| // Initialize priors P(s,a) | |
| node.P = new Map(); | |
| node.N = new Map(); | |
| node.W = new Map(); | |
| node.Q = new Map(); | |
| // Handle edge case: if all scores are -Infinity or sumExp is 0/NaN | |
| const useFallback = !isFinite(sumExp) || sumExp < 1e-10; | |
| for (let i = 0; i < scoredMoves.length; i++) { | |
| const actionKey = this.encodeAction(scoredMoves[i].move); | |
| // Use uniform distribution as fallback if normalization fails | |
| const prior = useFallback ? (1.0 / scoredMoves.length) : (expScores[i] / sumExp); | |
| node.P.set(actionKey, prior); | |
| node.N.set(actionKey, 0); | |
| node.W.set(actionKey, 0); | |
| node.Q.set(actionKey, 0); | |
| } | |
| // Add Dirichlet noise at root | |
| if (node.parent === null) { | |
| this.addDirichletNoise(node.P); | |
| } | |
| // Mark as expanded | |
| node.expanded = true; | |
| // Get value estimate from evaluation agent | |
| const evaluation = await this.evaluationAgent.evaluatePosition(workingState); | |
| // Return value directly (value model returns white-positive by design) | |
| return evaluation.value; | |
| } | |
| /** | |
| * Backup phase: Propagate value up the tree | |
| * | |
| * White-positive minimax propagation: | |
| * - All Q-values represent White's advantage (positive = White winning) | |
| * - When all children are terminal, mark parent as terminal with minimax value: | |
| * * White's turn: terminal_value = max(children terminal values) | |
| * * Black's turn: terminal_value = min(children terminal values) | |
| * | |
| * Improvements (based on GPT-5.1 review): | |
| * - Uses stored playerToMove instead of computing from depth | |
| * - Uses stored depth instead of recomputing via parent walk | |
| * | |
| * @param path Path from root to leaf | |
| * @param value Value to propagate (white-positive: positive = white winning) | |
| */ | |
| private backup(path: Array<{ node: MCTSNode; actionKey: string }>, value: number): void { | |
| // Propagate value up the tree (white-positive throughout) | |
| // No sign flipping needed - Q values are always white-positive | |
| for (let i = path.length - 1; i >= 0; i--) { | |
| const { node, actionKey } = path[i]; | |
| // Update statistics | |
| const n = node.N.get(actionKey) ?? 0; | |
| const w = node.W.get(actionKey) ?? 0; | |
| node.N.set(actionKey, n + 1); | |
| node.W.set(actionKey, w + value); | |
| node.Q.set(actionKey, (w + value) / (n + 1)); | |
| // ========== Terminal State Propagation ========== | |
| // Check if this node should be marked as terminal | |
| // Condition: node is fully expanded AND all children are terminal AND node itself not yet marked | |
| if (node.expanded && node.terminalValue === null) { | |
| const actionKeys = Array.from(node.P.keys()); | |
| // Skip propagation if no actions (already a terminal leaf, or error state) | |
| if (actionKeys.length === 0) { | |
| continue; | |
| } | |
| // Check if ALL children are terminal | |
| let allChildrenTerminal = true; | |
| const childTerminalValues: number[] = []; | |
| for (const key of actionKeys) { | |
| const child = node.children.get(key); | |
| // If child doesn't exist yet, not all children explored | |
| if (!child) { | |
| allChildrenTerminal = false; | |
| break; | |
| } | |
| // If child is not terminal, not all children terminal | |
| if (child.terminalValue === null) { | |
| allChildrenTerminal = false; | |
| break; | |
| } | |
| // Child is terminal, collect its value | |
| childTerminalValues.push(child.terminalValue); | |
| } | |
| // If all children are terminal, mark current node as terminal with minimax value | |
| if (allChildrenTerminal && childTerminalValues.length > 0) { | |
| // Use stored playerToMove instead of computing from depth (GPT-5.1 suggestion) | |
| const isWhiteTurn = node.playerToMove === 2; // 2 = White, 1 = Black | |
| // Apply minimax: choose best child value from current player's perspective | |
| let terminalValue: number; | |
| if (isWhiteTurn) { | |
| // White maximizes Q-value (white-positive) | |
| terminalValue = Math.max(...childTerminalValues); | |
| } else { | |
| // Black minimizes Q-value (white-positive) | |
| terminalValue = Math.min(...childTerminalValues); | |
| } | |
| // Mark this node as terminal with the minimax value | |
| node.terminalValue = terminalValue; | |
| // Debug logging for terminal propagation | |
| if (this.debugMode) { | |
| const playerName = isWhiteTurn ? 'White' : 'Black'; | |
| console.log( | |
| `[Terminal Propagation] Node at depth ${node.depth} (${playerName}) marked terminal: ` + | |
| `value=${terminalValue.toFixed(4)}, children=[${childTerminalValues.map(v => v.toFixed(2)).join(', ')}]` | |
| ); | |
| } | |
| } | |
| } | |
| // ================================================ | |
| } | |
| } | |
| /** | |
| * Calculate PUCT value for action selection | |
| * | |
| * PUCT = Q(s,a) + U(s,a) [for White, who maximizes] | |
| * PUCT = -Q(s,a) + U(s,a) [for Black, who minimizes] | |
| * where U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b)) / (1 + N(s,a)) | |
| * | |
| * @param node Current node | |
| * @param actionKey Action to evaluate | |
| * @param isWhite Whether current player is White | |
| * @returns PUCT value | |
| */ | |
| private calculatePUCT(node: MCTSNode, actionKey: string, isWhite: boolean): number { | |
| const Q = node.Q.get(actionKey) ?? 0; | |
| const N = node.N.get(actionKey) ?? 0; | |
| const P = node.P.get(actionKey) ?? 0; | |
| // Sum of all visit counts at this node | |
| const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0); | |
| // Exploration term: U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b) + 1) / (1 + N(s,a)) | |
| // +1 in sqrt to avoid zero exploration when node first expanded | |
| const U = this.config.cPuct * P * Math.sqrt(totalN + 1) / (1 + N); | |
| // Black minimizes Q (flips sign), White maximizes Q | |
| return (isWhite ? Q : -Q) + U; | |
| } | |
| /** | |
| * Select action to play based on visit counts | |
| * Uses temperature to control exploration vs exploitation | |
| * | |
| * @param node Root node | |
| * @param temperature Selection temperature (τ=1 for exploration, τ→0 for greedy) | |
| * @returns Selected move | |
| */ | |
| private selectPlayAction(node: MCTSNode, temperature: number): Move { | |
| const actionKeys = Array.from(node.N.keys()); | |
| // Edge case: no actions available (unexpanded root or terminal state) | |
| if (actionKeys.length === 0) { | |
| // Fallback to priors if available | |
| const priorKeys = Array.from(node.P.keys()); | |
| if (priorKeys.length > 0) { | |
| // Sample from prior distribution | |
| const priors = priorKeys.map(key => node.P.get(key) ?? 0); | |
| const sumP = priors.reduce((sum, p) => sum + p, 0); | |
| if (sumP > 0) { | |
| let rand = Math.random() * sumP; | |
| for (let i = 0; i < priorKeys.length; i++) { | |
| rand -= priors[i]; | |
| if (rand <= 0) { | |
| return this.decodeAction(priorKeys[i]); | |
| } | |
| } | |
| return this.decodeAction(priorKeys[priorKeys.length - 1]); | |
| } | |
| // Uniform fallback | |
| const randomIndex = Math.floor(Math.random() * priorKeys.length); | |
| return this.decodeAction(priorKeys[randomIndex]); | |
| } | |
| // No actions at all - return Pass as last resort | |
| return { player: "black", isPass: true }; | |
| } | |
| if (temperature < 0.01) { | |
| // Greedy: Select action with highest visit count | |
| let bestActionKey = actionKeys[0]; | |
| let bestN = node.N.get(bestActionKey) ?? 0; | |
| for (let i = 1; i < actionKeys.length; i++) { | |
| const actionKey = actionKeys[i]; | |
| const n = node.N.get(actionKey) ?? 0; | |
| if (n > bestN) { | |
| bestN = n; | |
| bestActionKey = actionKey; | |
| } | |
| } | |
| return this.decodeAction(bestActionKey); | |
| } else { | |
| // Temperature-based sampling: π(a|s) ∝ N(s,a)^(1/τ) | |
| const nValues = actionKeys.map(key => node.N.get(key) ?? 0); | |
| const nPowered = nValues.map(n => Math.pow(n, 1 / temperature)); | |
| const sumN = nPowered.reduce((sum, n) => sum + n, 0); | |
| // Handle edge case: if all visits are 0 or sum is invalid | |
| if (!isFinite(sumN) || sumN <= 0) { | |
| // Fallback to uniform random selection (or use priors) | |
| const randomIndex = Math.floor(Math.random() * actionKeys.length); | |
| return this.decodeAction(actionKeys[randomIndex]); | |
| } | |
| // Sample from distribution | |
| let rand = Math.random() * sumN; | |
| for (let i = 0; i < actionKeys.length; i++) { | |
| rand -= nPowered[i]; | |
| if (rand <= 0) { | |
| return this.decodeAction(actionKeys[i]); | |
| } | |
| } | |
| // Fallback (shouldn't reach here due to floating point precision) | |
| return this.decodeAction(actionKeys[actionKeys.length - 1]); | |
| } | |
| } | |
| /** | |
| * Compute search policy from visit counts | |
| * π(a|s) = N(s,a)^(1/τ) / Σ_b N(s,b)^(1/τ) | |
| * | |
| * @param node Root node | |
| * @param temperature Selection temperature | |
| * @returns Normalized policy distribution | |
| */ | |
| private computeSearchPolicy(node: MCTSNode, temperature: number): Map<string, number> { | |
| const policy = new Map<string, number>(); | |
| const actionKeys = Array.from(node.N.keys()); | |
| // Compute π(a|s) ∝ N(s,a)^(1/τ) | |
| const nPowered = actionKeys.map(key => Math.pow(node.N.get(key) ?? 0, 1 / temperature)); | |
| const sumN = nPowered.reduce((sum, n) => sum + n, 0); | |
| // Handle edge case: if all visits are 0 or sum is invalid | |
| if (!isFinite(sumN) || sumN <= 0) { | |
| // Fallback to uniform distribution | |
| const uniform = 1 / actionKeys.length; | |
| for (const key of actionKeys) { | |
| policy.set(key, uniform); | |
| } | |
| return policy; | |
| } | |
| for (let i = 0; i < actionKeys.length; i++) { | |
| const actionKey = actionKeys[i]; | |
| policy.set(actionKey, nPowered[i] / sumN); | |
| } | |
| return policy; | |
| } | |
| /** | |
| * Get root value estimate (weighted average of Q-values) | |
| */ | |
| private getRootValue(node: MCTSNode): number { | |
| const actionKeys = Array.from(node.N.keys()); | |
| const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0); | |
| if (totalN === 0) { | |
| return 0; | |
| } | |
| let weightedSum = 0; | |
| for (const actionKey of actionKeys) { | |
| const q = node.Q.get(actionKey) ?? 0; | |
| const n = node.N.get(actionKey) ?? 0; | |
| weightedSum += q * n; | |
| } | |
| return weightedSum / totalN; | |
| } | |
| /** | |
| * Add Dirichlet noise to prior probabilities at root | |
| * P(s,a) = (1 - ε) * p_a + ε * η_a | |
| * where η ~ Dir(α) | |
| * | |
| * Note: Pass move is excluded from noise to prevent exploration of | |
| * clearly suboptimal opening passes. | |
| */ | |
| private addDirichletNoise(priors: Map<string, number>): void { | |
| // Exclude Pass from Dirichlet noise - it should not be explored at root | |
| const actionKeys = Array.from(priors.keys()).filter(key => key !== "pass"); | |
| const alpha = this.config.dirichletAlpha; | |
| const epsilon = this.config.dirichletEpsilon; | |
| // If only Pass is available, no noise to add | |
| if (actionKeys.length === 0) { | |
| return; | |
| } | |
| // Generate Dirichlet noise (simplified using Gamma distribution) | |
| const noise: number[] = []; | |
| let noiseSum = 0; | |
| for (let i = 0; i < actionKeys.length; i++) { | |
| // Gamma(α, 1) approximation using rejection sampling | |
| const sample = this.sampleGamma(alpha); | |
| noise.push(sample); | |
| noiseSum += sample; | |
| } | |
| // Handle edge case: if all Gamma samples are 0 (extremely unlikely but possible) | |
| if (!isFinite(noiseSum) || noiseSum <= 0) { | |
| // Fallback: use uniform noise (no mixing, keep original priors) | |
| return; | |
| } | |
| // Normalize and mix with priors (only for non-Pass actions) | |
| for (let i = 0; i < actionKeys.length; i++) { | |
| const actionKey = actionKeys[i]; | |
| const prior = priors.get(actionKey) ?? 0; | |
| const noiseFraction = noise[i] / noiseSum; | |
| priors.set(actionKey, (1 - epsilon) * prior + epsilon * noiseFraction); | |
| } | |
| } | |
| /** | |
| * Sample from Gamma distribution using Marsaglia and Tsang method (2000) | |
| * Used for Dirichlet noise generation | |
| */ | |
| private sampleGamma(alpha: number): number { | |
| if (alpha <= 0) { | |
| throw new Error("Gamma distribution alpha must be > 0"); | |
| } | |
| // For alpha < 1, use transformation: sample Gamma(alpha+1) then multiply by U^(1/alpha) | |
| if (alpha < 1) { | |
| const u = Math.random(); | |
| const g = this.sampleGamma(alpha + 1); | |
| return g * Math.pow(u, 1 / alpha); | |
| } | |
| // For alpha >= 1, use Marsaglia and Tsang's method | |
| const d = alpha - 1/3; | |
| const c = 1 / Math.sqrt(9 * d); | |
| while (true) { | |
| let x, v; | |
| do { | |
| x = this.randomNormal(); | |
| v = 1 + c * x; | |
| } while (v <= 0); | |
| v = v * v * v; | |
| const u = Math.random(); | |
| // Fast acceptance check | |
| if (u < 1 - 0.0331 * x * x * x * x) { | |
| return d * v; | |
| } | |
| // Fallback acceptance check | |
| if (Math.log(u) < 0.5 * x * x + d * (1 - v + Math.log(v))) { | |
| return d * v; | |
| } | |
| } | |
| } | |
| /** | |
| * Sample from standard normal distribution (Box-Muller transform) | |
| */ | |
| private randomNormal(): number { | |
| const u1 = Math.random(); | |
| const u2 = Math.random(); | |
| return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2); | |
| } | |
| /** | |
| * Check if game state is terminal and return value if so | |
| * | |
| * Terminal conditions (checked in order of cost): | |
| * 1. Game already finished (double-pass or resignation) - CHEAPEST | |
| * 2. Board coverage > 50% AND naturally terminal (calls isNaturallyTerminal) - EXPENSIVE | |
| * | |
| * NOTE: The coverage check (> 50%) is an optimization to avoid expensive | |
| * territory calculations on sparse boards where natural termination is unlikely. | |
| * | |
| * @param state Game state to check | |
| * @returns Terminal value (white-positive) if terminal, null otherwise | |
| */ | |
| private checkTerminal(state: TrigoGame): number | null { | |
| // 1. Check if game is already finished (double-pass, resignation, etc.) | |
| // This is the cheapest check - just reading a status flag | |
| if (state.getGameStatus() === "finished") { | |
| const territory = state.getTerritory(); | |
| return this.calculateTerminalValue(territory); | |
| } | |
| // 2. Check for "natural" game end (all territory claimed, no capturing moves) | |
| // Optimization: Only check if board is reasonably full (> 50% coverage) | |
| // because natural termination is unlikely on sparse boards | |
| const board = state.getBoard(); | |
| const shape = state.getShape(); | |
| const totalPositions = shape.x * shape.y * shape.z; | |
| // Count stones (cheap) | |
| let stoneCount = 0; | |
| for (let x = 0; x < shape.x; x++) { | |
| for (let y = 0; y < shape.y; y++) { | |
| for (let z = 0; z < shape.z; z++) { | |
| const stone = board[x][y][z]; | |
| if (stone === 1 || stone === 2) { // StoneType.BLACK or WHITE | |
| stoneCount++; | |
| } | |
| } | |
| } | |
| } | |
| const coverageRatio = stoneCount / totalPositions; | |
| // Only check for natural termination if board is reasonably full | |
| if (coverageRatio > 0.5) { | |
| if (state.isNaturallyTerminal()) { | |
| const territory = state.getTerritory(); | |
| return this.calculateTerminalValue(territory); | |
| } | |
| } | |
| return null; // Not terminal | |
| } | |
| /** | |
| * Calculate terminal value from territory scores | |
| * Uses logarithmic scaling matching the training code | |
| * | |
| * @param territory Territory counts from game | |
| * @returns Value (white-positive: positive = white winning) | |
| */ | |
| private calculateTerminalValue(territory: { black: number; white: number; neutral: number }): number { | |
| const scoreDiff = territory.white - territory.black; | |
| if (Math.abs(scoreDiff) < 1e-6) { | |
| // Draw/tie case | |
| return 0.0; | |
| } | |
| // Match training formula from valueCausalLoss.py:_expand_value_targets | |
| // target = sign(score) * (1 + log(|score|)) * territory_value_factor | |
| // The log term incentivizes winning by larger margins (logarithmically) | |
| const territory_value_factor = 1.0; // Default from training config | |
| const signScore = Math.sign(scoreDiff); | |
| return signScore * (1 + Math.log(Math.abs(scoreDiff))) * territory_value_factor; | |
| } | |
| /** | |
| * Create a new MCTS node | |
| * | |
| * @param state Game state (only provided for root node, null for others to save memory) | |
| * @param parent Parent node | |
| * @param action Action that led to this node | |
| * @param playerToMove Player to move at this node (derived from state if available) | |
| */ | |
| private createNode(state: TrigoGame | null, parent: MCTSNode | null, action: Move | null, playerToMove?: number): MCTSNode { | |
| // Determine player to move | |
| let player: number; | |
| if (playerToMove !== undefined) { | |
| player = playerToMove; | |
| } else if (state) { | |
| // Most reliable: derive from actual game state | |
| player = state.getCurrentPlayer(); | |
| } else if (parent) { | |
| // NOTE: Fallback assumes strictly alternating turns (no passes keeping same player) | |
| // For standard Go-like games with strict alternation, this is safe. | |
| player = parent.playerToMove === 1 ? 2 : 1; | |
| } else { | |
| // Default to Black for root if no info | |
| player = 1; | |
| } | |
| return { | |
| state, | |
| parent, | |
| action, | |
| N: new Map(), | |
| W: new Map(), | |
| Q: new Map(), | |
| P: new Map(), | |
| children: new Map(), | |
| expanded: false, | |
| terminalValue: null, | |
| depth: parent ? parent.depth + 1 : 0, | |
| playerToMove: player | |
| }; | |
| } | |
| /** | |
| * Encode move to string key for storage in maps | |
| * Note: Only encodes position, player info is handled separately | |
| */ | |
| private encodeAction(move: Move): string { | |
| if (move.isPass) { | |
| return "pass"; | |
| } | |
| return `${move.x},${move.y},${move.z}`; | |
| } | |
| /** | |
| * Decode string key back to move | |
| * Note: Returns move with placeholder player - caller must set correct player | |
| * based on game state before using the move externally | |
| */ | |
| private decodeAction(key: string): Move { | |
| if (key === "pass") { | |
| // Player is placeholder - will be set by caller (selectMove sets it from game state) | |
| return { player: "black", isPass: true }; | |
| } | |
| const [x, y, z] = key.split(",").map(Number); | |
| // Player is placeholder - will be set by caller (selectMove sets it from game state) | |
| return { player: "black", x, y, z }; | |
| } | |
| } | |