Spaces:
Sleeping
Sleeping
File size: 5,269 Bytes
15f353f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 |
/**
* Test MCTS value sign correction
*
* This script verifies that the MCTS agent correctly handles value signs
* from both Black and White perspectives.
*/
import { TrigoGame } from "../inc/trigo/game";
import { MCTSAgent } from "../inc/mctsAgent";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent";
import { ModelInferencer } from "../inc/modelInferencer";
import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config";
import * as ort from "onnxruntime-node";
async function main() {
// Load environment config
await loadEnvConfig();
console.log("=== MCTS Value Sign Test ===\n");
// Initialize agents
const policyModelPath = process.env.ONNX_TREE_MODEL || "./models/policy_model.onnx";
const valueModelPath = process.env.ONNX_EVALUATION_MODEL || "./models/value_model.onnx";
const vocabSize = parseInt(process.env.VOCAB_SIZE || "342", 10);
const seqLen = parseInt(process.env.SEQ_LEN || "256", 10);
console.log("Loading models...");
console.log(`Policy: ${policyModelPath}`);
console.log(`Value: ${valueModelPath}`);
console.log(`Vocab Size: ${vocabSize}`);
console.log(`Seq Length: ${seqLen}\n`);
// Load sessions
const sessionOptions = getOnnxSessionOptions();
const treeSession = await ort.InferenceSession.create(policyModelPath, sessionOptions);
const evalSession = await ort.InferenceSession.create(valueModelPath, sessionOptions);
// Create inferencers
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize,
seqLen,
modelPath: policyModelPath
});
treeInferencer.setSession(treeSession as any);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize,
seqLen,
modelPath: valueModelPath
});
evalInferencer.setSession(evalSession as any);
// Create agents
const treeAgent = new TrigoTreeAgent(treeInferencer);
const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);
// Create MCTS agent with fewer simulations for quick test
const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
numSimulations: 100,
cPuct: 1.0,
temperature: 1.0
});
// Test 1: Black's first move
console.log("--- Test 1: Black's First Move ---");
const game1 = new TrigoGame({ x: 5, y: 5, z: 5 });
console.log(`Current player: Black (${game1.getCurrentPlayer()})`);
const eval1 = await evaluationAgent.evaluatePosition(game1);
console.log(`Evaluation value (Black perspective): ${eval1.value.toFixed(4)}`);
console.log(`Interpretation: ${eval1.interpretation}\n`);
const result1 = await mctsAgent.selectMove(game1);
console.log(`Root value (Black perspective): ${result1.rootValue.toFixed(4)}`);
console.log(`Selected move: (${result1.move.x}, ${result1.move.y}, ${result1.move.z})\n`);
// Test 2: White's first move (after Black plays)
console.log("--- Test 2: White's Response ---");
const game2 = new TrigoGame({ x: 5, y: 5, z: 5 });
game2.dropStone({ x: 2, y: 2, z: 2 }); // Black plays center
console.log(`Current player: White (${game2.getCurrentPlayer()})`);
const eval2 = await evaluationAgent.evaluatePosition(game2);
console.log(`Evaluation value (Black perspective): ${eval2.value.toFixed(4)}`);
console.log(`Interpretation: ${eval2.interpretation}`);
console.log(`From White's perspective: ${(-eval2.value).toFixed(4)}\n`);
const result2 = await mctsAgent.selectMove(game2);
console.log(`Root value (White perspective): ${result2.rootValue.toFixed(4)}`);
console.log(`Selected move: (${result2.move.x}, ${result2.move.y}, ${result2.move.z})\n`);
// Test 3: Play a few moves and check consistency
console.log("--- Test 3: Short Game ---");
const game3 = new TrigoGame({ x: 5, y: 5, z: 5 });
for (let i = 0; i < 6; i++) {
const player = game3.getCurrentPlayer() === 1 ? "Black" : "White";
console.log(`\nMove ${i + 1} - ${player}'s turn`);
const evalBefore = await evaluationAgent.evaluatePosition(game3);
console.log(` Position value (Black perspective): ${evalBefore.value.toFixed(4)}`);
const result = await mctsAgent.selectMove(game3);
console.log(` Root value (${player} perspective): ${result.rootValue.toFixed(4)}`);
// Check sign consistency
if (player === "Black") {
// For Black, root value should match eval value
const diff = Math.abs(result.rootValue - evalBefore.value);
if (diff > 0.1) {
console.log(` ⚠️ Warning: Value mismatch! Expected ~${evalBefore.value.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
}
} else {
// For White, root value should be negated eval value
const expectedWhiteValue = -evalBefore.value;
const diff = Math.abs(result.rootValue - expectedWhiteValue);
if (diff > 0.1) {
console.log(` ⚠️ Warning: Value mismatch! Expected ~${expectedWhiteValue.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
}
}
if (!result.move.isPass) {
game3.dropStone({
x: result.move.x!,
y: result.move.y!,
z: result.move.z!
});
console.log(` Played: (${result.move.x}, ${result.move.y}, ${result.move.z})`);
} else {
game3.pass();
console.log(` Played: PASS`);
}
}
console.log("\n=== Test Complete ===");
console.log("If no warnings appeared, value signs are handled correctly!");
}
main().catch(console.error);
|