trigo / trigo-web /tools /testAIAgent.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
/**
* Test script for Trigo AI Agent using onnxruntime-node
*
* This script tests the AI agent's ability to:
* 1. Load the ONNX evaluation model
* 2. Initialize the model inferencer
* 3. Score all valid moves using tree attention
* 4. Select and display the best move
*
* Usage: npx tsx tools/testAIAgent.ts
*/
import * as ort from "onnxruntime-node";
import * as path from "path";
import { fileURLToPath } from "url";
import { TrigoGame, StoneType } from "../inc/trigo/game";
import { ModelInferencer } from "../inc/modelInferencer";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import type { Move } from "../inc/trigo/types";
import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config";
// ES module equivalent of __dirname
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
// Load environment variables
await loadEnvConfig();
// Configuration
const modelPaths = getOnnxModelPaths();
const MODEL_PATH = getAbsoluteModelPath(modelPaths.evaluationModel);
const VOCAB_SIZE = 128;
const SEQ_LEN = 256;
/**
* Initialize the AI agent
*/
async function initializeAgent(): Promise<TrigoTreeAgent> {
console.log("=".repeat(80));
console.log("Initializing AI Agent...");
console.log("=".repeat(80));
console.log(`Model Path: ${MODEL_PATH}`);
console.log(`Vocab Size: ${VOCAB_SIZE}`);
console.log(`Sequence Length: ${SEQ_LEN}`);
console.log();
// Load ONNX model
console.log("Loading ONNX model...");
const sessionOptions = getOnnxSessionOptions();
const session = await ort.InferenceSession.create(MODEL_PATH, sessionOptions);
console.log("✓ Model loaded successfully");
// Create inferencer
console.log("Creating model inferencer...");
const inferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: VOCAB_SIZE,
seqLen: SEQ_LEN,
modelPath: MODEL_PATH
});
inferencer.setSession(session as any);
console.log("✓ Inferencer created");
// Create agent
console.log("Creating Trigo tree agent...");
const agent = new TrigoTreeAgent(inferencer);
console.log("✓ Agent initialized");
console.log();
return agent;
}
/**
* Test AI move generation on a fresh game
*/
async function testFreshGame(agent: TrigoTreeAgent): Promise<void> {
console.log("=".repeat(80));
console.log("Test 1: Fresh Game (First Move)");
console.log("=".repeat(80));
// Create a new game
const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {});
console.log("Game created: 5×5×5 board");
console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`);
console.log(`Valid move positions: ${game.validMovePositions().length}`);
console.log();
// Get all valid moves
const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white";
const validPositions = game.validMovePositions();
const moves: Move[] = validPositions.map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
moves.push({ player: currentPlayer, isPass: true }); // Add pass
console.log(`Scoring ${moves.length} moves (${moves.length - 1} positions + pass)...`);
const startTime = Date.now();
// Score all moves
const scoredMoves = await agent.scoreMoves(game, moves);
const endTime = Date.now();
const timeMs = endTime - startTime;
console.log(`✓ Scored ${scoredMoves.length} moves in ${timeMs}ms`);
console.log(` Average: ${(timeMs / scoredMoves.length).toFixed(2)}ms per move`);
console.log();
// Sort by score
scoredMoves.sort((a, b) => b.score - a.score);
// Compute softmax probabilities
const maxScore = Math.max(...scoredMoves.map((m) => m.score));
const expScores = scoredMoves.map((m) => Math.exp(m.score - maxScore));
const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
const probabilities = expScores.map((exp) => exp / sumExp);
// Display top 10 moves
console.log("Top 10 Moves:");
console.log("-".repeat(80));
console.log("Rank | Notation | Position | Log Prob | Probability");
console.log("-".repeat(80));
for (let i = 0; i < Math.min(10, scoredMoves.length); i++) {
const move = scoredMoves[i];
const prob = probabilities[i];
const position = move.move.isPass
? "Pass "
: `(${move.move.x}, ${move.move.y}, ${move.move.z}) `.slice(0, 15);
console.log(
`${(i + 1).toString().padStart(4)} | ` +
`${move.notation.padEnd(8)} | ` +
`${position} | ` +
`${move.score.toFixed(4).padStart(9)} | ` +
`${(prob * 100).toFixed(4)}%`
);
}
console.log();
}
/**
* Test AI move generation after several moves
*/
async function testMidGame(agent: TrigoTreeAgent): Promise<void> {
console.log("=".repeat(80));
console.log("Test 2: Mid Game (After Several Moves)");
console.log("=".repeat(80));
// Create a new game and make some moves
const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {});
// Make some test moves
const testMoves = [
{ x: 2, y: 2, z: 2 }, // Center
{ x: 1, y: 1, z: 1 }, // Near corner
{ x: 3, y: 2, z: 2 }, // Adjacent to center
{ x: 1, y: 2, z: 1 }, // Another move
];
console.log("Playing test moves:");
for (const pos of testMoves) {
const success = game.drop(pos);
const player = game.getCurrentPlayer() === StoneType.BLACK ? "White" : "Black";
console.log(` ${player} plays (${pos.x}, ${pos.y}, ${pos.z}): ${success ? "✓" : "✗"}`);
}
console.log();
// Display current TGN
console.log("Current TGN:");
console.log("-".repeat(80));
console.log(game.toTGN().trim());
console.log("-".repeat(80));
console.log();
// Get current state
console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`);
console.log(`Move count: ${game.stepHistory.length}`);
console.log(`Valid move positions: ${game.validMovePositions().length}`);
console.log();
// Select best move
console.log("Selecting best move...");
const startTime = Date.now();
const bestMove = await agent.selectBestMove(game);
const endTime = Date.now();
if (bestMove) {
const moveStr = bestMove.isPass
? "Pass"
: `(${bestMove.x}, ${bestMove.y}, ${bestMove.z})`;
console.log(`✓ Best move: ${moveStr} (selected in ${endTime - startTime}ms)`);
} else {
console.log("✗ No valid move found");
}
console.log();
}
/**
* Test tree structure visualization
*/
async function testTreeStructure(agent: TrigoTreeAgent): Promise<void> {
console.log("=".repeat(80));
console.log("Test 3: Tree Structure Visualization");
console.log("=".repeat(80));
// Create a simple game
const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {});
// Get limited moves for clearer visualization
const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white";
const validPositions = game.validMovePositions().slice(0, 5); // Only first 5 positions
const moves: Move[] = validPositions.map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
moves.push({ player: currentPlayer, isPass: true }); // Add pass
console.log(`Analyzing tree structure for ${moves.length} moves...`);
console.log();
// Get tree structure
const treeStructure = agent.getTreeStructure(game, moves);
console.log(`Tree nodes: ${treeStructure.evaluatedIds.length}`);
console.log(`Move count: ${treeStructure.moveData.length}`);
console.log();
// Display token sequence
console.log("Evaluated Token Sequence:");
console.log("-".repeat(80));
const tokens = treeStructure.evaluatedIds.map((id) => String.fromCharCode(id)).join("");
console.log(`Tokens: ${tokens}`);
console.log(`Token IDs: [${treeStructure.evaluatedIds.join(", ")}]`);
console.log();
// Display move details
console.log("Move Details:");
console.log("-".repeat(80));
console.log("Notation | Leaf Pos | Parent");
console.log("-".repeat(80));
for (const data of treeStructure.moveData) {
// Find parent from parent array
const parentPos = treeStructure.parent[data.leafPos];
const parentStr = parentPos === null ? "null" : parentPos.toString();
console.log(
`${data.notation.padEnd(8)} | ` +
`${data.leafPos.toString().padStart(8)} | ` +
`${parentStr.padStart(10)}`
);
}
console.log();
// Display attention mask (simplified)
console.log("Attention Mask (1 = can attend, 0 = cannot attend):");
console.log("-".repeat(80));
const size = treeStructure.evaluatedIds.length;
// Header
process.stdout.write(" ");
for (let col = 0; col < size; col++) {
process.stdout.write(col.toString().padStart(2));
}
console.log();
// Rows
for (let row = 0; row < size; row++) {
process.stdout.write(row.toString().padStart(3) + " ");
for (let col = 0; col < size; col++) {
const value = treeStructure.mask[row * size + col];
process.stdout.write(value.toString().padStart(2));
}
console.log();
}
console.log();
}
/**
* Main test function
*/
async function main() {
try {
// Initialize agent
const agent = await initializeAgent();
// Run tests
await testFreshGame(agent);
await testMidGame(agent);
await testTreeStructure(agent);
console.log("=".repeat(80));
console.log("All tests completed successfully!");
console.log("=".repeat(80));
} catch (error) {
console.error("Error:", error);
process.exit(1);
}
}
// Run main function
main();