Spaces:
Sleeping
Sleeping
| /** | |
| * Test terminal propagation with debug output | |
| * Runs a single 5x1x1 game with MCTS and debug mode enabled | |
| */ | |
| import * as ort from "onnxruntime-node"; | |
| import { TrigoGame } from "../inc/trigo/game.js"; | |
| import { TrigoTreeAgent } from "../inc/trigoTreeAgent.js"; | |
| import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent.js"; | |
| import { MCTSAgent } from "../inc/mctsAgent.js"; | |
| import { ModelInferencer } from "../inc/modelInferencer.js"; | |
| import { encodeAb0yz } from "../inc/trigo/ab0yz.js"; | |
| async function main() { | |
| console.log("================================================================================"); | |
| console.log("MCTS Terminal Propagation Test - 5x1x1 Board"); | |
| console.log("================================================================================\n"); | |
| // Model paths | |
| const modelDir = "/home/camus/work/trigo/trigo-web/public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500"; | |
| const treeModelPath = `${modelDir}/GPT2CausalLM_ep0019_tree.onnx`; | |
| const evalModelPath = `${modelDir}/GPT2CausalLM_ep0019_evaluation.onnx`; | |
| console.log("Loading models..."); | |
| console.log(` Tree: ${treeModelPath}`); | |
| console.log(` Eval: ${evalModelPath}\n`); | |
| // Create ONNX sessions | |
| const sessionOptions: ort.InferenceSession.SessionOptions = { | |
| executionProviders: ["cpu"] | |
| }; | |
| const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions); | |
| const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions); | |
| // Create model inferencers | |
| const treeInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: 128, | |
| seqLen: 256 | |
| }); | |
| treeInferencer.setSession(treeSession as any); | |
| const evalInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: 128, | |
| seqLen: 256 | |
| }); | |
| evalInferencer.setSession(evalSession as any); | |
| // Create agents | |
| const treeAgent = new TrigoTreeAgent(treeInferencer); | |
| const evaluationAgent = new TrigoEvaluationAgent(evalInferencer); | |
| // Create MCTS agent with DEBUG MODE ENABLED | |
| const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, { | |
| numSimulations: 200, | |
| cPuct: 1.0, | |
| temperature: 1.0, | |
| dirichletAlpha: 0.03, | |
| dirichletEpsilon: 0.25 | |
| }); | |
| // ENABLE DEBUG MODE | |
| mctsAgent.debugMode = true; | |
| console.log("✓ MCTS Agent created with debug mode enabled\n"); | |
| // Create game | |
| const game = new TrigoGame({ x: 5, y: 1, z: 1 }); | |
| game.startGame(); | |
| console.log("================================================================================"); | |
| console.log("Starting Self-Play Game"); | |
| console.log("================================================================================\n"); | |
| let moveNumber = 0; | |
| let maxMoves = 20; | |
| while (game.getGameStatus() === "playing" && moveNumber < maxMoves) { | |
| const currentPlayer = game.getCurrentPlayer() === 1 ? "Black" : "White"; | |
| console.log(`\n[Move ${moveNumber + 1}] ${currentPlayer} to move`); | |
| console.log("Current board:", game.getBoard().flat().flat()); | |
| // Select move with MCTS | |
| const startTime = Date.now(); | |
| const result = await mctsAgent.selectMove(game, moveNumber); | |
| const duration = Date.now() - startTime; | |
| // Log move selection | |
| let moveStr: string; | |
| if (result.move.isPass) { | |
| moveStr = "Pass"; | |
| } else { | |
| moveStr = encodeAb0yz([result.move.x!, result.move.y!, result.move.z!], [5, 1, 1]); | |
| } | |
| console.log(`\nSelected: ${moveStr}`); | |
| console.log(`Root value: ${result.rootValue.toFixed(4)}`); | |
| console.log(`Time: ${duration}ms`); | |
| console.log(`Visit counts: ${Array.from(result.visitCounts.values()).reduce((a, b) => a + b, 0)}`); | |
| // Apply move | |
| if (result.move.isPass) { | |
| game.pass(); | |
| } else { | |
| game.drop({ x: result.move.x!, y: result.move.y!, z: result.move.z! }); | |
| } | |
| moveNumber++; | |
| // Check terminal status | |
| if (game.getGameStatus() === "finished") { | |
| console.log("\n" + "=".repeat(80)); | |
| console.log("Game finished!"); | |
| const territory = game.getTerritory(); | |
| console.log(`Final territory: Black=${territory.black}, White=${territory.white}, Neutral=${territory.neutral}`); | |
| console.log(`Score diff (W-B): ${territory.white - territory.black}`); | |
| break; | |
| } | |
| } | |
| // Print game notation | |
| console.log("\n" + "=".repeat(80)); | |
| console.log("Game TGN:"); | |
| console.log("=".repeat(80)); | |
| console.log(game.toTGN()); | |
| console.log("Test complete!"); | |
| } | |
| main().catch(console.error); | |