File size: 5,269 Bytes
15f353f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
/**
 * Test MCTS value sign correction
 *
 * This script verifies that the MCTS agent correctly handles value signs
 * from both Black and White perspectives.
 */

import { TrigoGame } from "../inc/trigo/game";
import { MCTSAgent } from "../inc/mctsAgent";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent";
import { ModelInferencer } from "../inc/modelInferencer";
import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config";
import * as ort from "onnxruntime-node";


async function main() {
	// Load environment config
	await loadEnvConfig();

	console.log("=== MCTS Value Sign Test ===\n");

	// Initialize agents
	const policyModelPath = process.env.ONNX_TREE_MODEL || "./models/policy_model.onnx";
	const valueModelPath = process.env.ONNX_EVALUATION_MODEL || "./models/value_model.onnx";
	const vocabSize = parseInt(process.env.VOCAB_SIZE || "342", 10);
	const seqLen = parseInt(process.env.SEQ_LEN || "256", 10);

	console.log("Loading models...");
	console.log(`Policy: ${policyModelPath}`);
	console.log(`Value: ${valueModelPath}`);
	console.log(`Vocab Size: ${vocabSize}`);
	console.log(`Seq Length: ${seqLen}\n`);

	// Load sessions
	const sessionOptions = getOnnxSessionOptions();
	const treeSession = await ort.InferenceSession.create(policyModelPath, sessionOptions);
	const evalSession = await ort.InferenceSession.create(valueModelPath, sessionOptions);

	// Create inferencers
	const treeInferencer = new ModelInferencer(ort.Tensor as any, {
		vocabSize,
		seqLen,
		modelPath: policyModelPath
	});
	treeInferencer.setSession(treeSession as any);

	const evalInferencer = new ModelInferencer(ort.Tensor as any, {
		vocabSize,
		seqLen,
		modelPath: valueModelPath
	});
	evalInferencer.setSession(evalSession as any);

	// Create agents
	const treeAgent = new TrigoTreeAgent(treeInferencer);
	const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);

	// Create MCTS agent with fewer simulations for quick test
	const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
		numSimulations: 100,
		cPuct: 1.0,
		temperature: 1.0
	});

	// Test 1: Black's first move
	console.log("--- Test 1: Black's First Move ---");
	const game1 = new TrigoGame({ x: 5, y: 5, z: 5 });
	console.log(`Current player: Black (${game1.getCurrentPlayer()})`);

	const eval1 = await evaluationAgent.evaluatePosition(game1);
	console.log(`Evaluation value (Black perspective): ${eval1.value.toFixed(4)}`);
	console.log(`Interpretation: ${eval1.interpretation}\n`);

	const result1 = await mctsAgent.selectMove(game1);
	console.log(`Root value (Black perspective): ${result1.rootValue.toFixed(4)}`);
	console.log(`Selected move: (${result1.move.x}, ${result1.move.y}, ${result1.move.z})\n`);


	// Test 2: White's first move (after Black plays)
	console.log("--- Test 2: White's Response ---");
	const game2 = new TrigoGame({ x: 5, y: 5, z: 5 });
	game2.dropStone({ x: 2, y: 2, z: 2 }); // Black plays center
	console.log(`Current player: White (${game2.getCurrentPlayer()})`);

	const eval2 = await evaluationAgent.evaluatePosition(game2);
	console.log(`Evaluation value (Black perspective): ${eval2.value.toFixed(4)}`);
	console.log(`Interpretation: ${eval2.interpretation}`);
	console.log(`From White's perspective: ${(-eval2.value).toFixed(4)}\n`);

	const result2 = await mctsAgent.selectMove(game2);
	console.log(`Root value (White perspective): ${result2.rootValue.toFixed(4)}`);
	console.log(`Selected move: (${result2.move.x}, ${result2.move.y}, ${result2.move.z})\n`);


	// Test 3: Play a few moves and check consistency
	console.log("--- Test 3: Short Game ---");
	const game3 = new TrigoGame({ x: 5, y: 5, z: 5 });

	for (let i = 0; i < 6; i++) {
		const player = game3.getCurrentPlayer() === 1 ? "Black" : "White";
		console.log(`\nMove ${i + 1} - ${player}'s turn`);

		const evalBefore = await evaluationAgent.evaluatePosition(game3);
		console.log(`  Position value (Black perspective): ${evalBefore.value.toFixed(4)}`);

		const result = await mctsAgent.selectMove(game3);
		console.log(`  Root value (${player} perspective): ${result.rootValue.toFixed(4)}`);

		// Check sign consistency
		if (player === "Black") {
			// For Black, root value should match eval value
			const diff = Math.abs(result.rootValue - evalBefore.value);
			if (diff > 0.1) {
				console.log(`  ⚠️  Warning: Value mismatch! Expected ~${evalBefore.value.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
			}
		} else {
			// For White, root value should be negated eval value
			const expectedWhiteValue = -evalBefore.value;
			const diff = Math.abs(result.rootValue - expectedWhiteValue);
			if (diff > 0.1) {
				console.log(`  ⚠️  Warning: Value mismatch! Expected ~${expectedWhiteValue.toFixed(4)}, got ${result.rootValue.toFixed(4)}`);
			}
		}

		if (!result.move.isPass) {
			game3.dropStone({
				x: result.move.x!,
				y: result.move.y!,
				z: result.move.z!
			});
			console.log(`  Played: (${result.move.x}, ${result.move.y}, ${result.move.z})`);
		} else {
			game3.pass();
			console.log(`  Played: PASS`);
		}
	}

	console.log("\n=== Test Complete ===");
	console.log("If no warnings appeared, value signs are handled correctly!");
}


main().catch(console.error);