Spaces:
Sleeping
Sleeping
| ; | |
| /** | |
| * Trigo Tree Agent - AI agent using tree attention for efficient move evaluation | |
| * | |
| * Uses evaluation mode ONNX model to score all valid moves in parallel. | |
| * Organizes moves as a prefix tree where branches with same head token are merged. | |
| */ | |
| Object.defineProperty(exports, "__esModule", { value: true }); | |
| exports.TrigoTreeAgent = void 0; | |
| const game_1 = require("./trigo/game"); | |
| const ab0yz_1 = require("./trigo/ab0yz"); | |
| class TrigoTreeAgent { | |
| constructor(inferencer) { | |
| // Special token constants (must match TGN tokenizer) | |
| this.START_TOKEN = 1; | |
| this.inferencer = inferencer; | |
| } | |
| /** | |
| * Convert Stone type to player string | |
| */ | |
| stoneToPlayer(stone) { | |
| if (stone === game_1.StoneType.BLACK) | |
| return "black"; | |
| if (stone === game_1.StoneType.WHITE) | |
| return "white"; | |
| throw new Error(`Invalid stone type: ${stone}`); | |
| } | |
| /** | |
| * Encode a position to TGN notation (3 characters for 5×5×5 board) | |
| */ | |
| positionToTGN(pos, shape) { | |
| const posArray = [pos.x, pos.y, pos.z]; | |
| const shapeArray = [shape.x, shape.y, shape.z]; | |
| return (0, ab0yz_1.encodeAb0yz)(posArray, shapeArray); | |
| } | |
| /** | |
| * Convert string to byte tokens (ASCII encoding) | |
| */ | |
| stringToTokens(str) { | |
| const tokens = []; | |
| for (let i = 0; i < str.length; i++) { | |
| tokens.push(str.charCodeAt(i)); | |
| } | |
| return tokens; | |
| } | |
| /** | |
| * Build prefix tree from token arrays using recursive merging | |
| * Merges branches with the same token at EVERY level | |
| * | |
| * Algorithm: | |
| * 1. Group sequences by their first token | |
| * 2. For each group: | |
| * - Create one node for the shared first token | |
| * - Extract remaining tokens (residues) | |
| * - Recursively build subtree from residues | |
| * 3. Combine all subtrees and build attention mask | |
| * | |
| * Example for ["aa", "ab", "ba", "bb"]: | |
| * Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"] | |
| * Level 2: Within 'a' group, build subtree for ["a","b"] | |
| * Within 'b' group, build subtree for ["a","b"] | |
| * Result: Two branches, each with properly merged second-level nodes | |
| * | |
| * @param tokenArrays - Array of token arrays | |
| * @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping | |
| */ | |
| buildPrefixTree(tokenArrays) { | |
| let nextPos = 0; | |
| // --- Build prefix tree through recursive grouping --- | |
| function build(seqs, parent) { | |
| // group by token | |
| const groups = new Map(); | |
| for (const s of seqs) { | |
| if (s.tokens.length === 0) | |
| continue; | |
| const t = s.tokens[0]; | |
| if (!groups.has(t)) | |
| groups.set(t, []); | |
| groups.get(t).push(s); | |
| } | |
| const levelNodes = []; | |
| for (const [token, group] of groups) { | |
| const pos = nextPos++; | |
| const node = { | |
| token, | |
| pos, | |
| parent, | |
| children: [], | |
| moveEnds: [] | |
| }; | |
| // split residues | |
| const ends = []; | |
| const residues = []; | |
| for (const g of group) { | |
| if (g.tokens.length === 1) | |
| ends.push(g.moveIndex); | |
| else | |
| residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) }); | |
| } | |
| node.moveEnds = ends; | |
| // create sub nodes recursively | |
| if (residues.length > 0) { | |
| node.children = build(residues, pos); | |
| } | |
| levelNodes.push(node); | |
| } | |
| return levelNodes; | |
| } | |
| // Build roots | |
| const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t })); | |
| const roots = build(seqs, null); | |
| const total = nextPos; | |
| // --- Flatten tree --- | |
| const evaluatedIds = new Array(total); | |
| const parent = new Array(total).fill(null); | |
| const moveToLeafPos = new Array(tokenArrays.length).fill(-1); | |
| function dfs(n) { | |
| evaluatedIds[n.pos] = n.token; | |
| parent[n.pos] = n.parent; | |
| for (const m of n.moveEnds) | |
| moveToLeafPos[m] = n.pos; | |
| for (const c of n.children) | |
| dfs(c); | |
| } | |
| for (const r of roots) | |
| dfs(r); | |
| // NOTE: moveToLeafPos[i] = -1 means the move has empty tokens (e.g., single-char notation) | |
| // In this case, we use prefix logits directly for scoring (valid behavior) | |
| // --- Build ancestor mask --- | |
| const mask = new Array(total * total).fill(0); | |
| for (let i = 0; i < total; i++) { | |
| let p = i; | |
| while (p !== null) { | |
| mask[i * total + p] = 1; | |
| p = parent[p]; | |
| } | |
| } | |
| return { evaluatedIds, mask, moveToLeafPos, parent }; | |
| } | |
| /** | |
| * Build tree structure for all valid moves | |
| * Returns prefix tokens and tree structure for batch evaluation | |
| */ | |
| buildMoveTree(game, moves) { | |
| // Get current TGN as prefix | |
| const currentTGN = game.toTGN().trim(); | |
| // Build prefix (everything up to next move) | |
| const lines = currentTGN.split("\n"); | |
| const lastLine = lines[lines.length - 1]; | |
| let prefix; | |
| if (lastLine.match(/^\d+\./)) { | |
| // Last line is a move number, include it | |
| prefix = currentTGN + " "; | |
| } | |
| else if (lastLine.trim() === "") { | |
| // Empty line, add move number | |
| const moveMatches = currentTGN.match(/\d+\.\s/g); | |
| const moveNumber = moveMatches ? moveMatches.length + 1 : 1; | |
| const isBlackTurn = game.getCurrentPlayer() === game_1.StoneType.BLACK; | |
| if (isBlackTurn) { | |
| prefix = currentTGN + `${moveNumber}. `; | |
| } | |
| else { | |
| prefix = currentTGN + " "; | |
| } | |
| } | |
| else { | |
| // Last line has moves, add space | |
| prefix = currentTGN + " "; | |
| } | |
| const prefixTokens = [this.START_TOKEN, ...this.stringToTokens(prefix)]; | |
| // Encode each move to tokens (only first 2 tokens) | |
| const shape = game.getShape(); | |
| const movesWithTokens = moves.map((move) => { | |
| let notation; | |
| if (move.isPass) { | |
| notation = "Pass"; | |
| } | |
| else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) { | |
| notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape); | |
| } | |
| else { | |
| throw new Error("Invalid move: missing coordinates"); | |
| } | |
| // Exclude the last token | |
| // For single-char notations, this results in empty tokens array, | |
| // which means we use prefix logits directly for scoring | |
| const fullTokens = this.stringToTokens(notation); | |
| const tokens = fullTokens.slice(0, fullTokens.length - 1); | |
| return { move, notation, tokens }; | |
| }); | |
| // Build prefix tree | |
| const tokenArrays = movesWithTokens.map((m) => m.tokens); | |
| const { evaluatedIds, mask, moveToLeafPos, parent } = this.buildPrefixTree(tokenArrays); | |
| // Build move data with leaf positions only | |
| const moveData = movesWithTokens.map((m, index) => { | |
| const leafPos = moveToLeafPos[index]; | |
| return { | |
| move: m.move, | |
| notation: m.notation, | |
| leafPos | |
| }; | |
| }); | |
| return { prefixTokens, evaluatedIds, mask, parent, moveData }; | |
| } | |
| /** | |
| * Get tree structure for visualization (public method) | |
| */ | |
| getTreeStructure(game, moves) { | |
| return this.buildMoveTree(game, moves); | |
| } | |
| /** | |
| * Select move using tree attention with optional temperature sampling | |
| * @param game Current game state | |
| * @param temperature Sampling temperature (0 = greedy, higher = more random) | |
| * @returns Selected move (position or Pass if no valid positions) | |
| */ | |
| async selectMove(game, temperature = 0) { | |
| if (!this.inferencer.isReady()) { | |
| throw new Error("Inferencer not initialized"); | |
| } | |
| // Get current player as string | |
| const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer()); | |
| // Get all valid position moves (excluding Pass) | |
| const validMoves = game.validMovePositions().map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| // If no position moves available, return Pass directly | |
| if (validMoves.length === 0) { | |
| return { player: currentPlayer, isPass: true }; | |
| } | |
| // Score only position moves (Pass excluded from inference) | |
| const scoredMoves = await this.scoreMoves(game, validMoves); | |
| // Fallback to Pass if scoring fails | |
| if (scoredMoves.length === 0) { | |
| return { player: currentPlayer, isPass: true }; | |
| } | |
| // Select move based on temperature | |
| if (temperature <= 0.01) { | |
| // Greedy selection (use reduce to avoid mutating scoredMoves) | |
| const best = scoredMoves.reduce((a, b) => (b.score > a.score ? b : a)); | |
| return best.move; | |
| } | |
| // Temperature sampling | |
| return this.sampleMove(scoredMoves, temperature); | |
| } | |
| /** | |
| * Select best move using tree attention (greedy, temperature=0) | |
| * Evaluates all valid moves in a single inference call | |
| * Pass is excluded from model prediction - returned directly if no positions available | |
| */ | |
| async selectBestMove(game) { | |
| return this.selectMove(game, 0); | |
| } | |
| /** | |
| * Sample a move from scored moves using temperature | |
| */ | |
| sampleMove(scoredMoves, temperature) { | |
| // Apply temperature scaling to log probabilities | |
| const adjustedScores = scoredMoves.map((m) => m.score / temperature); | |
| const maxScore = Math.max(...adjustedScores); | |
| const expScores = adjustedScores.map((score) => Math.exp(score - maxScore)); | |
| const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); | |
| if (sumExp === 0 || !isFinite(sumExp)) { | |
| // Fallback to uniform random | |
| const idx = Math.floor(Math.random() * scoredMoves.length); | |
| return scoredMoves[idx].move; | |
| } | |
| const probabilities = expScores.map((exp) => exp / sumExp); | |
| // Weighted random sampling | |
| const random = Math.random(); | |
| let cumulative = 0; | |
| for (let i = 0; i < scoredMoves.length; i++) { | |
| cumulative += probabilities[i]; | |
| if (random <= cumulative) { | |
| return scoredMoves[i].move; | |
| } | |
| } | |
| return scoredMoves[scoredMoves.length - 1].move; | |
| } | |
| /** | |
| * Score all moves using tree attention (batch evaluation) | |
| */ | |
| async scoreMoves(game, moves) { | |
| if (moves.length === 0) { | |
| return []; | |
| } | |
| // Build tree structure | |
| const { prefixTokens, evaluatedIds, mask, parent, moveData } = this.buildMoveTree(game, moves); | |
| //console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`); | |
| //console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join("")); | |
| //console.debug( | |
| // `Move positions:`, | |
| // moveData.map((m) => `${m.notation}@${m.leafPos}`) | |
| //); | |
| // Prepare inputs for evaluation | |
| const inputs = { | |
| prefixIds: prefixTokens, | |
| evaluatedIds: evaluatedIds, | |
| evaluatedMask: mask | |
| }; | |
| // Run inference | |
| const output = await this.inferencer.runEvaluationInference(inputs); | |
| const { logits, numEvaluated } = output; | |
| //console.debug(`Inference output: ${numEvaluated} evaluated positions`); | |
| //process.stdout.write("."); | |
| // Minimum probability threshold to avoid log(0) while preserving small probabilities | |
| const MIN_PROB = 1e-10; // log(1e-10) ≈ -23 | |
| // Score each move by accumulating log probabilities along the path | |
| // For each move, build the path from root to leaf using parent array | |
| const scoredMoves = []; | |
| // Cache softmax results for each output position to avoid recomputation | |
| const softmaxCache = new Map(); | |
| const getSoftmax = (outputPos) => { | |
| if (!softmaxCache.has(outputPos)) { | |
| softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos)); | |
| } | |
| return softmaxCache.get(outputPos); | |
| }; | |
| for (const data of moveData) { | |
| let logProb = 0; | |
| // Special case: leafPos = -1 means empty tokens (single-char notation) | |
| // Use prefix logits directly to predict the single character | |
| if (data.leafPos === -1) { | |
| const notationTokens = this.stringToTokens(data.notation); | |
| if (notationTokens.length === 1) { | |
| // Single-char notation: use prefix output (logits[0]) to predict it | |
| const token = notationTokens[0]; | |
| const probs = getSoftmax(0); // Prefix output | |
| const prob = Math.max(probs[token], MIN_PROB); | |
| logProb = Math.log(prob); | |
| } | |
| else { | |
| console.error(`Unexpected: leafPos=-1 but notation length=${notationTokens.length}`); | |
| logProb = Math.log(MIN_PROB); | |
| } | |
| scoredMoves.push({ | |
| move: data.move, | |
| score: logProb, | |
| notation: data.notation | |
| }); | |
| continue; // Skip the normal path processing | |
| } | |
| // Build path from leaf to root using parent array, then reverse | |
| const pathReverse = []; | |
| let pos = data.leafPos; | |
| const visited = new Set(); | |
| // Safety checks: prevent infinite loops and invalid indices | |
| while (pos !== null && pos !== undefined) { | |
| // Check for cycles | |
| if (visited.has(pos)) { | |
| console.error(`Cycle detected in parent array at position ${pos}`); | |
| break; | |
| } | |
| // Check for valid index | |
| if (pos < 0 || pos >= parent.length) { | |
| console.error(`Invalid position ${pos}, parent array length: ${parent.length}`); | |
| break; | |
| } | |
| visited.add(pos); | |
| pathReverse.push(pos); | |
| pos = parent[pos]; | |
| // Safety limit to prevent runaway loops | |
| if (pathReverse.length > 10000) { | |
| console.error(`Path too long (>10000), possible infinite loop. leafPos: ${data.leafPos}`); | |
| break; | |
| } | |
| } | |
| // Reverse to get root→leaf path (indices in evaluatedIds array) | |
| const path = pathReverse.reverse(); | |
| // Now accumulate log probabilities for each transition in path | |
| // TreeLM returns logits[0..m] where: | |
| // logits[0] = output at prefix last position (n-1) → predicts evaluatedIds[0] | |
| // logits[i] = output at position (n-1+i) → predicts evaluatedIds[i] | |
| // | |
| // For a parent→child transition: | |
| // Parent: evaluatedIds[parentIdx] at input position (n+parentIdx) | |
| // Parent output: at position (n+parentIdx), which is logits[parentIdx+1] | |
| // Child token: evaluatedIds[childIdx] | |
| // Probability: softmax(logits[parentIdx+1])[evaluatedIds[childIdx]] | |
| // Special case: root token (predicted from prefix last position) | |
| if (path.length > 0) { | |
| const rootPos = path[0]; | |
| const rootToken = evaluatedIds[rootPos]; | |
| // Root is predicted by prefix last position output (logits[0]) | |
| const probs = getSoftmax(0); | |
| const prob = Math.max(probs[rootToken], MIN_PROB); // Clip to minimum | |
| logProb += Math.log(prob); | |
| } | |
| // Subsequent transitions: parent→child in tree | |
| for (let i = 1; i < path.length; i++) { | |
| const parentPos = path[i - 1]; // evaluatedIds index | |
| const childPos = path[i]; // evaluatedIds index | |
| const childToken = evaluatedIds[childPos]; | |
| // Parent output is at logits[parentPos+1] | |
| const logitsIndex = parentPos + 1; | |
| // Check bounds: logitsIndex must be <= numEvaluated | |
| // (logits has length numEvaluated+1, indices 0 to numEvaluated) | |
| if (logitsIndex <= numEvaluated) { | |
| const probs = getSoftmax(logitsIndex); | |
| const prob = Math.max(probs[childToken], MIN_PROB); // Clip to minimum | |
| logProb += Math.log(prob); | |
| } | |
| else { | |
| // Parent position out of bounds | |
| logProb += Math.log(MIN_PROB); | |
| } | |
| } | |
| // CRITICAL: Add probability for the LAST token (excluded from tree) | |
| // The last character of the move notation was excluded from evaluatedIds | |
| // We need to predict it using the leaf node's output | |
| if (path.length > 0) { | |
| const leafPos = path[path.length - 1]; // Last position in path | |
| const lastToken = this.stringToTokens(data.notation).pop(); // Last char of notation | |
| // Leaf output is at logits[leafPos+1] | |
| const logitsIndex = leafPos + 1; | |
| if (logitsIndex <= numEvaluated) { | |
| const probs = getSoftmax(logitsIndex); | |
| const prob = Math.max(probs[lastToken], MIN_PROB); // Clip to minimum | |
| logProb += Math.log(prob); | |
| } | |
| else { | |
| logProb += Math.log(MIN_PROB); | |
| } | |
| } | |
| scoredMoves.push({ | |
| move: data.move, | |
| score: logProb, | |
| notation: data.notation | |
| }); | |
| } | |
| return scoredMoves; | |
| } | |
| } | |
| exports.TrigoTreeAgent = TrigoTreeAgent; | |
| //# sourceMappingURL=trigoTreeAgent.js.map |