trigo / trigo-web /tools /testTerminalPropagation.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
/**
* Test terminal propagation with debug output
* Runs a single 5x1x1 game with MCTS and debug mode enabled
*/
import * as ort from "onnxruntime-node";
import { TrigoGame } from "../inc/trigo/game.js";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent.js";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent.js";
import { MCTSAgent } from "../inc/mctsAgent.js";
import { ModelInferencer } from "../inc/modelInferencer.js";
import { encodeAb0yz } from "../inc/trigo/ab0yz.js";
async function main() {
console.log("================================================================================");
console.log("MCTS Terminal Propagation Test - 5x1x1 Board");
console.log("================================================================================\n");
// Model paths
const modelDir = "/home/camus/work/trigo/trigo-web/public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500";
const treeModelPath = `${modelDir}/GPT2CausalLM_ep0019_tree.onnx`;
const evalModelPath = `${modelDir}/GPT2CausalLM_ep0019_evaluation.onnx`;
console.log("Loading models...");
console.log(` Tree: ${treeModelPath}`);
console.log(` Eval: ${evalModelPath}\n`);
// Create ONNX sessions
const sessionOptions: ort.InferenceSession.SessionOptions = {
executionProviders: ["cpu"]
};
const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions);
const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions);
// Create model inferencers
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: 128,
seqLen: 256
});
treeInferencer.setSession(treeSession as any);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: 128,
seqLen: 256
});
evalInferencer.setSession(evalSession as any);
// Create agents
const treeAgent = new TrigoTreeAgent(treeInferencer);
const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);
// Create MCTS agent with DEBUG MODE ENABLED
const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
numSimulations: 200,
cPuct: 1.0,
temperature: 1.0,
dirichletAlpha: 0.03,
dirichletEpsilon: 0.25
});
// ENABLE DEBUG MODE
mctsAgent.debugMode = true;
console.log("✓ MCTS Agent created with debug mode enabled\n");
// Create game
const game = new TrigoGame({ x: 5, y: 1, z: 1 });
game.startGame();
console.log("================================================================================");
console.log("Starting Self-Play Game");
console.log("================================================================================\n");
let moveNumber = 0;
let maxMoves = 20;
while (game.getGameStatus() === "playing" && moveNumber < maxMoves) {
const currentPlayer = game.getCurrentPlayer() === 1 ? "Black" : "White";
console.log(`\n[Move ${moveNumber + 1}] ${currentPlayer} to move`);
console.log("Current board:", game.getBoard().flat().flat());
// Select move with MCTS
const startTime = Date.now();
const result = await mctsAgent.selectMove(game, moveNumber);
const duration = Date.now() - startTime;
// Log move selection
let moveStr: string;
if (result.move.isPass) {
moveStr = "Pass";
} else {
moveStr = encodeAb0yz([result.move.x!, result.move.y!, result.move.z!], [5, 1, 1]);
}
console.log(`\nSelected: ${moveStr}`);
console.log(`Root value: ${result.rootValue.toFixed(4)}`);
console.log(`Time: ${duration}ms`);
console.log(`Visit counts: ${Array.from(result.visitCounts.values()).reduce((a, b) => a + b, 0)}`);
// Apply move
if (result.move.isPass) {
game.pass();
} else {
game.drop({ x: result.move.x!, y: result.move.y!, z: result.move.z! });
}
moveNumber++;
// Check terminal status
if (game.getGameStatus() === "finished") {
console.log("\n" + "=".repeat(80));
console.log("Game finished!");
const territory = game.getTerritory();
console.log(`Final territory: Black=${territory.black}, White=${territory.white}, Neutral=${territory.neutral}`);
console.log(`Score diff (W-B): ${territory.white - territory.black}`);
break;
}
}
// Print game notation
console.log("\n" + "=".repeat(80));
console.log("Game TGN:");
console.log("=".repeat(80));
console.log(game.toTGN());
console.log("Test complete!");
}
main().catch(console.error);