trigo / trigo-web /tools /testMCTSValueSign.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
/**
* Test MCTS value sign correction
*
* This script verifies that the MCTS agent correctly handles value signs
* from both Black and White perspectives.
*/
import { TrigoGame } from "../inc/trigo/game";
import { MCTSAgent } from "../inc/mctsAgent";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent";
import { ModelInferencer } from "../inc/modelInferencer";
import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config";
import * as ort from "onnxruntime-node";
async function main() {
// Load environment config
await loadEnvConfig();
console.log("=== MCTS Value Sign Test ===\n");
// Initialize agents
const policyModelPath = process.env.ONNX_TREE_MODEL || "./models/policy_model.onnx";
const valueModelPath = process.env.ONNX_EVALUATION_MODEL || "./models/value_model.onnx";
const vocabSize = parseInt(process.env.VOCAB_SIZE || "342", 10);
const seqLen = parseInt(process.env.SEQ_LEN || "256", 10);
console.log("Loading models...");
console.log(`Policy: ${policyModelPath}`);
console.log(`Value: ${valueModelPath}`);
console.log(`Vocab Size: ${vocabSize}`);
console.log(`Seq Length: ${seqLen}\n`);
// Load sessions
const sessionOptions = getOnnxSessionOptions();
const treeSession = await ort.InferenceSession.create(policyModelPath, sessionOptions);
const evalSession = await ort.InferenceSession.create(valueModelPath, sessionOptions);
// Create inferencers
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize,
seqLen,
modelPath: policyModelPath
});
treeInferencer.setSession(treeSession as any);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize,
seqLen,
modelPath: valueModelPath
});
evalInferencer.setSession(evalSession as any);
// Create agents
const treeAgent = new TrigoTreeAgent(treeInferencer);
const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);
// Create MCTS agent with fewer simulations for quick test
const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
numSimulations: 100,
cPuct: 1.0,
temperature: 1.0
});
// Test 1: Black's first move
console.log("--- Test 1: Black's First Move ---");
const game1 = new TrigoGame({ x: 5, y: 5, z: 5 });
console.log(`Current player: Black (${game1.getCurrentPlayer()})`);
const eval1 = await evaluationAgent.evaluatePosition(game1);
console.log(`Evaluation value (Black perspective): ${eval1.value.toFixed(4)}`);
console.log(`Interpretation: ${eval1.interpretation}\n`);
const result1 = await mctsAgent.selectMove(game1);
console.log(`Root value (Black perspective): ${result1.rootValue.toFixed(4)}`);
console.log(`Selected move: (${result1.move.x}, ${result1.move.y}, ${result1.move.z})\n`);
// Test 2: White's first move (after Black plays)
console.log("--- Test 2: White's Response ---");
const game2 = new TrigoGame({ x: 5, y: 5, z: 5 });
game2.dropStone({ x: 2, y: 2, z: 2 }); // Black plays center
console.log(`Current player: White (${game2.getCurrentPlayer()})`);
const eval2 = await evaluationAgent.evaluatePosition(game2);
console.log(`Evaluation value (Black perspective): ${eval2.value.toFixed(4)}`);
console.log(`Interpretation: ${eval2.interpretation}`);
console.log(`From White's perspective: ${(-eval2.value).toFixed(4)}\n`);
const result2 = await mctsAgent.selectMove(game2);
console.log(`Root value (White perspective): ${result2.rootValue.toFixed(4)}`);
console.log(`Selected move: (${result2.move.x}, ${result2.move.y}, ${result2.move.z})\n`);
// Test 3: Play a few moves and check consistency
console.log("--- Test 3: Short Game ---");
const game3 = new TrigoGame({ x: 5, y: 5, z: 5 });
for (let i = 0; i < 6; i++) {
const player = game3.getCurrentPlayer() === 1 ? "Black" : "White";
console.log(`\nMove ${i + 1} - ${player}'s turn`);
const evalBefore = await evaluationAgent.evaluatePosition(game3);
console.log(` Position value (Black perspective): ${evalBefore.value.toFixed(4)}`);
const result = await mctsAgent.selectMove(game3);
console.log(` Root value (${player} perspective): ${result.rootValue.toFixed(4)}`);
// Check sign consistency
if (player === "Black") {
// For Black, root value should match eval value
const diff = Math.abs(result.rootValue - evalBefore.value);
if (diff > 0.1) {
console.log(` ⚠️ Warning: Value mismatch! Expected ~${evalBefore.value.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
}
} else {
// For White, root value should be negated eval value
const expectedWhiteValue = -evalBefore.value;
const diff = Math.abs(result.rootValue - expectedWhiteValue);
if (diff > 0.1) {
console.log(` ⚠️ Warning: Value mismatch! Expected ~${expectedWhiteValue.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
}
}
if (!result.move.isPass) {
game3.dropStone({
x: result.move.x!,
y: result.move.y!,
z: result.move.z!
});
console.log(` Played: (${result.move.x}, ${result.move.y}, ${result.move.z})`);
} else {
game3.pass();
console.log(` Played: PASS`);
}
}
console.log("\n=== Test Complete ===");
console.log("If no warnings appeared, value signs are handled correctly!");
}
main().catch(console.error);