Spaces:
Sleeping
Sleeping
File size: 27,530 Bytes
15f353f | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 331 332 333 334 335 336 337 338 339 340 341 342 343 344 345 346 347 348 349 350 351 352 353 354 355 356 357 358 359 360 361 362 363 364 365 366 367 368 369 370 371 372 373 374 375 376 377 378 379 380 381 382 383 384 385 386 387 388 389 390 391 392 393 394 395 396 397 398 399 400 401 402 403 404 405 406 407 408 409 410 411 412 413 414 415 416 417 418 419 420 421 422 423 424 425 426 427 428 429 430 431 432 433 434 435 436 437 438 439 440 441 442 443 444 445 446 447 448 449 450 451 452 453 454 455 456 457 458 459 460 461 462 463 464 465 466 467 468 469 470 471 472 473 474 475 476 477 478 479 480 481 482 483 484 485 486 487 488 489 490 491 492 493 494 495 496 497 498 499 500 501 502 503 504 505 506 507 508 509 510 511 512 513 514 515 516 517 518 519 520 521 522 523 524 525 526 527 528 529 530 531 532 533 534 535 536 537 538 539 540 541 542 543 544 545 546 547 548 549 550 551 552 553 554 555 556 557 558 559 560 561 562 563 564 565 566 567 568 569 570 571 572 573 574 575 576 577 578 579 580 581 582 583 584 585 586 587 588 589 590 591 592 593 594 595 596 597 598 599 600 601 602 603 604 605 606 607 608 609 610 611 612 613 614 615 616 617 618 619 620 621 622 623 624 625 626 627 628 629 630 631 632 633 634 635 636 637 638 639 640 641 642 643 644 645 646 647 648 649 650 651 652 653 654 655 656 657 658 659 660 661 662 663 664 665 666 667 668 669 670 671 672 673 674 675 676 677 678 679 680 681 682 683 684 685 686 687 688 689 690 691 692 693 694 695 696 697 698 699 700 701 702 703 704 705 706 707 708 709 710 711 712 713 714 715 716 717 718 719 720 721 722 723 724 725 726 727 728 729 730 731 732 733 734 735 736 737 738 739 740 741 742 743 744 745 746 747 748 749 750 751 752 753 754 755 756 757 758 759 760 761 762 763 764 765 766 767 768 769 770 771 772 773 774 775 776 777 778 779 780 781 782 783 784 785 786 787 788 789 790 791 792 793 794 795 796 797 798 799 800 801 802 803 804 805 806 807 808 809 810 811 812 813 814 815 816 817 818 819 820 821 822 823 824 825 826 827 828 829 830 831 832 833 834 835 836 837 838 839 840 841 842 843 844 845 846 847 848 849 850 851 852 853 854 855 856 | /**
* Monte Carlo Tree Search (MCTS) Agent for Trigo
*
* Implements AlphaGo Zero-style MCTS with:
* - PUCT (Polynomial Upper Confidence Trees) selection
* - Neural network guidance for policy and value
* - Visit count statistics for training data generation
*
* Based on: Silver et al., "Mastering the Game of Go without Human Knowledge"
*/
import { TrigoGame } from "./trigo/game";
import type { Move } from "./trigo/types";
import { TrigoTreeAgent } from "./trigoTreeAgent";
import { TrigoEvaluationAgent } from "./trigoEvaluationAgent";
/**
* MCTS Configuration
*/
export interface MCTSConfig {
numSimulations: number; // Number of MCTS simulations per move (default: 600)
cPuct: number; // PUCT exploration constant (default: 1.0)
temperature: number; // Selection temperature for first 30 moves (default: 1.0)
dirichletAlpha: number; // Dirichlet noise alpha parameter (default: 0.03)
dirichletEpsilon: number; // Dirichlet noise mixing weight (default: 0.25)
}
/**
* MCTS Tree Node
* Stores search statistics for all legal actions from a given game state
*
* Memory optimization: Only root node stores the full game state.
* Non-root nodes only store the action that led to them.
* During simulation, a working state is cloned once and mutated along the path.
*/
interface MCTSNode {
state: TrigoGame | null; // Game state (only stored at root node for memory efficiency)
parent: MCTSNode | null; // Parent node (null for root)
action: Move | null; // Action that led to this node (null for root)
// MCTS statistics per action (action key -> value)
N: Map<string, number>; // Visit counts N(s,a)
W: Map<string, number>; // Total action-value W(s,a)
Q: Map<string, number>; // Mean action-value Q(s,a) = W(s,a) / N(s,a)
P: Map<string, number>; // Prior probabilities P(s,a) from policy network
children: Map<string, MCTSNode>; // Child nodes (action key -> child node)
expanded: boolean; // Whether this node has been expanded
terminalValue: number | null; // Cached terminal value (null if not terminal or not computed)
// Terminal propagation optimization (GPT-5.1 suggestions)
depth: number; // Distance from root (0 for root)
playerToMove: number; // Player to move at this node (1=Black, 2=White)
}
/**
* MCTS Agent
* Combines tree search with neural network evaluation
*/
export class MCTSAgent {
private treeAgent: TrigoTreeAgent; // For policy priors
private evaluationAgent: TrigoEvaluationAgent; // For value evaluation
private config: MCTSConfig;
public debugMode: boolean = false; // Enable debug logging
constructor(
treeAgent: TrigoTreeAgent,
evaluationAgent: TrigoEvaluationAgent,
config: Partial<MCTSConfig> = {}
) {
this.treeAgent = treeAgent;
this.evaluationAgent = evaluationAgent;
// Default configuration (AlphaGo Zero-inspired)
this.config = {
numSimulations: config.numSimulations ?? 600,
cPuct: config.cPuct ?? 1.0,
temperature: config.temperature ?? 1.0,
dirichletAlpha: config.dirichletAlpha ?? 0.03,
dirichletEpsilon: config.dirichletEpsilon ?? 0.25
};
}
/**
* Select best move using MCTS
*
* @param game Current game state
* @param moveNumber Move number (for temperature schedule)
* @returns Selected move with visit count statistics
*/
async selectMove(game: TrigoGame, moveNumber: number): Promise<{
move: Move;
visitCounts: Map<string, number>;
searchPolicy: Map<string, number>; // Normalized visit counts π(a|s)
rootValue: number;
}> {
// Create root node
const root = this.createNode(game, null, null);
// Check if root is already terminal (game over)
const terminalResult = this.checkTerminal(game);
if (terminalResult !== null) {
const currentPlayer = game.getCurrentPlayer();
return {
move: { player: currentPlayer === 1 ? "black" : "white", isPass: true },
visitCounts: new Map(),
searchPolicy: new Map(),
rootValue: terminalResult
};
}
// Run MCTS simulations
for (let i = 0; i < this.config.numSimulations; i++) {
await this.runSimulation(root, i);
}
// Temperature schedule: τ=1 for first 30 moves, τ→0 afterward
const temperature = moveNumber < 30 ? this.config.temperature : 0.01;
// Select move based on visit counts
const move = this.selectPlayAction(root, temperature);
// Set correct player for returned move
const currentPlayer = game.getCurrentPlayer();
move.player = currentPlayer === 1 ? "black" : "white";
// Compute search policy (normalized visit counts)
const searchPolicy = this.computeSearchPolicy(root, temperature);
// Get root value estimate (average Q-value weighted by visit counts)
const rootValue = this.getRootValue(root);
return {
move,
visitCounts: new Map(root.N),
searchPolicy,
rootValue
};
}
/**
* Run a single MCTS simulation
* Select -> Expand & Evaluate -> Backup
*
* Memory optimization: Clone state once at start, mutate along path.
* This reduces memory from O(nodes) to O(simulations).
*/
private async runSimulation(root: MCTSNode, simIndex?: number): Promise<void> {
// Invariant: root node must always have a non-null state
if (!root.state) {
throw new Error("runSimulation: root node must have a non-null state");
}
// Clone root state once for this simulation
const workingState = root.state.clone();
// 1. Selection: Traverse tree using PUCT until reaching unexpanded node
const { node, path } = this.select(root, workingState);
// 2. Expand and Evaluate: Get value from neural network
const value = await this.expandAndEvaluate(node, workingState);
// Debug logging
if (this.debugMode && simIndex !== undefined && simIndex < 10) {
const pathStr = path.map(p => p.actionKey).join(" → ");
const terminalStr = node.terminalValue !== null ? " [TERMINAL]" : "";
console.log(`Sim ${simIndex + 1}: ${pathStr || "(root)"} → value=${value.toFixed(4)}${terminalStr}`);
}
// 3. Backup: Propagate value up the tree
this.backup(path, value);
}
/**
* Selection phase: Traverse tree using PUCT
*
* @param root Root node to start selection from
* @param workingState Mutable game state that gets updated along the path
* @returns Leaf node and path taken
*/
private select(root: MCTSNode, workingState: TrigoGame): {
node: MCTSNode;
path: Array<{ node: MCTSNode; actionKey: string }>;
} {
const path: Array<{ node: MCTSNode; actionKey: string }> = [];
let node = root;
// Traverse until we reach an unexpanded node
while (node.expanded) {
// GPT-5.1 recommendation: Stop at terminal nodes immediately
// Terminal nodes should not be expanded or evaluated further
if (node.terminalValue !== null) {
break; // Return terminal node, use its cached value
}
// Get all legal actions
const actionKeys = Array.from(node.P.keys());
// Terminal node check: if expanded but no actions, this is a terminal node
if (actionKeys.length === 0) {
break; // Return this terminal node as leaf
}
// Select action with best PUCT value
// Both players select HIGHEST PUCT value:
// - Black: PUCT = -Q + U, max PUCT = max(-Q) = min(Q) ✓
// - White: PUCT = Q + U, max PUCT = max(Q) ✓
const currentPlayer = workingState.getCurrentPlayer();
const isWhite = currentPlayer === 2;
let bestActionKey = actionKeys[0];
let bestPuct = this.calculatePUCT(node, bestActionKey, isWhite);
for (let i = 1; i < actionKeys.length; i++) {
const actionKey = actionKeys[i];
const puct = this.calculatePUCT(node, actionKey, isWhite);
if (puct > bestPuct) {
bestPuct = puct;
bestActionKey = actionKey;
}
}
// Record path
path.push({ node, actionKey: bestActionKey });
// Apply action to working state (instead of cloning)
const action = this.decodeAction(bestActionKey);
if (action.isPass) {
workingState.pass();
} else if (action.x !== undefined && action.y !== undefined && action.z !== undefined) {
workingState.drop({ x: action.x, y: action.y, z: action.z });
}
// Move to child (create if doesn't exist)
if (!node.children.has(bestActionKey)) {
// Create child node WITHOUT storing state (memory optimization)
const childNode = this.createNode(null, node, action);
node.children.set(bestActionKey, childNode);
}
node = node.children.get(bestActionKey)!;
}
return { node, path };
}
/**
* Expand and evaluate leaf node using neural networks
*
* @param node Leaf node to expand
* @param workingState Current game state at this node (passed from simulation)
* @returns Value estimate from evaluation network
*/
private async expandAndEvaluate(node: MCTSNode, workingState: TrigoGame): Promise<number> {
// Check if terminal value is already cached
if (node.terminalValue !== null) {
return node.terminalValue;
}
// Check if game is over (terminal state)
const terminalValue = this.checkTerminal(workingState);
if (terminalValue !== null) {
// Mark terminal node as expanded with empty action set to prevent revisits
// Cache the terminal value to avoid repeated checks
node.expanded = true;
node.terminalValue = terminalValue;
node.P = new Map(); // No actions available (terminal)
node.N = new Map();
node.W = new Map();
node.Q = new Map();
node.children = new Map();
return terminalValue;
}
// Non-terminal state: expand with policy network and evaluate
// Get all valid moves
const currentPlayer = workingState.getCurrentPlayer() === 1 ? "black" : "white";
const validPositions = workingState.validMovePositions();
const moves: Move[] = validPositions.map(pos => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
moves.push({ player: currentPlayer, isPass: true });
// Get policy priors from tree agent
const scoredMoves = await this.treeAgent.scoreMoves(workingState, moves);
// Convert log probabilities to probabilities and normalize (stable softmax)
const maxScore = Math.max(...scoredMoves.map(m => m.score));
const expScores = scoredMoves.map(m => Math.exp(m.score - maxScore));
const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
// Initialize priors P(s,a)
node.P = new Map();
node.N = new Map();
node.W = new Map();
node.Q = new Map();
// Handle edge case: if all scores are -Infinity or sumExp is 0/NaN
const useFallback = !isFinite(sumExp) || sumExp < 1e-10;
for (let i = 0; i < scoredMoves.length; i++) {
const actionKey = this.encodeAction(scoredMoves[i].move);
// Use uniform distribution as fallback if normalization fails
const prior = useFallback ? (1.0 / scoredMoves.length) : (expScores[i] / sumExp);
node.P.set(actionKey, prior);
node.N.set(actionKey, 0);
node.W.set(actionKey, 0);
node.Q.set(actionKey, 0);
}
// Add Dirichlet noise at root
if (node.parent === null) {
this.addDirichletNoise(node.P);
}
// Mark as expanded
node.expanded = true;
// Get value estimate from evaluation agent
const evaluation = await this.evaluationAgent.evaluatePosition(workingState);
// Return value directly (value model returns white-positive by design)
return evaluation.value;
}
/**
* Backup phase: Propagate value up the tree
*
* White-positive minimax propagation:
* - All Q-values represent White's advantage (positive = White winning)
* - When all children are terminal, mark parent as terminal with minimax value:
* * White's turn: terminal_value = max(children terminal values)
* * Black's turn: terminal_value = min(children terminal values)
*
* Improvements (based on GPT-5.1 review):
* - Uses stored playerToMove instead of computing from depth
* - Uses stored depth instead of recomputing via parent walk
*
* @param path Path from root to leaf
* @param value Value to propagate (white-positive: positive = white winning)
*/
private backup(path: Array<{ node: MCTSNode; actionKey: string }>, value: number): void {
// Propagate value up the tree (white-positive throughout)
// No sign flipping needed - Q values are always white-positive
for (let i = path.length - 1; i >= 0; i--) {
const { node, actionKey } = path[i];
// Update statistics
const n = node.N.get(actionKey) ?? 0;
const w = node.W.get(actionKey) ?? 0;
node.N.set(actionKey, n + 1);
node.W.set(actionKey, w + value);
node.Q.set(actionKey, (w + value) / (n + 1));
// ========== Terminal State Propagation ==========
// Check if this node should be marked as terminal
// Condition: node is fully expanded AND all children are terminal AND node itself not yet marked
if (node.expanded && node.terminalValue === null) {
const actionKeys = Array.from(node.P.keys());
// Skip propagation if no actions (already a terminal leaf, or error state)
if (actionKeys.length === 0) {
continue;
}
// Check if ALL children are terminal
let allChildrenTerminal = true;
const childTerminalValues: number[] = [];
for (const key of actionKeys) {
const child = node.children.get(key);
// If child doesn't exist yet, not all children explored
if (!child) {
allChildrenTerminal = false;
break;
}
// If child is not terminal, not all children terminal
if (child.terminalValue === null) {
allChildrenTerminal = false;
break;
}
// Child is terminal, collect its value
childTerminalValues.push(child.terminalValue);
}
// If all children are terminal, mark current node as terminal with minimax value
if (allChildrenTerminal && childTerminalValues.length > 0) {
// Use stored playerToMove instead of computing from depth (GPT-5.1 suggestion)
const isWhiteTurn = node.playerToMove === 2; // 2 = White, 1 = Black
// Apply minimax: choose best child value from current player's perspective
let terminalValue: number;
if (isWhiteTurn) {
// White maximizes Q-value (white-positive)
terminalValue = Math.max(...childTerminalValues);
} else {
// Black minimizes Q-value (white-positive)
terminalValue = Math.min(...childTerminalValues);
}
// Mark this node as terminal with the minimax value
node.terminalValue = terminalValue;
// Debug logging for terminal propagation
if (this.debugMode) {
const playerName = isWhiteTurn ? 'White' : 'Black';
console.log(
`[Terminal Propagation] Node at depth ${node.depth} (${playerName}) marked terminal: ` +
`value=${terminalValue.toFixed(4)}, children=[${childTerminalValues.map(v => v.toFixed(2)).join(', ')}]`
);
}
}
}
// ================================================
}
}
/**
* Calculate PUCT value for action selection
*
* PUCT = Q(s,a) + U(s,a) [for White, who maximizes]
* PUCT = -Q(s,a) + U(s,a) [for Black, who minimizes]
* where U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b)) / (1 + N(s,a))
*
* @param node Current node
* @param actionKey Action to evaluate
* @param isWhite Whether current player is White
* @returns PUCT value
*/
private calculatePUCT(node: MCTSNode, actionKey: string, isWhite: boolean): number {
const Q = node.Q.get(actionKey) ?? 0;
const N = node.N.get(actionKey) ?? 0;
const P = node.P.get(actionKey) ?? 0;
// Sum of all visit counts at this node
const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0);
// Exploration term: U(s,a) = c_puct * P(s,a) * sqrt(Σ_b N(s,b) + 1) / (1 + N(s,a))
// +1 in sqrt to avoid zero exploration when node first expanded
const U = this.config.cPuct * P * Math.sqrt(totalN + 1) / (1 + N);
// Black minimizes Q (flips sign), White maximizes Q
return (isWhite ? Q : -Q) + U;
}
/**
* Select action to play based on visit counts
* Uses temperature to control exploration vs exploitation
*
* @param node Root node
* @param temperature Selection temperature (τ=1 for exploration, τ→0 for greedy)
* @returns Selected move
*/
private selectPlayAction(node: MCTSNode, temperature: number): Move {
const actionKeys = Array.from(node.N.keys());
// Edge case: no actions available (unexpanded root or terminal state)
if (actionKeys.length === 0) {
// Fallback to priors if available
const priorKeys = Array.from(node.P.keys());
if (priorKeys.length > 0) {
// Sample from prior distribution
const priors = priorKeys.map(key => node.P.get(key) ?? 0);
const sumP = priors.reduce((sum, p) => sum + p, 0);
if (sumP > 0) {
let rand = Math.random() * sumP;
for (let i = 0; i < priorKeys.length; i++) {
rand -= priors[i];
if (rand <= 0) {
return this.decodeAction(priorKeys[i]);
}
}
return this.decodeAction(priorKeys[priorKeys.length - 1]);
}
// Uniform fallback
const randomIndex = Math.floor(Math.random() * priorKeys.length);
return this.decodeAction(priorKeys[randomIndex]);
}
// No actions at all - return Pass as last resort
return { player: "black", isPass: true };
}
if (temperature < 0.01) {
// Greedy: Select action with highest visit count
let bestActionKey = actionKeys[0];
let bestN = node.N.get(bestActionKey) ?? 0;
for (let i = 1; i < actionKeys.length; i++) {
const actionKey = actionKeys[i];
const n = node.N.get(actionKey) ?? 0;
if (n > bestN) {
bestN = n;
bestActionKey = actionKey;
}
}
return this.decodeAction(bestActionKey);
} else {
// Temperature-based sampling: π(a|s) ∝ N(s,a)^(1/τ)
const nValues = actionKeys.map(key => node.N.get(key) ?? 0);
const nPowered = nValues.map(n => Math.pow(n, 1 / temperature));
const sumN = nPowered.reduce((sum, n) => sum + n, 0);
// Handle edge case: if all visits are 0 or sum is invalid
if (!isFinite(sumN) || sumN <= 0) {
// Fallback to uniform random selection (or use priors)
const randomIndex = Math.floor(Math.random() * actionKeys.length);
return this.decodeAction(actionKeys[randomIndex]);
}
// Sample from distribution
let rand = Math.random() * sumN;
for (let i = 0; i < actionKeys.length; i++) {
rand -= nPowered[i];
if (rand <= 0) {
return this.decodeAction(actionKeys[i]);
}
}
// Fallback (shouldn't reach here due to floating point precision)
return this.decodeAction(actionKeys[actionKeys.length - 1]);
}
}
/**
* Compute search policy from visit counts
* π(a|s) = N(s,a)^(1/τ) / Σ_b N(s,b)^(1/τ)
*
* @param node Root node
* @param temperature Selection temperature
* @returns Normalized policy distribution
*/
private computeSearchPolicy(node: MCTSNode, temperature: number): Map<string, number> {
const policy = new Map<string, number>();
const actionKeys = Array.from(node.N.keys());
// Compute π(a|s) ∝ N(s,a)^(1/τ)
const nPowered = actionKeys.map(key => Math.pow(node.N.get(key) ?? 0, 1 / temperature));
const sumN = nPowered.reduce((sum, n) => sum + n, 0);
// Handle edge case: if all visits are 0 or sum is invalid
if (!isFinite(sumN) || sumN <= 0) {
// Fallback to uniform distribution
const uniform = 1 / actionKeys.length;
for (const key of actionKeys) {
policy.set(key, uniform);
}
return policy;
}
for (let i = 0; i < actionKeys.length; i++) {
const actionKey = actionKeys[i];
policy.set(actionKey, nPowered[i] / sumN);
}
return policy;
}
/**
* Get root value estimate (weighted average of Q-values)
*/
private getRootValue(node: MCTSNode): number {
const actionKeys = Array.from(node.N.keys());
const totalN = Array.from(node.N.values()).reduce((sum, n) => sum + n, 0);
if (totalN === 0) {
return 0;
}
let weightedSum = 0;
for (const actionKey of actionKeys) {
const q = node.Q.get(actionKey) ?? 0;
const n = node.N.get(actionKey) ?? 0;
weightedSum += q * n;
}
return weightedSum / totalN;
}
/**
* Add Dirichlet noise to prior probabilities at root
* P(s,a) = (1 - ε) * p_a + ε * η_a
* where η ~ Dir(α)
*
* Note: Pass move is excluded from noise to prevent exploration of
* clearly suboptimal opening passes.
*/
private addDirichletNoise(priors: Map<string, number>): void {
// Exclude Pass from Dirichlet noise - it should not be explored at root
const actionKeys = Array.from(priors.keys()).filter(key => key !== "pass");
const alpha = this.config.dirichletAlpha;
const epsilon = this.config.dirichletEpsilon;
// If only Pass is available, no noise to add
if (actionKeys.length === 0) {
return;
}
// Generate Dirichlet noise (simplified using Gamma distribution)
const noise: number[] = [];
let noiseSum = 0;
for (let i = 0; i < actionKeys.length; i++) {
// Gamma(α, 1) approximation using rejection sampling
const sample = this.sampleGamma(alpha);
noise.push(sample);
noiseSum += sample;
}
// Handle edge case: if all Gamma samples are 0 (extremely unlikely but possible)
if (!isFinite(noiseSum) || noiseSum <= 0) {
// Fallback: use uniform noise (no mixing, keep original priors)
return;
}
// Normalize and mix with priors (only for non-Pass actions)
for (let i = 0; i < actionKeys.length; i++) {
const actionKey = actionKeys[i];
const prior = priors.get(actionKey) ?? 0;
const noiseFraction = noise[i] / noiseSum;
priors.set(actionKey, (1 - epsilon) * prior + epsilon * noiseFraction);
}
}
/**
* Sample from Gamma distribution using Marsaglia and Tsang method (2000)
* Used for Dirichlet noise generation
*/
private sampleGamma(alpha: number): number {
if (alpha <= 0) {
throw new Error("Gamma distribution alpha must be > 0");
}
// For alpha < 1, use transformation: sample Gamma(alpha+1) then multiply by U^(1/alpha)
if (alpha < 1) {
const u = Math.random();
const g = this.sampleGamma(alpha + 1);
return g * Math.pow(u, 1 / alpha);
}
// For alpha >= 1, use Marsaglia and Tsang's method
const d = alpha - 1/3;
const c = 1 / Math.sqrt(9 * d);
while (true) {
let x, v;
do {
x = this.randomNormal();
v = 1 + c * x;
} while (v <= 0);
v = v * v * v;
const u = Math.random();
// Fast acceptance check
if (u < 1 - 0.0331 * x * x * x * x) {
return d * v;
}
// Fallback acceptance check
if (Math.log(u) < 0.5 * x * x + d * (1 - v + Math.log(v))) {
return d * v;
}
}
}
/**
* Sample from standard normal distribution (Box-Muller transform)
*/
private randomNormal(): number {
const u1 = Math.random();
const u2 = Math.random();
return Math.sqrt(-2 * Math.log(u1)) * Math.cos(2 * Math.PI * u2);
}
/**
* Check if game state is terminal and return value if so
*
* Terminal conditions (checked in order of cost):
* 1. Game already finished (double-pass or resignation) - CHEAPEST
* 2. Board coverage > 50% AND naturally terminal (calls isNaturallyTerminal) - EXPENSIVE
*
* NOTE: The coverage check (> 50%) is an optimization to avoid expensive
* territory calculations on sparse boards where natural termination is unlikely.
*
* @param state Game state to check
* @returns Terminal value (white-positive) if terminal, null otherwise
*/
private checkTerminal(state: TrigoGame): number | null {
// 1. Check if game is already finished (double-pass, resignation, etc.)
// This is the cheapest check - just reading a status flag
if (state.getGameStatus() === "finished") {
const territory = state.getTerritory();
return this.calculateTerminalValue(territory);
}
// 2. Check for "natural" game end (all territory claimed, no capturing moves)
// Optimization: Only check if board is reasonably full (> 50% coverage)
// because natural termination is unlikely on sparse boards
const board = state.getBoard();
const shape = state.getShape();
const totalPositions = shape.x * shape.y * shape.z;
// Count stones (cheap)
let stoneCount = 0;
for (let x = 0; x < shape.x; x++) {
for (let y = 0; y < shape.y; y++) {
for (let z = 0; z < shape.z; z++) {
const stone = board[x][y][z];
if (stone === 1 || stone === 2) { // StoneType.BLACK or WHITE
stoneCount++;
}
}
}
}
const coverageRatio = stoneCount / totalPositions;
// Only check for natural termination if board is reasonably full
if (coverageRatio > 0.5) {
if (state.isNaturallyTerminal()) {
const territory = state.getTerritory();
return this.calculateTerminalValue(territory);
}
}
return null; // Not terminal
}
/**
* Calculate terminal value from territory scores
* Uses logarithmic scaling matching the training code
*
* @param territory Territory counts from game
* @returns Value (white-positive: positive = white winning)
*/
private calculateTerminalValue(territory: { black: number; white: number; neutral: number }): number {
const scoreDiff = territory.white - territory.black;
if (Math.abs(scoreDiff) < 1e-6) {
// Draw/tie case
return 0.0;
}
// Match training formula from valueCausalLoss.py:_expand_value_targets
// target = sign(score) * (1 + log(|score|)) * territory_value_factor
// The log term incentivizes winning by larger margins (logarithmically)
const territory_value_factor = 1.0; // Default from training config
const signScore = Math.sign(scoreDiff);
return signScore * (1 + Math.log(Math.abs(scoreDiff))) * territory_value_factor;
}
/**
* Create a new MCTS node
*
* @param state Game state (only provided for root node, null for others to save memory)
* @param parent Parent node
* @param action Action that led to this node
* @param playerToMove Player to move at this node (derived from state if available)
*/
private createNode(state: TrigoGame | null, parent: MCTSNode | null, action: Move | null, playerToMove?: number): MCTSNode {
// Determine player to move
let player: number;
if (playerToMove !== undefined) {
player = playerToMove;
} else if (state) {
// Most reliable: derive from actual game state
player = state.getCurrentPlayer();
} else if (parent) {
// NOTE: Fallback assumes strictly alternating turns (no passes keeping same player)
// For standard Go-like games with strict alternation, this is safe.
player = parent.playerToMove === 1 ? 2 : 1;
} else {
// Default to Black for root if no info
player = 1;
}
return {
state,
parent,
action,
N: new Map(),
W: new Map(),
Q: new Map(),
P: new Map(),
children: new Map(),
expanded: false,
terminalValue: null,
depth: parent ? parent.depth + 1 : 0,
playerToMove: player
};
}
/**
* Encode move to string key for storage in maps
* Note: Only encodes position, player info is handled separately
*/
private encodeAction(move: Move): string {
if (move.isPass) {
return "pass";
}
return `${move.x},${move.y},${move.z}`;
}
/**
* Decode string key back to move
* Note: Returns move with placeholder player - caller must set correct player
* based on game state before using the move externally
*/
private decodeAction(key: string): Move {
if (key === "pass") {
// Player is placeholder - will be set by caller (selectMove sets it from game state)
return { player: "black", isPass: true };
}
const [x, y, z] = key.split(",").map(Number);
// Player is placeholder - will be set by caller (selectMove sets it from game state)
return { player: "black", x, y, z };
}
}
|