Spaces:
Sleeping
Sleeping
| /** | |
| * Test MCTS value sign correction | |
| * | |
| * This script verifies that the MCTS agent correctly handles value signs | |
| * from both Black and White perspectives. | |
| */ | |
| import { TrigoGame } from "../inc/trigo/game"; | |
| import { MCTSAgent } from "../inc/mctsAgent"; | |
| import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; | |
| import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent"; | |
| import { ModelInferencer } from "../inc/modelInferencer"; | |
| import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config"; | |
| import * as ort from "onnxruntime-node"; | |
| async function main() { | |
| // Load environment config | |
| await loadEnvConfig(); | |
| console.log("=== MCTS Value Sign Test ===\n"); | |
| // Initialize agents | |
| const policyModelPath = process.env.ONNX_TREE_MODEL || "./models/policy_model.onnx"; | |
| const valueModelPath = process.env.ONNX_EVALUATION_MODEL || "./models/value_model.onnx"; | |
| const vocabSize = parseInt(process.env.VOCAB_SIZE || "342", 10); | |
| const seqLen = parseInt(process.env.SEQ_LEN || "256", 10); | |
| console.log("Loading models..."); | |
| console.log(`Policy: ${policyModelPath}`); | |
| console.log(`Value: ${valueModelPath}`); | |
| console.log(`Vocab Size: ${vocabSize}`); | |
| console.log(`Seq Length: ${seqLen}\n`); | |
| // Load sessions | |
| const sessionOptions = getOnnxSessionOptions(); | |
| const treeSession = await ort.InferenceSession.create(policyModelPath, sessionOptions); | |
| const evalSession = await ort.InferenceSession.create(valueModelPath, sessionOptions); | |
| // Create inferencers | |
| const treeInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize, | |
| seqLen, | |
| modelPath: policyModelPath | |
| }); | |
| treeInferencer.setSession(treeSession as any); | |
| const evalInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize, | |
| seqLen, | |
| modelPath: valueModelPath | |
| }); | |
| evalInferencer.setSession(evalSession as any); | |
| // Create agents | |
| const treeAgent = new TrigoTreeAgent(treeInferencer); | |
| const evaluationAgent = new TrigoEvaluationAgent(evalInferencer); | |
| // Create MCTS agent with fewer simulations for quick test | |
| const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, { | |
| numSimulations: 100, | |
| cPuct: 1.0, | |
| temperature: 1.0 | |
| }); | |
| // Test 1: Black's first move | |
| console.log("--- Test 1: Black's First Move ---"); | |
| const game1 = new TrigoGame({ x: 5, y: 5, z: 5 }); | |
| console.log(`Current player: Black (${game1.getCurrentPlayer()})`); | |
| const eval1 = await evaluationAgent.evaluatePosition(game1); | |
| console.log(`Evaluation value (Black perspective): ${eval1.value.toFixed(4)}`); | |
| console.log(`Interpretation: ${eval1.interpretation}\n`); | |
| const result1 = await mctsAgent.selectMove(game1); | |
| console.log(`Root value (Black perspective): ${result1.rootValue.toFixed(4)}`); | |
| console.log(`Selected move: (${result1.move.x}, ${result1.move.y}, ${result1.move.z})\n`); | |
| // Test 2: White's first move (after Black plays) | |
| console.log("--- Test 2: White's Response ---"); | |
| const game2 = new TrigoGame({ x: 5, y: 5, z: 5 }); | |
| game2.dropStone({ x: 2, y: 2, z: 2 }); // Black plays center | |
| console.log(`Current player: White (${game2.getCurrentPlayer()})`); | |
| const eval2 = await evaluationAgent.evaluatePosition(game2); | |
| console.log(`Evaluation value (Black perspective): ${eval2.value.toFixed(4)}`); | |
| console.log(`Interpretation: ${eval2.interpretation}`); | |
| console.log(`From White's perspective: ${(-eval2.value).toFixed(4)}\n`); | |
| const result2 = await mctsAgent.selectMove(game2); | |
| console.log(`Root value (White perspective): ${result2.rootValue.toFixed(4)}`); | |
| console.log(`Selected move: (${result2.move.x}, ${result2.move.y}, ${result2.move.z})\n`); | |
| // Test 3: Play a few moves and check consistency | |
| console.log("--- Test 3: Short Game ---"); | |
| const game3 = new TrigoGame({ x: 5, y: 5, z: 5 }); | |
| for (let i = 0; i < 6; i++) { | |
| const player = game3.getCurrentPlayer() === 1 ? "Black" : "White"; | |
| console.log(`\nMove ${i + 1} - ${player}'s turn`); | |
| const evalBefore = await evaluationAgent.evaluatePosition(game3); | |
| console.log(` Position value (Black perspective): ${evalBefore.value.toFixed(4)}`); | |
| const result = await mctsAgent.selectMove(game3); | |
| console.log(` Root value (${player} perspective): ${result.rootValue.toFixed(4)}`); | |
| // Check sign consistency | |
| if (player === "Black") { | |
| // For Black, root value should match eval value | |
| const diff = Math.abs(result.rootValue - evalBefore.value); | |
| if (diff > 0.1) { | |
| console.log(` ⚠️ Warning: Value mismatch! Expected ~${evalBefore.value.toFixed(4)}, got ${result.rootValue.toFixed(4)}`); | |
| } | |
| } else { | |
| // For White, root value should be negated eval value | |
| const expectedWhiteValue = -evalBefore.value; | |
| const diff = Math.abs(result.rootValue - expectedWhiteValue); | |
| if (diff > 0.1) { | |
| console.log(` ⚠️ Warning: Value mismatch! Expected ~${expectedWhiteValue.toFixed(4)}, got ${result.rootValue.toFixed(4)}`); | |
| } | |
| } | |
| if (!result.move.isPass) { | |
| game3.dropStone({ | |
| x: result.move.x!, | |
| y: result.move.y!, | |
| z: result.move.z! | |
| }); | |
| console.log(` Played: (${result.move.x}, ${result.move.y}, ${result.move.z})`); | |
| } else { | |
| game3.pass(); | |
| console.log(` Played: PASS`); | |
| } | |
| } | |
| console.log("\n=== Test Complete ==="); | |
| console.log("If no warnings appeared, value signs are handled correctly!"); | |
| } | |
| main().catch(console.error); | |