trigo / trigo-web /backend /dist /inc /trigoTreeAgent.js
k-l-lambda's picture
feat: room rename and room switching with confirmation
6f4808d
"use strict";
/**
* Trigo Tree Agent - AI agent using tree attention for efficient move evaluation
*
* Uses evaluation mode ONNX model to score all valid moves in parallel.
* Organizes moves as a prefix tree where branches with same head token are merged.
*/
Object.defineProperty(exports, "__esModule", { value: true });
exports.TrigoTreeAgent = void 0;
const game_1 = require("./trigo/game");
const ab0yz_1 = require("./trigo/ab0yz");
class TrigoTreeAgent {
constructor(inferencer) {
// Special token constants (must match TGN tokenizer)
this.START_TOKEN = 1;
this.inferencer = inferencer;
}
/**
* Convert Stone type to player string
*/
stoneToPlayer(stone) {
if (stone === game_1.StoneType.BLACK)
return "black";
if (stone === game_1.StoneType.WHITE)
return "white";
throw new Error(`Invalid stone type: ${stone}`);
}
/**
* Encode a position to TGN notation (3 characters for 5×5×5 board)
*/
positionToTGN(pos, shape) {
const posArray = [pos.x, pos.y, pos.z];
const shapeArray = [shape.x, shape.y, shape.z];
return (0, ab0yz_1.encodeAb0yz)(posArray, shapeArray);
}
/**
* Convert string to byte tokens (ASCII encoding)
*/
stringToTokens(str) {
const tokens = [];
for (let i = 0; i < str.length; i++) {
tokens.push(str.charCodeAt(i));
}
return tokens;
}
/**
* Build prefix tree from token arrays using recursive merging
* Merges branches with the same token at EVERY level
*
* Algorithm:
* 1. Group sequences by their first token
* 2. For each group:
* - Create one node for the shared first token
* - Extract remaining tokens (residues)
* - Recursively build subtree from residues
* 3. Combine all subtrees and build attention mask
*
* Example for ["aa", "ab", "ba", "bb"]:
* Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"]
* Level 2: Within 'a' group, build subtree for ["a","b"]
* Within 'b' group, build subtree for ["a","b"]
* Result: Two branches, each with properly merged second-level nodes
*
* @param tokenArrays - Array of token arrays
* @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping
*/
buildPrefixTree(tokenArrays) {
let nextPos = 0;
// --- Build prefix tree through recursive grouping ---
function build(seqs, parent) {
// group by token
const groups = new Map();
for (const s of seqs) {
if (s.tokens.length === 0)
continue;
const t = s.tokens[0];
if (!groups.has(t))
groups.set(t, []);
groups.get(t).push(s);
}
const levelNodes = [];
for (const [token, group] of groups) {
const pos = nextPos++;
const node = {
token,
pos,
parent,
children: [],
moveEnds: []
};
// split residues
const ends = [];
const residues = [];
for (const g of group) {
if (g.tokens.length === 1)
ends.push(g.moveIndex);
else
residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) });
}
node.moveEnds = ends;
// create sub nodes recursively
if (residues.length > 0) {
node.children = build(residues, pos);
}
levelNodes.push(node);
}
return levelNodes;
}
// Build roots
const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t }));
const roots = build(seqs, null);
const total = nextPos;
// --- Flatten tree ---
const evaluatedIds = new Array(total);
const parent = new Array(total).fill(null);
const moveToLeafPos = new Array(tokenArrays.length).fill(-1);
function dfs(n) {
evaluatedIds[n.pos] = n.token;
parent[n.pos] = n.parent;
for (const m of n.moveEnds)
moveToLeafPos[m] = n.pos;
for (const c of n.children)
dfs(c);
}
for (const r of roots)
dfs(r);
// NOTE: moveToLeafPos[i] = -1 means the move has empty tokens (e.g., single-char notation)
// In this case, we use prefix logits directly for scoring (valid behavior)
// --- Build ancestor mask ---
const mask = new Array(total * total).fill(0);
for (let i = 0; i < total; i++) {
let p = i;
while (p !== null) {
mask[i * total + p] = 1;
p = parent[p];
}
}
return { evaluatedIds, mask, moveToLeafPos, parent };
}
/**
* Build tree structure for all valid moves
* Returns prefix tokens and tree structure for batch evaluation
*/
buildMoveTree(game, moves) {
// Get current TGN as prefix
const currentTGN = game.toTGN().trim();
// Build prefix (everything up to next move)
const lines = currentTGN.split("\n");
const lastLine = lines[lines.length - 1];
let prefix;
if (lastLine.match(/^\d+\./)) {
// Last line is a move number, include it
prefix = currentTGN + " ";
}
else if (lastLine.trim() === "") {
// Empty line, add move number
const moveMatches = currentTGN.match(/\d+\.\s/g);
const moveNumber = moveMatches ? moveMatches.length + 1 : 1;
const isBlackTurn = game.getCurrentPlayer() === game_1.StoneType.BLACK;
if (isBlackTurn) {
prefix = currentTGN + `${moveNumber}. `;
}
else {
prefix = currentTGN + " ";
}
}
else {
// Last line has moves, add space
prefix = currentTGN + " ";
}
const prefixTokens = [this.START_TOKEN, ...this.stringToTokens(prefix)];
// Encode each move to tokens (only first 2 tokens)
const shape = game.getShape();
const movesWithTokens = moves.map((move) => {
let notation;
if (move.isPass) {
notation = "Pass";
}
else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape);
}
else {
throw new Error("Invalid move: missing coordinates");
}
// Exclude the last token
// For single-char notations, this results in empty tokens array,
// which means we use prefix logits directly for scoring
const fullTokens = this.stringToTokens(notation);
const tokens = fullTokens.slice(0, fullTokens.length - 1);
return { move, notation, tokens };
});
// Build prefix tree
const tokenArrays = movesWithTokens.map((m) => m.tokens);
const { evaluatedIds, mask, moveToLeafPos, parent } = this.buildPrefixTree(tokenArrays);
// Build move data with leaf positions only
const moveData = movesWithTokens.map((m, index) => {
const leafPos = moveToLeafPos[index];
return {
move: m.move,
notation: m.notation,
leafPos
};
});
return { prefixTokens, evaluatedIds, mask, parent, moveData };
}
/**
* Get tree structure for visualization (public method)
*/
getTreeStructure(game, moves) {
return this.buildMoveTree(game, moves);
}
/**
* Select move using tree attention with optional temperature sampling
* @param game Current game state
* @param temperature Sampling temperature (0 = greedy, higher = more random)
* @returns Selected move (position or Pass if no valid positions)
*/
async selectMove(game, temperature = 0) {
if (!this.inferencer.isReady()) {
throw new Error("Inferencer not initialized");
}
// Get current player as string
const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer());
// Get all valid position moves (excluding Pass)
const validMoves = game.validMovePositions().map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
// If no position moves available, return Pass directly
if (validMoves.length === 0) {
return { player: currentPlayer, isPass: true };
}
// Score only position moves (Pass excluded from inference)
const scoredMoves = await this.scoreMoves(game, validMoves);
// Fallback to Pass if scoring fails
if (scoredMoves.length === 0) {
return { player: currentPlayer, isPass: true };
}
// Select move based on temperature
if (temperature <= 0.01) {
// Greedy selection (use reduce to avoid mutating scoredMoves)
const best = scoredMoves.reduce((a, b) => (b.score > a.score ? b : a));
return best.move;
}
// Temperature sampling
return this.sampleMove(scoredMoves, temperature);
}
/**
* Select best move using tree attention (greedy, temperature=0)
* Evaluates all valid moves in a single inference call
* Pass is excluded from model prediction - returned directly if no positions available
*/
async selectBestMove(game) {
return this.selectMove(game, 0);
}
/**
* Sample a move from scored moves using temperature
*/
sampleMove(scoredMoves, temperature) {
// Apply temperature scaling to log probabilities
const adjustedScores = scoredMoves.map((m) => m.score / temperature);
const maxScore = Math.max(...adjustedScores);
const expScores = adjustedScores.map((score) => Math.exp(score - maxScore));
const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
if (sumExp === 0 || !isFinite(sumExp)) {
// Fallback to uniform random
const idx = Math.floor(Math.random() * scoredMoves.length);
return scoredMoves[idx].move;
}
const probabilities = expScores.map((exp) => exp / sumExp);
// Weighted random sampling
const random = Math.random();
let cumulative = 0;
for (let i = 0; i < scoredMoves.length; i++) {
cumulative += probabilities[i];
if (random <= cumulative) {
return scoredMoves[i].move;
}
}
return scoredMoves[scoredMoves.length - 1].move;
}
/**
* Score all moves using tree attention (batch evaluation)
*/
async scoreMoves(game, moves) {
if (moves.length === 0) {
return [];
}
// Build tree structure
const { prefixTokens, evaluatedIds, mask, parent, moveData } = this.buildMoveTree(game, moves);
//console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`);
//console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join(""));
//console.debug(
// `Move positions:`,
// moveData.map((m) => `${m.notation}@${m.leafPos}`)
//);
// Prepare inputs for evaluation
const inputs = {
prefixIds: prefixTokens,
evaluatedIds: evaluatedIds,
evaluatedMask: mask
};
// Run inference
const output = await this.inferencer.runEvaluationInference(inputs);
const { logits, numEvaluated } = output;
//console.debug(`Inference output: ${numEvaluated} evaluated positions`);
//process.stdout.write(".");
// Minimum probability threshold to avoid log(0) while preserving small probabilities
const MIN_PROB = 1e-10; // log(1e-10) ≈ -23
// Score each move by accumulating log probabilities along the path
// For each move, build the path from root to leaf using parent array
const scoredMoves = [];
// Cache softmax results for each output position to avoid recomputation
const softmaxCache = new Map();
const getSoftmax = (outputPos) => {
if (!softmaxCache.has(outputPos)) {
softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos));
}
return softmaxCache.get(outputPos);
};
for (const data of moveData) {
let logProb = 0;
// Special case: leafPos = -1 means empty tokens (single-char notation)
// Use prefix logits directly to predict the single character
if (data.leafPos === -1) {
const notationTokens = this.stringToTokens(data.notation);
if (notationTokens.length === 1) {
// Single-char notation: use prefix output (logits[0]) to predict it
const token = notationTokens[0];
const probs = getSoftmax(0); // Prefix output
const prob = Math.max(probs[token], MIN_PROB);
logProb = Math.log(prob);
}
else {
console.error(`Unexpected: leafPos=-1 but notation length=${notationTokens.length}`);
logProb = Math.log(MIN_PROB);
}
scoredMoves.push({
move: data.move,
score: logProb,
notation: data.notation
});
continue; // Skip the normal path processing
}
// Build path from leaf to root using parent array, then reverse
const pathReverse = [];
let pos = data.leafPos;
const visited = new Set();
// Safety checks: prevent infinite loops and invalid indices
while (pos !== null && pos !== undefined) {
// Check for cycles
if (visited.has(pos)) {
console.error(`Cycle detected in parent array at position ${pos}`);
break;
}
// Check for valid index
if (pos < 0 || pos >= parent.length) {
console.error(`Invalid position ${pos}, parent array length: ${parent.length}`);
break;
}
visited.add(pos);
pathReverse.push(pos);
pos = parent[pos];
// Safety limit to prevent runaway loops
if (pathReverse.length > 10000) {
console.error(`Path too long (>10000), possible infinite loop. leafPos: ${data.leafPos}`);
break;
}
}
// Reverse to get root→leaf path (indices in evaluatedIds array)
const path = pathReverse.reverse();
// Now accumulate log probabilities for each transition in path
// TreeLM returns logits[0..m] where:
// logits[0] = output at prefix last position (n-1) → predicts evaluatedIds[0]
// logits[i] = output at position (n-1+i) → predicts evaluatedIds[i]
//
// For a parent→child transition:
// Parent: evaluatedIds[parentIdx] at input position (n+parentIdx)
// Parent output: at position (n+parentIdx), which is logits[parentIdx+1]
// Child token: evaluatedIds[childIdx]
// Probability: softmax(logits[parentIdx+1])[evaluatedIds[childIdx]]
// Special case: root token (predicted from prefix last position)
if (path.length > 0) {
const rootPos = path[0];
const rootToken = evaluatedIds[rootPos];
// Root is predicted by prefix last position output (logits[0])
const probs = getSoftmax(0);
const prob = Math.max(probs[rootToken], MIN_PROB); // Clip to minimum
logProb += Math.log(prob);
}
// Subsequent transitions: parent→child in tree
for (let i = 1; i < path.length; i++) {
const parentPos = path[i - 1]; // evaluatedIds index
const childPos = path[i]; // evaluatedIds index
const childToken = evaluatedIds[childPos];
// Parent output is at logits[parentPos+1]
const logitsIndex = parentPos + 1;
// Check bounds: logitsIndex must be <= numEvaluated
// (logits has length numEvaluated+1, indices 0 to numEvaluated)
if (logitsIndex <= numEvaluated) {
const probs = getSoftmax(logitsIndex);
const prob = Math.max(probs[childToken], MIN_PROB); // Clip to minimum
logProb += Math.log(prob);
}
else {
// Parent position out of bounds
logProb += Math.log(MIN_PROB);
}
}
// CRITICAL: Add probability for the LAST token (excluded from tree)
// The last character of the move notation was excluded from evaluatedIds
// We need to predict it using the leaf node's output
if (path.length > 0) {
const leafPos = path[path.length - 1]; // Last position in path
const lastToken = this.stringToTokens(data.notation).pop(); // Last char of notation
// Leaf output is at logits[leafPos+1]
const logitsIndex = leafPos + 1;
if (logitsIndex <= numEvaluated) {
const probs = getSoftmax(logitsIndex);
const prob = Math.max(probs[lastToken], MIN_PROB); // Clip to minimum
logProb += Math.log(prob);
}
else {
logProb += Math.log(MIN_PROB);
}
}
scoredMoves.push({
move: data.move,
score: logProb,
notation: data.notation
});
}
return scoredMoves;
}
}
exports.TrigoTreeAgent = TrigoTreeAgent;
//# sourceMappingURL=trigoTreeAgent.js.map