File size: 4,325 Bytes
15f353f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
/**
 * Test terminal propagation with debug output
 * Runs a single 5x1x1 game with MCTS and debug mode enabled
 */

import * as ort from "onnxruntime-node";
import { TrigoGame } from "../inc/trigo/game.js";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent.js";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent.js";
import { MCTSAgent } from "../inc/mctsAgent.js";
import { ModelInferencer } from "../inc/modelInferencer.js";
import { encodeAb0yz } from "../inc/trigo/ab0yz.js";


async function main() {
	console.log("================================================================================");
	console.log("MCTS Terminal Propagation Test - 5x1x1 Board");
	console.log("================================================================================\n");

	// Model paths
	const modelDir = "/home/camus/work/trigo/trigo-web/public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500";
	const treeModelPath = `${modelDir}/GPT2CausalLM_ep0019_tree.onnx`;
	const evalModelPath = `${modelDir}/GPT2CausalLM_ep0019_evaluation.onnx`;

	console.log("Loading models...");
	console.log(`  Tree: ${treeModelPath}`);
	console.log(`  Eval: ${evalModelPath}\n`);

	// Create ONNX sessions
	const sessionOptions: ort.InferenceSession.SessionOptions = {
		executionProviders: ["cpu"]
	};

	const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions);
	const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions);

	// Create model inferencers
	const treeInferencer = new ModelInferencer(ort.Tensor as any, {
		vocabSize: 128,
		seqLen: 256
	});
	treeInferencer.setSession(treeSession as any);

	const evalInferencer = new ModelInferencer(ort.Tensor as any, {
		vocabSize: 128,
		seqLen: 256
	});
	evalInferencer.setSession(evalSession as any);

	// Create agents
	const treeAgent = new TrigoTreeAgent(treeInferencer);
	const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);

	// Create MCTS agent with DEBUG MODE ENABLED
	const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
		numSimulations: 200,
		cPuct: 1.0,
		temperature: 1.0,
		dirichletAlpha: 0.03,
		dirichletEpsilon: 0.25
	});

	// ENABLE DEBUG MODE
	mctsAgent.debugMode = true;

	console.log("✓ MCTS Agent created with debug mode enabled\n");

	// Create game
	const game = new TrigoGame({ x: 5, y: 1, z: 1 });
	game.startGame();

	console.log("================================================================================");
	console.log("Starting Self-Play Game");
	console.log("================================================================================\n");

	let moveNumber = 0;
	let maxMoves = 20;

	while (game.getGameStatus() === "playing" && moveNumber < maxMoves) {
		const currentPlayer = game.getCurrentPlayer() === 1 ? "Black" : "White";
		console.log(`\n[Move ${moveNumber + 1}] ${currentPlayer} to move`);
		console.log("Current board:", game.getBoard().flat().flat());

		// Select move with MCTS
		const startTime = Date.now();
		const result = await mctsAgent.selectMove(game, moveNumber);
		const duration = Date.now() - startTime;

		// Log move selection
		let moveStr: string;
		if (result.move.isPass) {
			moveStr = "Pass";
		} else {
			moveStr = encodeAb0yz([result.move.x!, result.move.y!, result.move.z!], [5, 1, 1]);
		}

		console.log(`\nSelected: ${moveStr}`);
		console.log(`Root value: ${result.rootValue.toFixed(4)}`);
		console.log(`Time: ${duration}ms`);
		console.log(`Visit counts: ${Array.from(result.visitCounts.values()).reduce((a, b) => a + b, 0)}`);

		// Apply move
		if (result.move.isPass) {
			game.pass();
		} else {
			game.drop({ x: result.move.x!, y: result.move.y!, z: result.move.z! });
		}

		moveNumber++;

		// Check terminal status
		if (game.getGameStatus() === "finished") {
			console.log("\n" + "=".repeat(80));
			console.log("Game finished!");
			const territory = game.getTerritory();
			console.log(`Final territory: Black=${territory.black}, White=${territory.white}, Neutral=${territory.neutral}`);
			console.log(`Score diff (W-B): ${territory.white - territory.black}`);
			break;
		}
	}

	// Print game notation
	console.log("\n" + "=".repeat(80));
	console.log("Game TGN:");
	console.log("=".repeat(80));
	console.log(game.toTGN());

	console.log("Test complete!");
}


main().catch(console.error);