trigo / trigo-web /tools /selfPlayGames.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
raw
history blame
24.1 kB
/**
* Self-Play Dataset Generation for Trigo AI
*
* This script generates a dataset of self-play games for training the Trigo AI model.
* The AI plays against itself using either tree attention or MCTS-based agents.
*
* Features:
* - Configurable number of games to generate
* - Random or fixed board shapes (2D: 2x1x1 to 5x5x1, 3D: 2x2x2 to 3x3x3)
* - Temperature-based sampling for move diversity
* - MCTS (Monte Carlo Tree Search) mode with AlphaGo Zero algorithm
* - Automatic game termination detection (50% coverage threshold)
* - TGN format output with score notation for each game
* - Optional visit count statistics for MCTS training data
* - Per-board-shape statistics
* - Progress tracking
*
* Usage:
* npx tsx tools/selfPlayGames.ts [options]
*
* Options:
* --games <n> Number of games to generate (default: 10)
* --output <dir> Output directory (default: ./tools/output/selfplay)
* --board <shape> Board shape "X*Y*Z" or "random" (default: "random")
* --temperature <t> Sampling temperature (default: 1.0)
* --max-moves <n> Maximum moves per game (default: 300)
* --model <path> Path to tree model ONNX file
* --eval-model <path> Path to evaluation model ONNX file (for MCTS)
* --verbose Enable verbose logging
*
* MCTS Options:
* --use-mcts Enable MCTS for move selection
* --mcts-simulations <n> MCTS simulations per move (default: 600)
* --mcts-cpuct <f> PUCT exploration constant (default: 1.0)
* --mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03)
* --mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25)
* --save-visit-counts Save visit count statistics
*
* Output:
* - Each game saved as game_<hash>.tgn (hash based on content)
* - Dataset statistics in <date>_dataset_stats.json
* - Visit counts saved as game_<id>_visit_counts.json (if --save-visit-counts)
*/
import * as ort from "onnxruntime-node";
import * as path from "path";
import * as fs from "fs";
import * as crypto from "crypto";
import { fileURLToPath } from "url";
import { TrigoGame, StoneType } from "../inc/trigo/game";
import { ModelInferencer } from "../inc/modelInferencer";
import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent";
import { MCTSAgent, type MCTSConfig } from "../inc/mctsAgent";
import type { Move, BoardShape } from "../inc/trigo/types";
import {encodeAb0yz} from "../inc/trigo/ab0yz";
// ES module equivalent of __dirname
const __filename = fileURLToPath(import.meta.url);
const __dirname = path.dirname(__filename);
// Load environment variables
await loadEnvConfig();
// Default model paths from environment
const defaultModelPaths = getOnnxModelPaths();
// Board shape types
type BoardShapeTuple = [number, number, number];
/**
* Generate all board shapes in a range (inclusive)
*/
const arangeShape = (min: BoardShapeTuple, max: BoardShapeTuple): BoardShapeTuple[] => {
const result: BoardShapeTuple[] = [];
for (let x = min[0]; x <= max[0]; x++) {
for (let y = min[1]; y <= max[1]; y++) {
for (let z = min[2]; z <= max[2]; z++) {
result.push([x, y, z]);
}
}
}
return result;
};
/**
* Candidate board shapes for random selection
* - 2D boards: 2x1x1 to 19x19x1
* - 3D boards: 2x2x2 to 9x9x9
*/
const CANDIDATE_BOARD_SHAPES = [
...arangeShape([2, 1, 1], [13, 13, 1]),
...arangeShape([2, 2, 2], [5, 5, 5]),
];
// Configuration
interface GenerationConfig {
numGames: number;
outputDir: string;
temperature: number;
maxMoves: number;
verbose: boolean;
modelPath: string;
evaluationModelPath: string;
vocabSize: number;
seqLen: number;
boardShape: BoardShape | "random";
// MCTS configuration
useMCTS: boolean;
mctsSimulations: number;
mctsCPuct: number;
mctsDirichletAlpha: number;
mctsDirichletEpsilon: number;
saveVisitCounts: boolean;
}
// Game statistics
interface GameStats {
gameId: number;
boardShape: string;
moveCount: number;
maxMovesReached: boolean;
duration: number; // milliseconds
averageMoveTime: number; // milliseconds
scoreDiff: number; // white - black (positive: white wins, negative: black wins)
}
// Board shape statistics
interface BoardShapeStats {
boardShape: string;
gameCount: number;
averageScoreDiff: number;
averageMoveCount: number;
}
// Dataset statistics
interface DatasetStats {
totalGames: number;
totalMoves: number;
blackWins: number;
whiteWins: number;
resignations: number;
maxMovesReached: number;
averageGameLength: number;
averageMoveTime: number;
generationTime: number; // milliseconds
averageScoreDiff: number; // average white - black
boardShapeStats: BoardShapeStats[]; // stats per board shape
games: GameStats[];
}
/**
* Parse board shape string (e.g., "5*5*5" or "9*9*1")
* Special value "random" selects randomly from CANDIDATE_BOARD_SHAPES
*/
function parseBoardShape(shapeStr: string): BoardShape | "random" {
// Handle random selection
if (shapeStr.toLowerCase() === "random") {
return "random";
}
// Parse explicit board shape
const parts = shapeStr.split(/[^0-9]+/).filter(Boolean).map(Number);
if (parts.length !== 3) {
throw new Error(`Invalid board shape: ${shapeStr}. Expected format: "X*Y*Z" or "random"`);
}
return { x: parts[0], y: parts[1], z: parts[2] };
}
/**
* Select a random board shape from candidates
*/
function selectRandomBoardShape(): BoardShape {
const randomIndex = Math.floor(Math.random() * CANDIDATE_BOARD_SHAPES.length);
const [x, y, z] = CANDIDATE_BOARD_SHAPES[randomIndex];
return { x, y, z };
}
/**
* Parse command line arguments
*/
function parseArgs(): GenerationConfig {
const args = process.argv.slice(2);
const config: GenerationConfig = {
numGames: 10,
outputDir: path.join(__dirname, "output/selfplay"),
temperature: 1.0,
maxMoves: 300,
verbose: false,
modelPath: getAbsoluteModelPath(defaultModelPaths.treeModel),
evaluationModelPath: getAbsoluteModelPath(defaultModelPaths.evaluationModel),
vocabSize: 128,
seqLen: 256,
boardShape: "random",
// MCTS defaults
useMCTS: false,
mctsSimulations: 600,
mctsCPuct: 1.0,
mctsDirichletAlpha: 0.03,
mctsDirichletEpsilon: 0.25,
saveVisitCounts: false
};
for (let i = 0; i < args.length; i++) {
switch (args[i]) {
case "--games":
config.numGames = parseInt(args[++i], 10);
break;
case "--output":
config.outputDir = args[++i];
break;
case "--temperature":
config.temperature = parseFloat(args[++i]);
break;
case "--max-moves":
config.maxMoves = parseInt(args[++i], 10);
break;
case "--board":
config.boardShape = parseBoardShape(args[++i]);
break;
case "--model":
config.modelPath = args[++i];
break;
case "--eval-model":
config.evaluationModelPath = args[++i];
break;
case "--use-mcts":
config.useMCTS = true;
break;
case "--mcts-simulations":
config.mctsSimulations = parseInt(args[++i], 10);
break;
case "--mcts-cpuct":
config.mctsCPuct = parseFloat(args[++i]);
break;
case "--mcts-dirichlet-alpha":
config.mctsDirichletAlpha = parseFloat(args[++i]);
break;
case "--mcts-dirichlet-epsilon":
config.mctsDirichletEpsilon = parseFloat(args[++i]);
break;
case "--save-visit-counts":
config.saveVisitCounts = true;
break;
case "--verbose":
config.verbose = true;
break;
case "--help":
printHelp();
process.exit(0);
default:
if (args[i].startsWith("--")) {
console.error(`Unknown option: ${args[i]}`);
printHelp();
process.exit(1);
}
}
}
return config;
}
/**
* Print help message
*/
function printHelp(): void {
console.log(`
Usage: npx tsx tools/selfPlayGames.ts [options]
Options:
--games <n> Number of games to generate (default: 10)
--output <dir> Output directory (default: ./tools/output/selfplay)
--board <shape> Board shape "X*Y*Z" or "random" (default: "random")
--temperature <t> Sampling temperature (default: 1.0)
--max-moves <n> Maximum moves per game (default: 300)
--model <path> Path to tree model ONNX file
--eval-model <path> Path to evaluation model ONNX file (for MCTS)
--verbose Enable verbose logging
MCTS Options:
--use-mcts Enable MCTS for move selection
--mcts-simulations <n> MCTS simulations per move (default: 600)
--mcts-cpuct <f> PUCT exploration constant (default: 1.0)
--mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03)
--mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25)
--save-visit-counts Save visit count statistics
--help Show this help message
Board Shape Examples:
--board "5*5*5" Fixed 5x5x5 board for all games
--board "9*9*1" Fixed 9x9x1 (2D) board for all games
--board random Random board shape for each game (default)
Examples:
# Generate 100 games with random board shapes (no MCTS)
npx tsx tools/selfPlayGames.ts --games 100
# Generate 50 games on 5x5x5 board with MCTS
npx tsx tools/selfPlayGames.ts --games 50 --board "5*5*5" --use-mcts --mcts-simulations 600
# Generate games with custom models
npx tsx tools/selfPlayGames.ts --games 20 --model ./models/tree.onnx --eval-model ./models/eval.onnx
# Generate games with custom output directory and save visit counts
npx tsx tools/selfPlayGames.ts --games 20 --output ./my_dataset --use-mcts --save-visit-counts
`);
}
/**
* Initialize the AI agent (tree agent or MCTS agent)
*/
async function initializeAgent(config: GenerationConfig): Promise<TrigoTreeAgent | MCTSAgent> {
console.log("Initializing AI Agent...");
console.log(` Mode: ${config.useMCTS ? "MCTS Search" : "Tree Attention"}`);
console.log(` Tree Model: ${config.modelPath}`);
if (config.useMCTS) {
console.log(` Evaluation Model: ${config.evaluationModelPath}`);
console.log(` MCTS Simulations: ${config.mctsSimulations}`);
console.log(` C-PUCT: ${config.mctsCPuct}`);
}
console.log(` Vocab Size: ${config.vocabSize}`);
console.log(` Sequence Length: ${config.seqLen}`);
// Load tree model
const sessionOptions = getOnnxSessionOptions();
const treeSession = await ort.InferenceSession.create(config.modelPath, sessionOptions);
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: config.vocabSize,
seqLen: config.seqLen,
modelPath: config.modelPath
});
treeInferencer.setSession(treeSession as any);
const treeAgent = new TrigoTreeAgent(treeInferencer);
// Return tree agent if MCTS is disabled
if (!config.useMCTS) {
console.log("✓ Tree Agent initialized\n");
return treeAgent;
}
// Load evaluation model for MCTS
const evalSession = await ort.InferenceSession.create(config.evaluationModelPath, sessionOptions);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: config.vocabSize,
seqLen: config.seqLen,
modelPath: config.evaluationModelPath
});
evalInferencer.setSession(evalSession as any);
const evaluationAgent = new TrigoEvaluationAgent(evalInferencer);
// Create MCTS agent
const mctsConfig: MCTSConfig = {
numSimulations: config.mctsSimulations,
cPuct: config.mctsCPuct,
temperature: config.temperature,
dirichletAlpha: config.mctsDirichletAlpha,
dirichletEpsilon: config.mctsDirichletEpsilon
};
const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, mctsConfig);
console.log("✓ MCTS Agent initialized\n");
return mctsAgent;
}
/**
* Sample a move from probability distribution with temperature
*/
function sampleMove(scoredMoves: Array<{ move: Move; score: number; notation: string }>, temperature: number): Move {
// Apply temperature to scores (log probabilities)
const adjustedScores = scoredMoves.map((m) => m.score / temperature);
// Convert to probabilities using softmax
const maxScore = Math.max(...adjustedScores);
const expScores = adjustedScores.map((score) => Math.exp(score - maxScore));
const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
const probabilities = expScores.map((exp) => exp / sumExp);
// Sample from distribution
const random = Math.random();
let cumulative = 0;
for (let i = 0; i < scoredMoves.length; i++) {
cumulative += probabilities[i];
if (random <= cumulative) {
return scoredMoves[i].move;
}
}
// Fallback to last move (should never happen)
return scoredMoves[scoredMoves.length - 1].move;
}
/**
* Play a single self-play game
*/
async function playSelfPlayGame(
agent: TrigoTreeAgent | MCTSAgent,
gameId: number,
config: GenerationConfig
): Promise<{ game: TrigoGame; stats: GameStats; visitCounts?: number[][] }> {
// Select board shape (random or fixed)
const boardShape: BoardShape =
config.boardShape === "random"
? selectRandomBoardShape()
: config.boardShape;
const game = new TrigoGame(boardShape, {});
const startTime = Date.now();
let moveCount = 0;
let totalMoveTime = 0;
let consecutivePasses = 0;
const visitCountsHistory: number[][] = [];
// Calculate territory check threshold (50% coverage, same as MCTS)
const totalPositions = boardShape.x * boardShape.y * boardShape.z;
const coverageThreshold = Math.floor(totalPositions * 0.5);
let territoryCheckStarted = false;
if (config.verbose) {
console.log(`\nGame ${gameId} started [Board: ${boardShape.x}×${boardShape.y}×${boardShape.z}]`);
}
while (moveCount < config.maxMoves) {
// Check if we should start territory checking (after 50% coverage)
if (!territoryCheckStarted && moveCount >= coverageThreshold) {
territoryCheckStarted = true;
if (config.verbose) {
console.log(` Reached 50% coverage (${moveCount} moves), starting territory check`);
}
}
// Get current player
const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white";
const moveStartTime = Date.now();
let selectedMove: Move;
let visitCounts: Map<string, number> | undefined;
// Use MCTS or tree agent depending on agent type
if (agent instanceof MCTSAgent) {
// MCTS move selection
const result = await agent.selectMove(game, moveCount);
selectedMove = result.move;
visitCounts = result.visitCounts;
// Store visit counts if requested
if (config.saveVisitCounts && visitCounts) {
visitCountsHistory.push(Array.from(visitCounts.values()));
}
} else {
// Tree agent move selection (original method)
// Get all valid moves
const validPositions = game.validMovePositions();
const moves: Move[] = validPositions.map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
moves.push({ player: currentPlayer, isPass: true });
// If no valid moves (only pass), game is over
if (validPositions.length === 0) {
game.pass();
break;
}
// Score all moves
const scoredMoves = await agent.scoreMoves(game, moves);
if (scoredMoves.length === 0) {
break;
}
// Sort by score
scoredMoves.sort((a, b) => b.score - a.score);
// Sample move with temperature
selectedMove = sampleMove(scoredMoves, config.temperature);
}
const moveEndTime = Date.now();
totalMoveTime += moveEndTime - moveStartTime;
// Apply move
let success = false;
let moveNotation = "";
if (selectedMove.isPass) {
success = game.pass();
moveNotation = "Pass";
consecutivePasses++;
} else if (
selectedMove.x !== undefined &&
selectedMove.y !== undefined &&
selectedMove.z !== undefined
) {
success = game.drop({ x: selectedMove.x, y: selectedMove.y, z: selectedMove.z });
moveNotation = encodeAb0yz([selectedMove.x, selectedMove.y, selectedMove.z], [boardShape.x, boardShape.y, boardShape.z]);
consecutivePasses = 0;
}
process.stdout.write(moveNotation + " ");
if (!success) {
console.error(`Failed to apply move: ${moveNotation}`);
break;
}
moveCount++;
if (config.verbose) {
const player = currentPlayer === "black" ? "Black" : "White";
console.log(` Move ${moveCount}: ${player} plays ${moveNotation}`);
}
// Check for game end (two consecutive passes)
if (consecutivePasses >= 2) {
if (config.verbose) {
console.log(" Game ended: Two consecutive passes");
}
break;
}
// Check territory after 50% coverage (same as MCTS)
if (territoryCheckStarted && !selectedMove.isPass) {
// Check for natural termination (all territory claimed, no capturing moves)
if (game.isNaturallyTerminal()) {
if (config.verbose) {
const territory = game.getTerritory();
console.log(` Game ended: No neutral territory and no captures possible (settled)`);
console.log(` Black: ${territory.black}, White: ${territory.white}`);
}
break;
} else if (config.verbose) {
const territory = game.getTerritory();
if (territory.neutral === 0) {
console.log(` Territory settled but captures still possible (continuing...)`);
}
}
}
}
const endTime = Date.now();
const duration = endTime - startTime;
const averageMoveTime = moveCount > 0 ? totalMoveTime / moveCount : 0;
// Determine game result
const maxMovesReached = moveCount >= config.maxMoves;
// Get final territory and calculate score difference
const territory = game.getTerritory();
const scoreDiff = territory.white - territory.black;
const stats: GameStats = {
gameId,
boardShape: `${boardShape.x}×${boardShape.y}×${boardShape.z}`,
moveCount,
maxMovesReached,
duration,
averageMoveTime,
scoreDiff
};
return {
game,
stats,
visitCounts: config.saveVisitCounts ? visitCountsHistory : undefined
};
}
/**
* Save game to TGN file using hash-based filename
*/
function saveGame(game: TrigoGame, outputDir: string): string {
const tgn = game.toTGN({}, { markResult: true });
// Generate filename based on content hash (same as generateRandomGames.ts)
const hash = crypto.createHash('sha256').update(tgn).digest('hex');
const filename = `game_${hash.substring(0, 16)}.tgn`;
const filepath = path.join(outputDir, filename);
fs.writeFileSync(filepath, tgn, "utf-8");
return filename;
}
/**
* Generate dataset of self-play games
*/
async function generateDataset(config: GenerationConfig): Promise<void> {
console.log("=".repeat(80));
console.log("Trigo Self-Play Dataset Generation");
console.log("=".repeat(80));
console.log(`Configuration:`);
console.log(` Number of games: ${config.numGames}`);
console.log(` Output directory: ${config.outputDir}`);
console.log(` Temperature: ${config.temperature}`);
console.log(` Max moves per game: ${config.maxMoves}`);
console.log(` Verbose: ${config.verbose}`);
console.log();
// Create output directory
if (!fs.existsSync(config.outputDir)) {
fs.mkdirSync(config.outputDir, { recursive: true });
console.log(`✓ Created output directory: ${config.outputDir}\n`);
}
// Initialize agent
const agent = await initializeAgent(config);
// Generate games
const startTime = Date.now();
const datasetStats: DatasetStats = {
totalGames: 0,
totalMoves: 0,
blackWins: 0,
whiteWins: 0,
resignations: 0,
maxMovesReached: 0,
averageGameLength: 0,
averageMoveTime: 0,
generationTime: 0,
averageScoreDiff: 0,
boardShapeStats: [],
games: []
};
console.log("Generating games...");
console.log("=".repeat(80));
for (let i = 1; i <= config.numGames; i++) {
const gameStartTime = Date.now();
// Play game
const { game, stats, visitCounts } = await playSelfPlayGame(agent, i, config);
// Save game with hash-based filename
saveGame(game, config.outputDir);
// Save visit counts if available
if (visitCounts && config.saveVisitCounts) {
const visitCountsPath = path.join(config.outputDir, `game_${i}_visit_counts.json`);
fs.writeFileSync(visitCountsPath, JSON.stringify(visitCounts, null, 2), "utf-8");
}
const gameEndTime = Date.now();
const gameDuration = gameEndTime - gameStartTime;
// Update statistics
datasetStats.totalGames++;
datasetStats.totalMoves += stats.moveCount;
if (stats.maxMovesReached) datasetStats.maxMovesReached++;
datasetStats.games.push(stats);
// Progress update
const progress = ((i / config.numGames) * 100).toFixed(1);
const result = stats.scoreDiff > 0 ? `White +${stats.scoreDiff}` : stats.scoreDiff < 0 ? `Black +${Math.abs(stats.scoreDiff)}` : "Draw";
console.log(
`[${progress}%] Game ${i}/${config.numGames} [${stats.boardShape}]: ` +
`${stats.moveCount} moves, ${result}, ${(gameDuration / 1000).toFixed(1)}s`
);
}
const endTime = Date.now();
datasetStats.generationTime = endTime - startTime;
datasetStats.averageGameLength = datasetStats.totalMoves / datasetStats.totalGames;
datasetStats.averageMoveTime =
datasetStats.games.reduce((sum, g) => sum + g.averageMoveTime, 0) / datasetStats.totalGames;
// Calculate average score difference
datasetStats.averageScoreDiff =
datasetStats.games.reduce((sum, g) => sum + g.scoreDiff, 0) / datasetStats.totalGames;
// Calculate per-board-shape statistics
const shapeMap = new Map<string, GameStats[]>();
for (const game of datasetStats.games) {
if (!shapeMap.has(game.boardShape)) {
shapeMap.set(game.boardShape, []);
}
shapeMap.get(game.boardShape)!.push(game);
}
datasetStats.boardShapeStats = Array.from(shapeMap.entries()).map(([boardShape, games]) => ({
boardShape,
gameCount: games.length,
averageScoreDiff: games.reduce((sum, g) => sum + g.scoreDiff, 0) / games.length,
averageMoveCount: games.reduce((sum, g) => sum + g.moveCount, 0) / games.length
}));
// Sort by board shape for consistent output
datasetStats.boardShapeStats.sort((a, b) => a.boardShape.localeCompare(b.boardShape));
// Save dataset statistics with timestamp
const timestamp = new Date().toISOString().replace(/[:.]/g, "-").split("T")[0];
const statsFilepath = path.join(config.outputDir, `${timestamp}_dataset_stats.json`);
fs.writeFileSync(statsFilepath, JSON.stringify(datasetStats, null, 2), "utf-8");
// Print summary
console.log("=".repeat(80));
console.log("Dataset Generation Complete!");
console.log("=".repeat(80));
console.log(`Total games: ${datasetStats.totalGames}`);
console.log(`Total moves: ${datasetStats.totalMoves}`);
console.log(`Average game length: ${datasetStats.averageGameLength.toFixed(1)} moves`);
console.log(`Average move time: ${datasetStats.averageMoveTime.toFixed(1)}ms`);
console.log(`Average score diff (W-B): ${datasetStats.averageScoreDiff.toFixed(2)}`);
// Print per-board-shape statistics
if (datasetStats.boardShapeStats.length > 0) {
console.log(`\nBoard Shape Statistics:`);
for (const shapeStats of datasetStats.boardShapeStats) {
console.log(` [${shapeStats.boardShape}] ${shapeStats.gameCount} games, avg score: ${shapeStats.averageScoreDiff.toFixed(2)}, avg moves: ${shapeStats.averageMoveCount.toFixed(1)}`);
}
}
console.log(`\nBlack wins: ${datasetStats.blackWins} (${((datasetStats.blackWins / datasetStats.totalGames) * 100).toFixed(1)}%)`);
console.log(`White wins: ${datasetStats.whiteWins} (${((datasetStats.whiteWins / datasetStats.totalGames) * 100).toFixed(1)}%)`);
console.log(`Resignations: ${datasetStats.resignations}`);
console.log(`Max moves reached: ${datasetStats.maxMovesReached}`);
console.log(`Total time: ${(datasetStats.generationTime / 1000 / 60).toFixed(1)} minutes`);
console.log(`\nOutput directory: ${config.outputDir}`);
console.log(`Statistics file: ${statsFilepath}`);
console.log("=".repeat(80));
}
/**
* Main function
*/
async function main() {
try {
const config = parseArgs();
await generateDataset(config);
} catch (error) {
console.error("Error:", error);
process.exit(1);
}
}
// Run main function
main();