Spaces:
Sleeping
Sleeping
| /** | |
| * Test script for Trigo AI Agent using onnxruntime-node | |
| * | |
| * This script tests the AI agent's ability to: | |
| * 1. Load the ONNX evaluation model | |
| * 2. Initialize the model inferencer | |
| * 3. Score all valid moves using tree attention | |
| * 4. Select and display the best move | |
| * | |
| * Usage: npx tsx tools/testAIAgent.ts | |
| */ | |
| import * as ort from "onnxruntime-node"; | |
| import * as path from "path"; | |
| import { fileURLToPath } from "url"; | |
| import { TrigoGame, StoneType } from "../inc/trigo/game"; | |
| import { ModelInferencer } from "../inc/modelInferencer"; | |
| import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; | |
| import type { Move } from "../inc/trigo/types"; | |
| import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config"; | |
| // ES module equivalent of __dirname | |
| const __filename = fileURLToPath(import.meta.url); | |
| const __dirname = path.dirname(__filename); | |
| // Load environment variables | |
| await loadEnvConfig(); | |
| // Configuration | |
| const modelPaths = getOnnxModelPaths(); | |
| const MODEL_PATH = getAbsoluteModelPath(modelPaths.evaluationModel); | |
| const VOCAB_SIZE = 128; | |
| const SEQ_LEN = 256; | |
| /** | |
| * Initialize the AI agent | |
| */ | |
| async function initializeAgent(): Promise<TrigoTreeAgent> { | |
| console.log("=".repeat(80)); | |
| console.log("Initializing AI Agent..."); | |
| console.log("=".repeat(80)); | |
| console.log(`Model Path: ${MODEL_PATH}`); | |
| console.log(`Vocab Size: ${VOCAB_SIZE}`); | |
| console.log(`Sequence Length: ${SEQ_LEN}`); | |
| console.log(); | |
| // Load ONNX model | |
| console.log("Loading ONNX model..."); | |
| const sessionOptions = getOnnxSessionOptions(); | |
| const session = await ort.InferenceSession.create(MODEL_PATH, sessionOptions); | |
| console.log("✓ Model loaded successfully"); | |
| // Create inferencer | |
| console.log("Creating model inferencer..."); | |
| const inferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: VOCAB_SIZE, | |
| seqLen: SEQ_LEN, | |
| modelPath: MODEL_PATH | |
| }); | |
| inferencer.setSession(session as any); | |
| console.log("✓ Inferencer created"); | |
| // Create agent | |
| console.log("Creating Trigo tree agent..."); | |
| const agent = new TrigoTreeAgent(inferencer); | |
| console.log("✓ Agent initialized"); | |
| console.log(); | |
| return agent; | |
| } | |
| /** | |
| * Test AI move generation on a fresh game | |
| */ | |
| async function testFreshGame(agent: TrigoTreeAgent): Promise<void> { | |
| console.log("=".repeat(80)); | |
| console.log("Test 1: Fresh Game (First Move)"); | |
| console.log("=".repeat(80)); | |
| // Create a new game | |
| const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); | |
| console.log("Game created: 5×5×5 board"); | |
| console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`); | |
| console.log(`Valid move positions: ${game.validMovePositions().length}`); | |
| console.log(); | |
| // Get all valid moves | |
| const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; | |
| const validPositions = game.validMovePositions(); | |
| const moves: Move[] = validPositions.map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| moves.push({ player: currentPlayer, isPass: true }); // Add pass | |
| console.log(`Scoring ${moves.length} moves (${moves.length - 1} positions + pass)...`); | |
| const startTime = Date.now(); | |
| // Score all moves | |
| const scoredMoves = await agent.scoreMoves(game, moves); | |
| const endTime = Date.now(); | |
| const timeMs = endTime - startTime; | |
| console.log(`✓ Scored ${scoredMoves.length} moves in ${timeMs}ms`); | |
| console.log(` Average: ${(timeMs / scoredMoves.length).toFixed(2)}ms per move`); | |
| console.log(); | |
| // Sort by score | |
| scoredMoves.sort((a, b) => b.score - a.score); | |
| // Compute softmax probabilities | |
| const maxScore = Math.max(...scoredMoves.map((m) => m.score)); | |
| const expScores = scoredMoves.map((m) => Math.exp(m.score - maxScore)); | |
| const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); | |
| const probabilities = expScores.map((exp) => exp / sumExp); | |
| // Display top 10 moves | |
| console.log("Top 10 Moves:"); | |
| console.log("-".repeat(80)); | |
| console.log("Rank | Notation | Position | Log Prob | Probability"); | |
| console.log("-".repeat(80)); | |
| for (let i = 0; i < Math.min(10, scoredMoves.length); i++) { | |
| const move = scoredMoves[i]; | |
| const prob = probabilities[i]; | |
| const position = move.move.isPass | |
| ? "Pass " | |
| : `(${move.move.x}, ${move.move.y}, ${move.move.z}) `.slice(0, 15); | |
| console.log( | |
| `${(i + 1).toString().padStart(4)} | ` + | |
| `${move.notation.padEnd(8)} | ` + | |
| `${position} | ` + | |
| `${move.score.toFixed(4).padStart(9)} | ` + | |
| `${(prob * 100).toFixed(4)}%` | |
| ); | |
| } | |
| console.log(); | |
| } | |
| /** | |
| * Test AI move generation after several moves | |
| */ | |
| async function testMidGame(agent: TrigoTreeAgent): Promise<void> { | |
| console.log("=".repeat(80)); | |
| console.log("Test 2: Mid Game (After Several Moves)"); | |
| console.log("=".repeat(80)); | |
| // Create a new game and make some moves | |
| const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); | |
| // Make some test moves | |
| const testMoves = [ | |
| { x: 2, y: 2, z: 2 }, // Center | |
| { x: 1, y: 1, z: 1 }, // Near corner | |
| { x: 3, y: 2, z: 2 }, // Adjacent to center | |
| { x: 1, y: 2, z: 1 }, // Another move | |
| ]; | |
| console.log("Playing test moves:"); | |
| for (const pos of testMoves) { | |
| const success = game.drop(pos); | |
| const player = game.getCurrentPlayer() === StoneType.BLACK ? "White" : "Black"; | |
| console.log(` ${player} plays (${pos.x}, ${pos.y}, ${pos.z}): ${success ? "✓" : "✗"}`); | |
| } | |
| console.log(); | |
| // Display current TGN | |
| console.log("Current TGN:"); | |
| console.log("-".repeat(80)); | |
| console.log(game.toTGN().trim()); | |
| console.log("-".repeat(80)); | |
| console.log(); | |
| // Get current state | |
| console.log(`Current player: ${game.getCurrentPlayer() === StoneType.BLACK ? "Black" : "White"}`); | |
| console.log(`Move count: ${game.stepHistory.length}`); | |
| console.log(`Valid move positions: ${game.validMovePositions().length}`); | |
| console.log(); | |
| // Select best move | |
| console.log("Selecting best move..."); | |
| const startTime = Date.now(); | |
| const bestMove = await agent.selectBestMove(game); | |
| const endTime = Date.now(); | |
| if (bestMove) { | |
| const moveStr = bestMove.isPass | |
| ? "Pass" | |
| : `(${bestMove.x}, ${bestMove.y}, ${bestMove.z})`; | |
| console.log(`✓ Best move: ${moveStr} (selected in ${endTime - startTime}ms)`); | |
| } else { | |
| console.log("✗ No valid move found"); | |
| } | |
| console.log(); | |
| } | |
| /** | |
| * Test tree structure visualization | |
| */ | |
| async function testTreeStructure(agent: TrigoTreeAgent): Promise<void> { | |
| console.log("=".repeat(80)); | |
| console.log("Test 3: Tree Structure Visualization"); | |
| console.log("=".repeat(80)); | |
| // Create a simple game | |
| const game = new TrigoGame({ x: 5, y: 5, z: 5 }, {}); | |
| // Get limited moves for clearer visualization | |
| const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; | |
| const validPositions = game.validMovePositions().slice(0, 5); // Only first 5 positions | |
| const moves: Move[] = validPositions.map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| moves.push({ player: currentPlayer, isPass: true }); // Add pass | |
| console.log(`Analyzing tree structure for ${moves.length} moves...`); | |
| console.log(); | |
| // Get tree structure | |
| const treeStructure = agent.getTreeStructure(game, moves); | |
| console.log(`Tree nodes: ${treeStructure.evaluatedIds.length}`); | |
| console.log(`Move count: ${treeStructure.moveData.length}`); | |
| console.log(); | |
| // Display token sequence | |
| console.log("Evaluated Token Sequence:"); | |
| console.log("-".repeat(80)); | |
| const tokens = treeStructure.evaluatedIds.map((id) => String.fromCharCode(id)).join(""); | |
| console.log(`Tokens: ${tokens}`); | |
| console.log(`Token IDs: [${treeStructure.evaluatedIds.join(", ")}]`); | |
| console.log(); | |
| // Display move details | |
| console.log("Move Details:"); | |
| console.log("-".repeat(80)); | |
| console.log("Notation | Leaf Pos | Parent"); | |
| console.log("-".repeat(80)); | |
| for (const data of treeStructure.moveData) { | |
| // Find parent from parent array | |
| const parentPos = treeStructure.parent[data.leafPos]; | |
| const parentStr = parentPos === null ? "null" : parentPos.toString(); | |
| console.log( | |
| `${data.notation.padEnd(8)} | ` + | |
| `${data.leafPos.toString().padStart(8)} | ` + | |
| `${parentStr.padStart(10)}` | |
| ); | |
| } | |
| console.log(); | |
| // Display attention mask (simplified) | |
| console.log("Attention Mask (1 = can attend, 0 = cannot attend):"); | |
| console.log("-".repeat(80)); | |
| const size = treeStructure.evaluatedIds.length; | |
| // Header | |
| process.stdout.write(" "); | |
| for (let col = 0; col < size; col++) { | |
| process.stdout.write(col.toString().padStart(2)); | |
| } | |
| console.log(); | |
| // Rows | |
| for (let row = 0; row < size; row++) { | |
| process.stdout.write(row.toString().padStart(3) + " "); | |
| for (let col = 0; col < size; col++) { | |
| const value = treeStructure.mask[row * size + col]; | |
| process.stdout.write(value.toString().padStart(2)); | |
| } | |
| console.log(); | |
| } | |
| console.log(); | |
| } | |
| /** | |
| * Main test function | |
| */ | |
| async function main() { | |
| try { | |
| // Initialize agent | |
| const agent = await initializeAgent(); | |
| // Run tests | |
| await testFreshGame(agent); | |
| await testMidGame(agent); | |
| await testTreeStructure(agent); | |
| console.log("=".repeat(80)); | |
| console.log("All tests completed successfully!"); | |
| console.log("=".repeat(80)); | |
| } catch (error) { | |
| console.error("Error:", error); | |
| process.exit(1); | |
| } | |
| } | |
| // Run main function | |
| main(); | |