File size: 18,752 Bytes
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
 
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
 
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
 
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
 
 
 
63a8db2
6f4808d
63a8db2
 
 
 
 
6f4808d
63a8db2
 
 
 
 
 
6f4808d
63a8db2
6f4808d
63a8db2
6f4808d
63a8db2
6f4808d
63a8db2
6f4808d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a8db2
6f4808d
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
6f4808d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
63a8db2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
342
343
344
345
346
347
348
349
350
351
352
353
354
355
356
357
358
359
360
361
362
363
364
365
366
367
368
369
370
371
372
373
374
375
376
377
378
379
380
381
382
383
384
385
386
387
388
389
390
391
392
393
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
"use strict";
/**
 * Trigo Tree Agent - AI agent using tree attention for efficient move evaluation
 *
 * Uses evaluation mode ONNX model to score all valid moves in parallel.
 * Organizes moves as a prefix tree where branches with same head token are merged.
 */
Object.defineProperty(exports, "__esModule", { value: true });
exports.TrigoTreeAgent = void 0;
const game_1 = require("./trigo/game");
const ab0yz_1 = require("./trigo/ab0yz");
class TrigoTreeAgent {
    constructor(inferencer) {
        // Special token constants (must match TGN tokenizer)
        this.START_TOKEN = 1;
        this.inferencer = inferencer;
    }
    /**
     * Convert Stone type to player string
     */
    stoneToPlayer(stone) {
        if (stone === game_1.StoneType.BLACK)
            return "black";
        if (stone === game_1.StoneType.WHITE)
            return "white";
        throw new Error(`Invalid stone type: ${stone}`);
    }
    /**
     * Encode a position to TGN notation (3 characters for 5×5×5 board)
     */
    positionToTGN(pos, shape) {
        const posArray = [pos.x, pos.y, pos.z];
        const shapeArray = [shape.x, shape.y, shape.z];
        return (0, ab0yz_1.encodeAb0yz)(posArray, shapeArray);
    }
    /**
     * Convert string to byte tokens (ASCII encoding)
     */
    stringToTokens(str) {
        const tokens = [];
        for (let i = 0; i < str.length; i++) {
            tokens.push(str.charCodeAt(i));
        }
        return tokens;
    }
    /**
     * Build prefix tree from token arrays using recursive merging
     * Merges branches with the same token at EVERY level
     *
     * Algorithm:
     * 1. Group sequences by their first token
     * 2. For each group:
     *    - Create one node for the shared first token
     *    - Extract remaining tokens (residues)
     *    - Recursively build subtree from residues
     * 3. Combine all subtrees and build attention mask
     *
     * Example for ["aa", "ab", "ba", "bb"]:
     *   Level 1: Group by first token → 'a': ["a","b"], 'b': ["a","b"]
     *   Level 2: Within 'a' group, build subtree for ["a","b"]
     *            Within 'b' group, build subtree for ["a","b"]
     *   Result: Two branches, each with properly merged second-level nodes
     *
     * @param tokenArrays - Array of token arrays
     * @returns Flattened token array (length m), mask matrix (m×m), and move-to-position mapping
     */
    buildPrefixTree(tokenArrays) {
        let nextPos = 0;
        // --- Build prefix tree through recursive grouping ---
        function build(seqs, parent) {
            // group by token
            const groups = new Map();
            for (const s of seqs) {
                if (s.tokens.length === 0)
                    continue;
                const t = s.tokens[0];
                if (!groups.has(t))
                    groups.set(t, []);
                groups.get(t).push(s);
            }
            const levelNodes = [];
            for (const [token, group] of groups) {
                const pos = nextPos++;
                const node = {
                    token,
                    pos,
                    parent,
                    children: [],
                    moveEnds: []
                };
                // split residues
                const ends = [];
                const residues = [];
                for (const g of group) {
                    if (g.tokens.length === 1)
                        ends.push(g.moveIndex);
                    else
                        residues.push({ moveIndex: g.moveIndex, tokens: g.tokens.slice(1) });
                }
                node.moveEnds = ends;
                // create sub nodes recursively
                if (residues.length > 0) {
                    node.children = build(residues, pos);
                }
                levelNodes.push(node);
            }
            return levelNodes;
        }
        // Build roots
        const seqs = tokenArrays.map((t, i) => ({ moveIndex: i, tokens: t }));
        const roots = build(seqs, null);
        const total = nextPos;
        // --- Flatten tree ---
        const evaluatedIds = new Array(total);
        const parent = new Array(total).fill(null);
        const moveToLeafPos = new Array(tokenArrays.length).fill(-1);
        function dfs(n) {
            evaluatedIds[n.pos] = n.token;
            parent[n.pos] = n.parent;
            for (const m of n.moveEnds)
                moveToLeafPos[m] = n.pos;
            for (const c of n.children)
                dfs(c);
        }
        for (const r of roots)
            dfs(r);
        // NOTE: moveToLeafPos[i] = -1 means the move has empty tokens (e.g., single-char notation)
        // In this case, we use prefix logits directly for scoring (valid behavior)
        // --- Build ancestor mask ---
        const mask = new Array(total * total).fill(0);
        for (let i = 0; i < total; i++) {
            let p = i;
            while (p !== null) {
                mask[i * total + p] = 1;
                p = parent[p];
            }
        }
        return { evaluatedIds, mask, moveToLeafPos, parent };
    }
    /**
     * Build tree structure for all valid moves
     * Returns prefix tokens and tree structure for batch evaluation
     */
    buildMoveTree(game, moves) {
        // Get current TGN as prefix
        const currentTGN = game.toTGN().trim();
        // Build prefix (everything up to next move)
        const lines = currentTGN.split("\n");
        const lastLine = lines[lines.length - 1];
        let prefix;
        if (lastLine.match(/^\d+\./)) {
            // Last line is a move number, include it
            prefix = currentTGN + " ";
        }
        else if (lastLine.trim() === "") {
            // Empty line, add move number
            const moveMatches = currentTGN.match(/\d+\.\s/g);
            const moveNumber = moveMatches ? moveMatches.length + 1 : 1;
            const isBlackTurn = game.getCurrentPlayer() === game_1.StoneType.BLACK;
            if (isBlackTurn) {
                prefix = currentTGN + `${moveNumber}. `;
            }
            else {
                prefix = currentTGN + " ";
            }
        }
        else {
            // Last line has moves, add space
            prefix = currentTGN + " ";
        }
        const prefixTokens = [this.START_TOKEN, ...this.stringToTokens(prefix)];
        // Encode each move to tokens (only first 2 tokens)
        const shape = game.getShape();
        const movesWithTokens = moves.map((move) => {
            let notation;
            if (move.isPass) {
                notation = "Pass";
            }
            else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
                notation = this.positionToTGN({ x: move.x, y: move.y, z: move.z }, shape);
            }
            else {
                throw new Error("Invalid move: missing coordinates");
            }
            // Exclude the last token
            // For single-char notations, this results in empty tokens array,
            // which means we use prefix logits directly for scoring
            const fullTokens = this.stringToTokens(notation);
            const tokens = fullTokens.slice(0, fullTokens.length - 1);
            return { move, notation, tokens };
        });
        // Build prefix tree
        const tokenArrays = movesWithTokens.map((m) => m.tokens);
        const { evaluatedIds, mask, moveToLeafPos, parent } = this.buildPrefixTree(tokenArrays);
        // Build move data with leaf positions only
        const moveData = movesWithTokens.map((m, index) => {
            const leafPos = moveToLeafPos[index];
            return {
                move: m.move,
                notation: m.notation,
                leafPos
            };
        });
        return { prefixTokens, evaluatedIds, mask, parent, moveData };
    }
    /**
     * Get tree structure for visualization (public method)
     */
    getTreeStructure(game, moves) {
        return this.buildMoveTree(game, moves);
    }
    /**
     * Select move using tree attention with optional temperature sampling
     * @param game Current game state
     * @param temperature Sampling temperature (0 = greedy, higher = more random)
     * @returns Selected move (position or Pass if no valid positions)
     */
    async selectMove(game, temperature = 0) {
        if (!this.inferencer.isReady()) {
            throw new Error("Inferencer not initialized");
        }
        // Get current player as string
        const currentPlayer = this.stoneToPlayer(game.getCurrentPlayer());
        // Get all valid position moves (excluding Pass)
        const validMoves = game.validMovePositions().map((pos) => ({
            x: pos.x,
            y: pos.y,
            z: pos.z,
            player: currentPlayer
        }));
        // If no position moves available, return Pass directly
        if (validMoves.length === 0) {
            return { player: currentPlayer, isPass: true };
        }
        // Score only position moves (Pass excluded from inference)
        const scoredMoves = await this.scoreMoves(game, validMoves);
        // Fallback to Pass if scoring fails
        if (scoredMoves.length === 0) {
            return { player: currentPlayer, isPass: true };
        }
        // Select move based on temperature
        if (temperature <= 0.01) {
            // Greedy selection (use reduce to avoid mutating scoredMoves)
            const best = scoredMoves.reduce((a, b) => (b.score > a.score ? b : a));
            return best.move;
        }
        // Temperature sampling
        return this.sampleMove(scoredMoves, temperature);
    }
    /**
     * Select best move using tree attention (greedy, temperature=0)
     * Evaluates all valid moves in a single inference call
     * Pass is excluded from model prediction - returned directly if no positions available
     */
    async selectBestMove(game) {
        return this.selectMove(game, 0);
    }
    /**
     * Sample a move from scored moves using temperature
     */
    sampleMove(scoredMoves, temperature) {
        // Apply temperature scaling to log probabilities
        const adjustedScores = scoredMoves.map((m) => m.score / temperature);
        const maxScore = Math.max(...adjustedScores);
        const expScores = adjustedScores.map((score) => Math.exp(score - maxScore));
        const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
        if (sumExp === 0 || !isFinite(sumExp)) {
            // Fallback to uniform random
            const idx = Math.floor(Math.random() * scoredMoves.length);
            return scoredMoves[idx].move;
        }
        const probabilities = expScores.map((exp) => exp / sumExp);
        // Weighted random sampling
        const random = Math.random();
        let cumulative = 0;
        for (let i = 0; i < scoredMoves.length; i++) {
            cumulative += probabilities[i];
            if (random <= cumulative) {
                return scoredMoves[i].move;
            }
        }
        return scoredMoves[scoredMoves.length - 1].move;
    }
    /**
     * Score all moves using tree attention (batch evaluation)
     */
    async scoreMoves(game, moves) {
        if (moves.length === 0) {
            return [];
        }
        // Build tree structure
        const { prefixTokens, evaluatedIds, mask, parent, moveData } = this.buildMoveTree(game, moves);
        //console.debug(`Tree structure: ${evaluatedIds.length} nodes for ${moveData.length} moves`);
        //console.debug(`Evaluated IDs:`, evaluatedIds.map((id) => String.fromCharCode(id)).join(""));
        //console.debug(
        //	`Move positions:`,
        //	moveData.map((m) => `${m.notation}@${m.leafPos}`)
        //);
        // Prepare inputs for evaluation
        const inputs = {
            prefixIds: prefixTokens,
            evaluatedIds: evaluatedIds,
            evaluatedMask: mask
        };
        // Run inference
        const output = await this.inferencer.runEvaluationInference(inputs);
        const { logits, numEvaluated } = output;
        //console.debug(`Inference output: ${numEvaluated} evaluated positions`);
        //process.stdout.write(".");
        // Minimum probability threshold to avoid log(0) while preserving small probabilities
        const MIN_PROB = 1e-10; // log(1e-10) ≈ -23
        // Score each move by accumulating log probabilities along the path
        // For each move, build the path from root to leaf using parent array
        const scoredMoves = [];
        // Cache softmax results for each output position to avoid recomputation
        const softmaxCache = new Map();
        const getSoftmax = (outputPos) => {
            if (!softmaxCache.has(outputPos)) {
                softmaxCache.set(outputPos, this.inferencer.softmax(logits, outputPos));
            }
            return softmaxCache.get(outputPos);
        };
        for (const data of moveData) {
            let logProb = 0;
            // Special case: leafPos = -1 means empty tokens (single-char notation)
            // Use prefix logits directly to predict the single character
            if (data.leafPos === -1) {
                const notationTokens = this.stringToTokens(data.notation);
                if (notationTokens.length === 1) {
                    // Single-char notation: use prefix output (logits[0]) to predict it
                    const token = notationTokens[0];
                    const probs = getSoftmax(0); // Prefix output
                    const prob = Math.max(probs[token], MIN_PROB);
                    logProb = Math.log(prob);
                }
                else {
                    console.error(`Unexpected: leafPos=-1 but notation length=${notationTokens.length}`);
                    logProb = Math.log(MIN_PROB);
                }
                scoredMoves.push({
                    move: data.move,
                    score: logProb,
                    notation: data.notation
                });
                continue; // Skip the normal path processing
            }
            // Build path from leaf to root using parent array, then reverse
            const pathReverse = [];
            let pos = data.leafPos;
            const visited = new Set();
            // Safety checks: prevent infinite loops and invalid indices
            while (pos !== null && pos !== undefined) {
                // Check for cycles
                if (visited.has(pos)) {
                    console.error(`Cycle detected in parent array at position ${pos}`);
                    break;
                }
                // Check for valid index
                if (pos < 0 || pos >= parent.length) {
                    console.error(`Invalid position ${pos}, parent array length: ${parent.length}`);
                    break;
                }
                visited.add(pos);
                pathReverse.push(pos);
                pos = parent[pos];
                // Safety limit to prevent runaway loops
                if (pathReverse.length > 10000) {
                    console.error(`Path too long (>10000), possible infinite loop. leafPos: ${data.leafPos}`);
                    break;
                }
            }
            // Reverse to get root→leaf path (indices in evaluatedIds array)
            const path = pathReverse.reverse();
            // Now accumulate log probabilities for each transition in path
            // TreeLM returns logits[0..m] where:
            //   logits[0] = output at prefix last position (n-1) → predicts evaluatedIds[0]
            //   logits[i] = output at position (n-1+i) → predicts evaluatedIds[i]
            //
            // For a parent→child transition:
            //   Parent: evaluatedIds[parentIdx] at input position (n+parentIdx)
            //   Parent output: at position (n+parentIdx), which is logits[parentIdx+1]
            //   Child token: evaluatedIds[childIdx]
            //   Probability: softmax(logits[parentIdx+1])[evaluatedIds[childIdx]]
            // Special case: root token (predicted from prefix last position)
            if (path.length > 0) {
                const rootPos = path[0];
                const rootToken = evaluatedIds[rootPos];
                // Root is predicted by prefix last position output (logits[0])
                const probs = getSoftmax(0);
                const prob = Math.max(probs[rootToken], MIN_PROB); // Clip to minimum
                logProb += Math.log(prob);
            }
            // Subsequent transitions: parent→child in tree
            for (let i = 1; i < path.length; i++) {
                const parentPos = path[i - 1]; // evaluatedIds index
                const childPos = path[i]; // evaluatedIds index
                const childToken = evaluatedIds[childPos];
                // Parent output is at logits[parentPos+1]
                const logitsIndex = parentPos + 1;
                // Check bounds: logitsIndex must be <= numEvaluated
                // (logits has length numEvaluated+1, indices 0 to numEvaluated)
                if (logitsIndex <= numEvaluated) {
                    const probs = getSoftmax(logitsIndex);
                    const prob = Math.max(probs[childToken], MIN_PROB); // Clip to minimum
                    logProb += Math.log(prob);
                }
                else {
                    // Parent position out of bounds
                    logProb += Math.log(MIN_PROB);
                }
            }
            // CRITICAL: Add probability for the LAST token (excluded from tree)
            // The last character of the move notation was excluded from evaluatedIds
            // We need to predict it using the leaf node's output
            if (path.length > 0) {
                const leafPos = path[path.length - 1]; // Last position in path
                const lastToken = this.stringToTokens(data.notation).pop(); // Last char of notation
                // Leaf output is at logits[leafPos+1]
                const logitsIndex = leafPos + 1;
                if (logitsIndex <= numEvaluated) {
                    const probs = getSoftmax(logitsIndex);
                    const prob = Math.max(probs[lastToken], MIN_PROB); // Clip to minimum
                    logProb += Math.log(prob);
                }
                else {
                    logProb += Math.log(MIN_PROB);
                }
            }
            scoredMoves.push({
                move: data.move,
                score: logProb,
                notation: data.notation
            });
        }
        return scoredMoves;
    }
}
exports.TrigoTreeAgent = TrigoTreeAgent;
//# sourceMappingURL=trigoTreeAgent.js.map