"use strict"; /** * Trigo Tree Agent - AI agent using tree attention for efficient move evaluation * * Uses evaluation mode ONNX model to score all valid moves in parallel. * Organizes moves as a prefix tree where branches with same head token are merged. */ Object.defineProperty(exports, "__esModule", { value: true }); exports.TrigoTreeAgent = void 0; const game_1 = require("./trigo/game"); const ab0yz_1 = require("./trigo/ab0yz"); class TrigoTreeAgent { constructor(inferencer) { // Special token constants (must match TGN tokenizer) this.START_TOKEN = 1; this.inferencer = inferencer; } /** * Convert Stone type to player string */ stoneToPlayer(stone) { if (stone === game_1.StoneType.BLACK) return "black"; if (stone === game_1.StoneType.WHITE) return "white"; throw new Error(`Invalid stone type: ${stone}`); } /** * Encode a position to TGN notation (3 characters for 5×5×5 board) */ positionToTGN(pos, shape) { const posArray = [pos.x, pos.y, pos.z]; const shapeArray = [shape.x, shape.y, shape.z]; return (0, ab0yz_1.encodeAb0yz)(posArray, shapeArray); } /** * Convert string to byte tokens (ASCII encoding) */ stringToTokens(str) { const tokens = []; for (let i = 0; i < str.length; i++) { tokens.push(str.charCodeAt(i)); } return tokens; } /** * Build prefix tree from token arrays using recursive merging * Merges branches with the same token at EVERY level * * Algorithm: * 1. Group sequences by their first token * 2. For each group: * - Create one node for the shared first token * - Extract remaining tokens (residues) * - Recursively build subtree from residues * 3. Combine all subtrees and build attention mask * * Example for ["aa", "ab", "ba", "bb"]: * Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"] * Level 2: Within 'a' group, build subtree for ["a","b"] * Within 'b' group, build subtree for ["a","b"] * Result: Two branches, each with properly merged second-level nodes * * @param tokenArrays - Array of token arrays * @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping */ buildPrefixTree(tokenArrays) { let nextPos = 0; // --- Build prefix tree through recursive grouping --- function build(seqs, parent) { // group by token const groups = new Map(); for (const s of seqs) { if (s.tokens.length === 0) continue; const t = s.tokens[0]; if (!groups.has(t)) groups.set(t, []); groups.get(t).push(s); } const levelNodes = []; for (const [token, group] of groups) { const pos = nextPos++; const node = { token, pos, parent, children: [], moveEnds: [] }; // split residues const ends = []; const residues = []; for (const g of group) { if (g.tokens.length === 1) ends.push(g.moveIndex); else residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) }); } node.moveEnds = ends; // create sub nodes recursively if (residues.length > 0) { node.children = build(residues, pos); } levelNodes.push(node); } return levelNodes; } // Build roots const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t })); const roots = build(seqs, null); const total = nextPos; // --- Flatten tree --- const evaluatedIds = new Array(total); const parent = new Array(total).fill(null); const moveToLeafPos = new Array(tokenArrays.length).fill(-1); function dfs(n) { evaluatedIds[n.pos] = n.token; parent[n.pos] = n.parent; for (const m of n.moveEnds) moveToLeafPos[m] = n.pos; for (const c of n.children) dfs(c); } for (const r of roots) dfs(r); // NOTE: moveToLeafPos[i] = -1 means the move has empty tokens (e.g., single-char notation) // In this case, we use prefix logits directly for scoring (valid behavior) // --- Build ancestor mask --- const mask = new Array(total * total).fill(0); for (let i = 0; i < total; i++) { let p = i; while (p !== null) { mask[i * total + p] = 1; p = parent[p]; } } return { evaluatedIds, mask, moveToLeafPos, parent }; } /** * Build tree structure for all valid moves * Returns prefix tokens and tree structure for batch evaluation */ buildMoveTree(game, moves) { // Get current TGN as prefix const currentTGN = game.toTGN().trim(); // Build prefix (everything up to next move) const lines = currentTGN.split("\n"); const lastLine = lines[lines.length - 1]; let prefix; if (lastLine.match(/^\d+\./)) { // Last line is a move number, include it prefix = currentTGN + " "; } else if (lastLine.trim() === "") { // Empty line, add move number const moveMatches = currentTGN.match(/\d+\.\s/g); const moveNumber = moveMatches ? moveMatches.length + 1 : 1; const isBlackTurn = game.getCurrentPlayer() === game_1.StoneType.BLACK; if (isBlackTurn) { prefix = currentTGN + `${moveNumber}. `; } else { prefix = currentTGN + " "; } } else { // Last line has moves, add space prefix = currentTGN + " "; } const prefixTokens = [this.START_TOKEN, ...this.stringToTokens(prefix)]; // Encode each move to tokens (only first 2 tokens) const shape = game.getShape(); const movesWithTokens = moves.map((move) => { let notation; if (move.isPass) { notation = "Pass"; } else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape); } else { throw new Error("Invalid move: missing coordinates"); } // Exclude the last token // For single-char notations, this results in empty tokens array, // which means we use prefix logits directly for scoring const fullTokens = this.stringToTokens(notation); const tokens = fullTokens.slice(0, fullTokens.length - 1); return { move, notation, tokens }; }); // Build prefix tree const tokenArrays = movesWithTokens.map((m) => m.tokens); const { evaluatedIds, mask, moveToLeafPos, parent } = this.buildPrefixTree(tokenArrays); // Build move data with leaf positions only const moveData = movesWithTokens.map((m, index) => { const leafPos = moveToLeafPos[index]; return { move: m.move, notation: m.notation, leafPos }; }); return { prefixTokens, evaluatedIds, mask, parent, moveData }; } /** * Get tree structure for visualization (public method) */ getTreeStructure(game, moves) { return this.buildMoveTree(game, moves); } /** * Select move using tree attention with optional temperature sampling * @param game Current game state * @param temperature Sampling temperature (0 = greedy, higher = more random) * @returns Selected move (position or Pass if no valid positions) */ async selectMove(game, temperature = 0) { if (!this.inferencer.isReady()) { throw new Error("Inferencer not initialized"); } // Get current player as string const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); // Get all valid position moves (excluding Pass) const validMoves = game.validMovePositions().map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); // If no position moves available, return Pass directly if (validMoves.length === 0) { return { player: currentPlayer, isPass: true }; } // Score only position moves (Pass excluded from inference) const scoredMoves = await this.scoreMoves(game, validMoves); // Fallback to Pass if scoring fails if (scoredMoves.length === 0) { return { player: currentPlayer, isPass: true }; } // Select move based on temperature if (temperature <= 0.01) { // Greedy selection (use reduce to avoid mutating scoredMoves) const best = scoredMoves.reduce((a, b) => (b.score > a.score ? b : a)); return best.move; } // Temperature sampling return this.sampleMove(scoredMoves, temperature); } /** * Select best move using tree attention (greedy, temperature=0) * Evaluates all valid moves in a single inference call * Pass is excluded from model prediction - returned directly if no positions available */ async selectBestMove(game) { return this.selectMove(game, 0); } /** * Sample a move from scored moves using temperature */ sampleMove(scoredMoves, temperature) { // Apply temperature scaling to log probabilities const adjustedScores = scoredMoves.map((m) => m.score / temperature); const maxScore = Math.max(...adjustedScores); const expScores = adjustedScores.map((score) => Math.exp(score - maxScore)); const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); if (sumExp === 0 || !isFinite(sumExp)) { // Fallback to uniform random const idx = Math.floor(Math.random() * scoredMoves.length); return scoredMoves[idx].move; } const probabilities = expScores.map((exp) => exp / sumExp); // Weighted random sampling const random = Math.random(); let cumulative = 0; for (let i = 0; i < scoredMoves.length; i++) { cumulative += probabilities[i]; if (random <= cumulative) { return scoredMoves[i].move; } } return scoredMoves[scoredMoves.length - 1].move; } /** * Score all moves using tree attention (batch evaluation) */ async scoreMoves(game, moves) { if (moves.length === 0) { return []; } // Build tree structure const { prefixTokens, evaluatedIds, mask, parent, moveData } = this.buildMoveTree(game, moves); //console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`); //console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join("")); //console.debug( // `Move positions:`, // moveData.map((m) => `${m.notation}@${m.leafPos}`) //); // Prepare inputs for evaluation const inputs = { prefixIds: prefixTokens, evaluatedIds: evaluatedIds, evaluatedMask: mask }; // Run inference const output = await this.inferencer.runEvaluationInference(inputs); const { logits, numEvaluated } = output; //console.debug(`Inference output: ${numEvaluated} evaluated positions`); //process.stdout.write("."); // Minimum probability threshold to avoid log(0) while preserving small probabilities const MIN_PROB = 1e-10; // log(1e-10) ≈ -23 // Score each move by accumulating log probabilities along the path // For each move, build the path from root to leaf using parent array const scoredMoves = []; // Cache softmax results for each output position to avoid recomputation const softmaxCache = new Map(); const getSoftmax = (outputPos) => { if (!softmaxCache.has(outputPos)) { softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos)); } return softmaxCache.get(outputPos); }; for (const data of moveData) { let logProb = 0; // Special case: leafPos = -1 means empty tokens (single-char notation) // Use prefix logits directly to predict the single character if (data.leafPos === -1) { const notationTokens = this.stringToTokens(data.notation); if (notationTokens.length === 1) { // Single-char notation: use prefix output (logits[0]) to predict it const token = notationTokens[0]; const probs = getSoftmax(0); // Prefix output const prob = Math.max(probs[token], MIN_PROB); logProb = Math.log(prob); } else { console.error(`Unexpected: leafPos=-1 but notation length=${notationTokens.length}`); logProb = Math.log(MIN_PROB); } scoredMoves.push({ move: data.move, score: logProb, notation: data.notation }); continue; // Skip the normal path processing } // Build path from leaf to root using parent array, then reverse const pathReverse = []; let pos = data.leafPos; const visited = new Set(); // Safety checks: prevent infinite loops and invalid indices while (pos !== null && pos !== undefined) { // Check for cycles if (visited.has(pos)) { console.error(`Cycle detected in parent array at position ${pos}`); break; } // Check for valid index if (pos < 0 || pos >= parent.length) { console.error(`Invalid position ${pos}, parent array length: ${parent.length}`); break; } visited.add(pos); pathReverse.push(pos); pos = parent[pos]; // Safety limit to prevent runaway loops if (pathReverse.length > 10000) { console.error(`Path too long (>10000), possible infinite loop. leafPos: ${data.leafPos}`); break; } } // Reverse to get root→leaf path (indices in evaluatedIds array) const path = pathReverse.reverse(); // Now accumulate log probabilities for each transition in path // TreeLM returns logits[0..m] where: // logits[0] = output at prefix last position (n-1) → predicts evaluatedIds[0] // logits[i] = output at position (n-1+i) → predicts evaluatedIds[i] // // For a parent→child transition: // Parent: evaluatedIds[parentIdx] at input position (n+parentIdx) // Parent output: at position (n+parentIdx), which is logits[parentIdx+1] // Child token: evaluatedIds[childIdx] // Probability: softmax(logits[parentIdx+1])[evaluatedIds[childIdx]] // Special case: root token (predicted from prefix last position) if (path.length > 0) { const rootPos = path[0]; const rootToken = evaluatedIds[rootPos]; // Root is predicted by prefix last position output (logits[0]) const probs = getSoftmax(0); const prob = Math.max(probs[rootToken], MIN_PROB); // Clip to minimum logProb += Math.log(prob); } // Subsequent transitions: parent→child in tree for (let i = 1; i < path.length; i++) { const parentPos = path[i - 1]; // evaluatedIds index const childPos = path[i]; // evaluatedIds index const childToken = evaluatedIds[childPos]; // Parent output is at logits[parentPos+1] const logitsIndex = parentPos + 1; // Check bounds: logitsIndex must be <= numEvaluated // (logits has length numEvaluated+1, indices 0 to numEvaluated) if (logitsIndex <= numEvaluated) { const probs = getSoftmax(logitsIndex); const prob = Math.max(probs[childToken], MIN_PROB); // Clip to minimum logProb += Math.log(prob); } else { // Parent position out of bounds logProb += Math.log(MIN_PROB); } } // CRITICAL: Add probability for the LAST token (excluded from tree) // The last character of the move notation was excluded from evaluatedIds // We need to predict it using the leaf node's output if (path.length > 0) { const leafPos = path[path.length - 1]; // Last position in path const lastToken = this.stringToTokens(data.notation).pop(); // Last char of notation // Leaf output is at logits[leafPos+1] const logitsIndex = leafPos + 1; if (logitsIndex <= numEvaluated) { const probs = getSoftmax(logitsIndex); const prob = Math.max(probs[lastToken], MIN_PROB); // Clip to minimum logProb += Math.log(prob); } else { logProb += Math.log(MIN_PROB); } } scoredMoves.push({ move: data.move, score: logProb, notation: data.notation }); } return scoredMoves; } } exports.TrigoTreeAgent = TrigoTreeAgent; //# sourceMappingURL=trigoTreeAgent.js.map