/** * Test MCTS value sign correction * * This script verifies that the MCTS agent correctly handles value signs * from both Black and White perspectives. */ import { TrigoGame } from "../inc/trigo/game"; import { MCTSAgent } from "../inc/mctsAgent"; import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent"; import { ModelInferencer } from "../inc/modelInferencer"; import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config"; import * as ort from "onnxruntime-node"; async function main() { // Load environment config await loadEnvConfig(); console.log("=== MCTS Value Sign Test ===\n"); // Initialize agents const policyModelPath = process.env.ONNX_TREE_MODEL || "./models/policy_model.onnx"; const valueModelPath = process.env.ONNX_EVALUATION_MODEL || "./models/value_model.onnx"; const vocabSize = parseInt(process.env.VOCAB_SIZE || "342", 10); const seqLen = parseInt(process.env.SEQ_LEN || "256", 10); console.log("Loading models..."); console.log(`Policy: ${policyModelPath}`); console.log(`Value: ${valueModelPath}`); console.log(`Vocab Size: ${vocabSize}`); console.log(`Seq Length: ${seqLen}\n`); // Load sessions const sessionOptions = getOnnxSessionOptions(); const treeSession = await ort.InferenceSession.create(policyModelPath, sessionOptions); const evalSession = await ort.InferenceSession.create(valueModelPath, sessionOptions); // Create inferencers const treeInferencer = new ModelInferencer(ort.Tensor as any, { vocabSize, seqLen, modelPath: policyModelPath }); treeInferencer.setSession(treeSession as any); const evalInferencer = new ModelInferencer(ort.Tensor as any, { vocabSize, seqLen, modelPath: valueModelPath }); evalInferencer.setSession(evalSession as any); // Create agents const treeAgent = new TrigoTreeAgent(treeInferencer); const evaluationAgent = new TrigoEvaluationAgent(evalInferencer); // Create MCTS agent with fewer simulations for quick test const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, { numSimulations: 100, cPuct: 1.0, temperature: 1.0 }); // Test 1: Black's first move console.log("--- Test 1: Black's First Move ---"); const game1 = new TrigoGame({ x: 5, y: 5, z: 5 }); console.log(`Current player: Black (${game1.getCurrentPlayer()})`); const eval1 = await evaluationAgent.evaluatePosition(game1); console.log(`Evaluation value (Black perspective): ${eval1.value.toFixed(4)}`); console.log(`Interpretation: ${eval1.interpretation}\n`); const result1 = await mctsAgent.selectMove(game1); console.log(`Root value (Black perspective): ${result1.rootValue.toFixed(4)}`); console.log(`Selected move: (${result1.move.x}, ${result1.move.y}, ${result1.move.z})\n`); // Test 2: White's first move (after Black plays) console.log("--- Test 2: White's Response ---"); const game2 = new TrigoGame({ x: 5, y: 5, z: 5 }); game2.dropStone({ x: 2, y: 2, z: 2 }); // Black plays center console.log(`Current player: White (${game2.getCurrentPlayer()})`); const eval2 = await evaluationAgent.evaluatePosition(game2); console.log(`Evaluation value (Black perspective): ${eval2.value.toFixed(4)}`); console.log(`Interpretation: ${eval2.interpretation}`); console.log(`From White's perspective: ${(-eval2.value).toFixed(4)}\n`); const result2 = await mctsAgent.selectMove(game2); console.log(`Root value (White perspective): ${result2.rootValue.toFixed(4)}`); console.log(`Selected move: (${result2.move.x}, ${result2.move.y}, ${result2.move.z})\n`); // Test 3: Play a few moves and check consistency console.log("--- Test 3: Short Game ---"); const game3 = new TrigoGame({ x: 5, y: 5, z: 5 }); for (let i = 0; i < 6; i++) { const player = game3.getCurrentPlayer() === 1 ? "Black" : "White"; console.log(`\nMove ${i + 1} - ${player}'s turn`); const evalBefore = await evaluationAgent.evaluatePosition(game3); console.log(` Position value (Black perspective): ${evalBefore.value.toFixed(4)}`); const result = await mctsAgent.selectMove(game3); console.log(` Root value (${player} perspective): ${result.rootValue.toFixed(4)}`); // Check sign consistency if (player === "Black") { // For Black, root value should match eval value const diff = Math.abs(result.rootValue - evalBefore.value); if (diff > 0.1) { console.log(` ⚠️ Warning: Value mismatch! Expected ~${evalBefore.value.toFixed(4)}, got ${result.rootValue.toFixed(4)}`); } } else { // For White, root value should be negated eval value const expectedWhiteValue = -evalBefore.value; const diff = Math.abs(result.rootValue - expectedWhiteValue); if (diff > 0.1) { console.log(` ⚠️ Warning: Value mismatch! Expected ~${expectedWhiteValue.toFixed(4)}, got ${result.rootValue.toFixed(4)}`); } } if (!result.move.isPass) { game3.dropStone({ x: result.move.x!, y: result.move.y!, z: result.move.z! }); console.log(` Played: (${result.move.x}, ${result.move.y}, ${result.move.z})`); } else { game3.pass(); console.log(` Played: PASS`); } } console.log("\n=== Test Complete ==="); console.log("If no warnings appeared, value signs are handled correctly!"); } main().catch(console.error);