/** * Test script for Trigo AI Agent using onnxruntime-node * * This script tests the AI agent's ability to: * 1. Load the ONNX evaluation model * 2. Initialize the model inferencer * 3. Score all valid moves using tree attention * 4. Select and display the best move * * Usage: npx tsx tools/testAIAgent.ts */ import * as ort from "onnxruntime-node"; import * as path from "path"; import { fileURLToPath } from "url"; import { TrigoGame, StoneType } from "../inc/trigo/game"; import { ModelInferencer } from "../inc/modelInferencer"; import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; import type { Move } from "../inc/trigo/types"; import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config"; // ES module equivalent of __dirname const __filename = fileURLToPath(import.meta.url); const __dirname = path.dirname(__filename); // Load environment variables await loadEnvConfig(); // Configuration const modelPaths = getOnnxModelPaths(); const MODEL_PATH = getAbsoluteModelPath(modelPaths.evaluationModel); const VOCAB_SIZE = 128; const SEQ_LEN = 256; /** * Initialize the AI agent */ async function initializeAgent(): Promise { console.log("=".repeat(80)); console.log("Initializing AI Agent..."); console.log("=".repeat(80)); console.log(`Model Path: ${MODEL_PATH}`); console.log(`Vocab Size: ${VOCAB_SIZE}`); console.log(`Sequence Length: ${SEQ_LEN}`); console.log(); // Load ONNX model console.log("Loading ONNX model..."); const sessionOptions = getOnnxSessionOptions(); const session = await ort.InferenceSession.create(MODEL_PATH, sessionOptions); console.log("✓ Model loaded successfully"); // Create inferencer console.log("Creating model inferencer..."); const inferencer = new ModelInferencer(ort.Tensor as any, { vocabSize: VOCAB_SIZE, seqLen: SEQ_LEN, modelPath: MODEL_PATH }); inferencer.setSession(session as any); console.log("✓ Inferencer created"); // Create agent console.log("Creating Trigo tree agent..."); const agent = new TrigoTreeAgent(inferencer); console.log("✓ Agent initialized"); console.log(); return agent; } /** * Test AI move generation on a fresh game */ async function testFreshGame(agent: TrigoTreeAgent): Promise { console.log("=".repeat(80)); console.log("Test 1: Fresh Game (First Move)"); console.log("=".repeat(80)); // Create a new game const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); console.log("Game created: 5×5×5 board"); console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`); console.log(`Valid move positions: ${game.validMovePositions().length}`); console.log(); // Get all valid moves const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; const validPositions = game.validMovePositions(); const moves: Move[] = validPositions.map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); moves.push({ player: currentPlayer, isPass: true }); // Add pass console.log(`Scoring ${moves.length} moves (${moves.length - 1} positions + pass)...`); const startTime = Date.now(); // Score all moves const scoredMoves = await agent.scoreMoves(game, moves); const endTime = Date.now(); const timeMs = endTime - startTime; console.log(`✓ Scored ${scoredMoves.length} moves in ${timeMs}ms`); console.log(` Average: ${(timeMs / scoredMoves.length).toFixed(2)}ms per move`); console.log(); // Sort by score scoredMoves.sort((a, b) => b.score - a.score); // Compute softmax probabilities const maxScore = Math.max(...scoredMoves.map((m) => m.score)); const expScores = scoredMoves.map((m) => Math.exp(m.score - maxScore)); const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); const probabilities = expScores.map((exp) => exp / sumExp); // Display top 10 moves console.log("Top 10 Moves:"); console.log("-".repeat(80)); console.log("Rank | Notation | Position | Log Prob | Probability"); console.log("-".repeat(80)); for (let i = 0; i < Math.min(10, scoredMoves.length); i++) { const move = scoredMoves[i]; const prob = probabilities[i]; const position = move.move.isPass ? "Pass " : `(${move.move.x}, ${move.move.y}, ${move.move.z}) `.slice(0, 15); console.log( `${(i + 1).toString().padStart(4)} | ` + `${move.notation.padEnd(8)} | ` + `${position} | ` + `${move.score.toFixed(4).padStart(9)} | ` + `${(prob * 100).toFixed(4)}%` ); } console.log(); } /** * Test AI move generation after several moves */ async function testMidGame(agent: TrigoTreeAgent): Promise { console.log("=".repeat(80)); console.log("Test 2: Mid Game (After Several Moves)"); console.log("=".repeat(80)); // Create a new game and make some moves const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); // Make some test moves const testMoves = [ { x: 2, y: 2, z: 2 }, // Center { x: 1, y: 1, z: 1 }, // Near corner { x: 3, y: 2, z: 2 }, // Adjacent to center { x: 1, y: 2, z: 1 }, // Another move ]; console.log("Playing test moves:"); for (const pos of testMoves) { const success = game.drop(pos); const player = game.getCurrentPlayer() === StoneType.BLACK ? "White" : "Black"; console.log(` ${player} plays (${pos.x}, ${pos.y}, ${pos.z}): ${success ? "✓" : "✗"}`); } console.log(); // Display current TGN console.log("Current TGN:"); console.log("-".repeat(80)); console.log(game.toTGN().trim()); console.log("-".repeat(80)); console.log(); // Get current state console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`); console.log(`Move count: ${game.stepHistory.length}`); console.log(`Valid move positions: ${game.validMovePositions().length}`); console.log(); // Select best move console.log("Selecting best move..."); const startTime = Date.now(); const bestMove = await agent.selectBestMove(game); const endTime = Date.now(); if (bestMove) { const moveStr = bestMove.isPass ? "Pass" : `(${bestMove.x}, ${bestMove.y}, ${bestMove.z})`; console.log(`✓ Best move: ${moveStr} (selected in ${endTime - startTime}ms)`); } else { console.log("✗ No valid move found"); } console.log(); } /** * Test tree structure visualization */ async function testTreeStructure(agent: TrigoTreeAgent): Promise { console.log("=".repeat(80)); console.log("Test 3: Tree Structure Visualization"); console.log("=".repeat(80)); // Create a simple game const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); // Get limited moves for clearer visualization const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; const validPositions = game.validMovePositions().slice(0, 5); // Only first 5 positions const moves: Move[] = validPositions.map((pos) => ({ x: pos.x, y: pos.y, z: pos.z, player: currentPlayer })); moves.push({ player: currentPlayer, isPass: true }); // Add pass console.log(`Analyzing tree structure for ${moves.length} moves...`); console.log(); // Get tree structure const treeStructure = agent.getTreeStructure(game, moves); console.log(`Tree nodes: ${treeStructure.evaluatedIds.length}`); console.log(`Move count: ${treeStructure.moveData.length}`); console.log(); // Display token sequence console.log("Evaluated Token Sequence:"); console.log("-".repeat(80)); const tokens = treeStructure.evaluatedIds.map((id) => String.fromCharCode(id)).join(""); console.log(`Tokens: ${tokens}`); console.log(`Token IDs: [${treeStructure.evaluatedIds.join(", ")}]`); console.log(); // Display move details console.log("Move Details:"); console.log("-".repeat(80)); console.log("Notation | Leaf Pos | Parent"); console.log("-".repeat(80)); for (const data of treeStructure.moveData) { // Find parent from parent array const parentPos = treeStructure.parent[data.leafPos]; const parentStr = parentPos === null ? "null" : parentPos.toString(); console.log( `${data.notation.padEnd(8)} | ` + `${data.leafPos.toString().padStart(8)} | ` + `${parentStr.padStart(10)}` ); } console.log(); // Display attention mask (simplified) console.log("Attention Mask (1 = can attend, 0 = cannot attend):"); console.log("-".repeat(80)); const size = treeStructure.evaluatedIds.length; // Header process.stdout.write(" "); for (let col = 0; col < size; col++) { process.stdout.write(col.toString().padStart(2)); } console.log(); // Rows for (let row = 0; row < size; row++) { process.stdout.write(row.toString().padStart(3) + " "); for (let col = 0; col < size; col++) { const value = treeStructure.mask[row * size + col]; process.stdout.write(value.toString().padStart(2)); } console.log(); } console.log(); } /** * Main test function */ async function main() { try { // Initialize agent const agent = await initializeAgent(); // Run tests await testFreshGame(agent); await testMidGame(agent); await testTreeStructure(agent); console.log("=".repeat(80)); console.log("All tests completed successfully!"); console.log("=".repeat(80)); } catch (error) { console.error("Error:", error); process.exit(1); } } // Run main function main();