Spaces:
Sleeping
Sleeping
File size: 4,325 Bytes
15f353f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 |
/**
* Test terminal propagation with debug output
* Runs a single 5x1x1 game with MCTS and debug mode enabled
*/
import * as ort from "onnxruntime-node";
import { TrigoGame } from "../inc/trigo/game.js";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent.js";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent.js";
import { MCTSAgent } from "../inc/mctsAgent.js";
import { ModelInferencer } from "../inc/modelInferencer.js";
import { encodeAb0yz } from "../inc/trigo/ab0yz.js";
async function main() {
console.log("================================================================================");
console.log("MCTS Terminal Propagation Test - 5x1x1 Board");
console.log("================================================================================\n");
// Model paths
const modelDir = "/home/camus/work/trigo/trigo-web/public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500";
const treeModelPath = `${modelDir}/GPT2CausalLM_ep0019_tree.onnx`;
const evalModelPath = `${modelDir}/GPT2CausalLM_ep0019_evaluation.onnx`;
console.log("Loading models...");
console.log(` Tree: ${treeModelPath}`);
console.log(` Eval: ${evalModelPath}\n`);
// Create ONNX sessions
const sessionOptions: ort.InferenceSession.SessionOptions = {
executionProviders: ["cpu"]
};
const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions);
const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions);
// Create model inferencers
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: 128,
seqLen: 256
});
treeInferencer.setSession(treeSession as any);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: 128,
seqLen: 256
});
evalInferencer.setSession(evalSession as any);
// Create agents
const treeAgent = new TrigoTreeAgent(treeInferencer);
const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);
// Create MCTS agent with DEBUG MODE ENABLED
const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, {
numSimulations: 200,
cPuct: 1.0,
temperature: 1.0,
dirichletAlpha: 0.03,
dirichletEpsilon: 0.25
});
// ENABLE DEBUG MODE
mctsAgent.debugMode = true;
console.log("✓ MCTS Agent created with debug mode enabled\n");
// Create game
const game = new TrigoGame({ x: 5, y: 1, z: 1 });
game.startGame();
console.log("================================================================================");
console.log("Starting Self-Play Game");
console.log("================================================================================\n");
let moveNumber = 0;
let maxMoves = 20;
while (game.getGameStatus() === "playing" && moveNumber < maxMoves) {
const currentPlayer = game.getCurrentPlayer() === 1 ? "Black" : "White";
console.log(`\n[Move ${moveNumber + 1}] ${currentPlayer} to move`);
console.log("Current board:", game.getBoard().flat().flat());
// Select move with MCTS
const startTime = Date.now();
const result = await mctsAgent.selectMove(game, moveNumber);
const duration = Date.now() - startTime;
// Log move selection
let moveStr: string;
if (result.move.isPass) {
moveStr = "Pass";
} else {
moveStr = encodeAb0yz([result.move.x!, result.move.y!, result.move.z!], [5, 1, 1]);
}
console.log(`\nSelected: ${moveStr}`);
console.log(`Root value: ${result.rootValue.toFixed(4)}`);
console.log(`Time: ${duration}ms`);
console.log(`Visit counts: ${Array.from(result.visitCounts.values()).reduce((a, b) => a + b, 0)}`);
// Apply move
if (result.move.isPass) {
game.pass();
} else {
game.drop({ x: result.move.x!, y: result.move.y!, z: result.move.z! });
}
moveNumber++;
// Check terminal status
if (game.getGameStatus() === "finished") {
console.log("\n" + "=".repeat(80));
console.log("Game finished!");
const territory = game.getTerritory();
console.log(`Final territory: Black=${territory.black}, White=${territory.white}, Neutral=${territory.neutral}`);
console.log(`Score diff (W-B): ${territory.white - territory.black}`);
break;
}
}
// Print game notation
console.log("\n" + "=".repeat(80));
console.log("Game TGN:");
console.log("=".repeat(80));
console.log(game.toTGN());
console.log("Test complete!");
}
main().catch(console.error);
|