Spaces:
Sleeping
Sleeping
| /** | |
| * Self-Play Dataset Generation for Trigo AI | |
| * | |
| * This script generates a dataset of self-play games for training the Trigo AI model. | |
| * The AI plays against itself using either tree attention or MCTS-based agents. | |
| * | |
| * Features: | |
| * - Configurable number of games to generate | |
| * - Random or fixed board shapes (2D: 2x1x1 to 5x5x1, 3D: 2x2x2 to 3x3x3) | |
| * - Temperature-based sampling for move diversity | |
| * - MCTS (Monte Carlo Tree Search) mode with AlphaGo Zero algorithm | |
| * - Automatic game termination detection (50% coverage threshold) | |
| * - TGN format output with score notation for each game | |
| * - Optional visit count statistics for MCTS training data | |
| * - Per-board-shape statistics | |
| * - Progress tracking | |
| * | |
| * Usage: | |
| * npx tsx tools/selfPlayGames.ts [options] | |
| * | |
| * Options: | |
| * --games <n> Number of games to generate (default: 10) | |
| * --output <dir> Output directory (default: ./tools/output/selfplay) | |
| * --board <shape> Board shape "X*Y*Z" or "random" (default: "random") | |
| * --temperature <t> Sampling temperature (default: 1.0) | |
| * --max-moves <n> Maximum moves per game (default: 300) | |
| * --model <path> Path to tree model ONNX file | |
| * --eval-model <path> Path to evaluation model ONNX file (for MCTS) | |
| * --verbose Enable verbose logging | |
| * | |
| * MCTS Options: | |
| * --use-mcts Enable MCTS for move selection | |
| * --mcts-simulations <n> MCTS simulations per move (default: 600) | |
| * --mcts-cpuct <f> PUCT exploration constant (default: 1.0) | |
| * --mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03) | |
| * --mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25) | |
| * --save-visit-counts Save visit count statistics | |
| * | |
| * Output: | |
| * - Each game saved as game_<hash>.tgn (hash based on content) | |
| * - Dataset statistics in <date>_dataset_stats.json | |
| * - Visit counts saved as game_<id>_visit_counts.json (if --save-visit-counts) | |
| */ | |
| import * as ort from "onnxruntime-node"; | |
| import * as path from "path"; | |
| import * as fs from "fs"; | |
| import * as crypto from "crypto"; | |
| import { fileURLToPath } from "url"; | |
| import { TrigoGame, StoneType } from "../inc/trigo/game"; | |
| import { ModelInferencer } from "../inc/modelInferencer"; | |
| import { loadEnvConfig, getOnnxModelPaths, getAbsoluteModelPath, getOnnxSessionOptions } from "../inc/config"; | |
| import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; | |
| import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent"; | |
| import { MCTSAgent, type MCTSConfig } from "../inc/mctsAgent"; | |
| import type { Move, BoardShape } from "../inc/trigo/types"; | |
| import {encodeAb0yz} from "../inc/trigo/ab0yz"; | |
| // ES module equivalent of __dirname | |
| const __filename = fileURLToPath(import.meta.url); | |
| const __dirname = path.dirname(__filename); | |
| // Load environment variables | |
| await loadEnvConfig(); | |
| // Default model paths from environment | |
| const defaultModelPaths = getOnnxModelPaths(); | |
| // Board shape types | |
| type BoardShapeTuple = [number, number, number]; | |
| /** | |
| * Generate all board shapes in a range (inclusive) | |
| */ | |
| const arangeShape = (min: BoardShapeTuple, max: BoardShapeTuple): BoardShapeTuple[] => { | |
| const result: BoardShapeTuple[] = []; | |
| for (let x = min[0]; x <= max[0]; x++) { | |
| for (let y = min[1]; y <= max[1]; y++) { | |
| for (let z = min[2]; z <= max[2]; z++) { | |
| result.push([x, y, z]); | |
| } | |
| } | |
| } | |
| return result; | |
| }; | |
| /** | |
| * Candidate board shapes for random selection | |
| * - 2D boards: 2x1x1 to 19x19x1 | |
| * - 3D boards: 2x2x2 to 9x9x9 | |
| */ | |
| const CANDIDATE_BOARD_SHAPES = [ | |
| ...arangeShape([2, 1, 1], [13, 13, 1]), | |
| ...arangeShape([2, 2, 2], [5, 5, 5]), | |
| ]; | |
| // Configuration | |
| interface GenerationConfig { | |
| numGames: number; | |
| outputDir: string; | |
| temperature: number; | |
| maxMoves: number; | |
| verbose: boolean; | |
| modelPath: string; | |
| evaluationModelPath: string; | |
| vocabSize: number; | |
| seqLen: number; | |
| boardShape: BoardShape | "random"; | |
| // MCTS configuration | |
| useMCTS: boolean; | |
| mctsSimulations: number; | |
| mctsCPuct: number; | |
| mctsDirichletAlpha: number; | |
| mctsDirichletEpsilon: number; | |
| saveVisitCounts: boolean; | |
| } | |
| // Game statistics | |
| interface GameStats { | |
| gameId: number; | |
| boardShape: string; | |
| moveCount: number; | |
| maxMovesReached: boolean; | |
| duration: number; // milliseconds | |
| averageMoveTime: number; // milliseconds | |
| scoreDiff: number; // white - black (positive: white wins, negative: black wins) | |
| } | |
| // Board shape statistics | |
| interface BoardShapeStats { | |
| boardShape: string; | |
| gameCount: number; | |
| averageScoreDiff: number; | |
| averageMoveCount: number; | |
| } | |
| // Dataset statistics | |
| interface DatasetStats { | |
| totalGames: number; | |
| totalMoves: number; | |
| blackWins: number; | |
| whiteWins: number; | |
| resignations: number; | |
| maxMovesReached: number; | |
| averageGameLength: number; | |
| averageMoveTime: number; | |
| generationTime: number; // milliseconds | |
| averageScoreDiff: number; // average white - black | |
| boardShapeStats: BoardShapeStats[]; // stats per board shape | |
| games: GameStats[]; | |
| } | |
| /** | |
| * Parse board shape string (e.g., "5*5*5" or "9*9*1") | |
| * Special value "random" selects randomly from CANDIDATE_BOARD_SHAPES | |
| */ | |
| function parseBoardShape(shapeStr: string): BoardShape | "random" { | |
| // Handle random selection | |
| if (shapeStr.toLowerCase() === "random") { | |
| return "random"; | |
| } | |
| // Parse explicit board shape | |
| const parts = shapeStr.split(/[^0-9]+/).filter(Boolean).map(Number); | |
| if (parts.length !== 3) { | |
| throw new Error(`Invalid board shape: ${shapeStr}. Expected format: "X*Y*Z" or "random"`); | |
| } | |
| return { x: parts[0], y: parts[1], z: parts[2] }; | |
| } | |
| /** | |
| * Select a random board shape from candidates | |
| */ | |
| function selectRandomBoardShape(): BoardShape { | |
| const randomIndex = Math.floor(Math.random() * CANDIDATE_BOARD_SHAPES.length); | |
| const [x, y, z] = CANDIDATE_BOARD_SHAPES[randomIndex]; | |
| return { x, y, z }; | |
| } | |
| /** | |
| * Parse command line arguments | |
| */ | |
| function parseArgs(): GenerationConfig { | |
| const args = process.argv.slice(2); | |
| const config: GenerationConfig = { | |
| numGames: 10, | |
| outputDir: path.join(__dirname, "output/selfplay"), | |
| temperature: 1.0, | |
| maxMoves: 300, | |
| verbose: false, | |
| modelPath: getAbsoluteModelPath(defaultModelPaths.treeModel), | |
| evaluationModelPath: getAbsoluteModelPath(defaultModelPaths.evaluationModel), | |
| vocabSize: 128, | |
| seqLen: 256, | |
| boardShape: "random", | |
| // MCTS defaults | |
| useMCTS: false, | |
| mctsSimulations: 600, | |
| mctsCPuct: 1.0, | |
| mctsDirichletAlpha: 0.03, | |
| mctsDirichletEpsilon: 0.25, | |
| saveVisitCounts: false | |
| }; | |
| for (let i = 0; i < args.length; i++) { | |
| switch (args[i]) { | |
| case "--games": | |
| config.numGames = parseInt(args[++i], 10); | |
| break; | |
| case "--output": | |
| config.outputDir = args[++i]; | |
| break; | |
| case "--temperature": | |
| config.temperature = parseFloat(args[++i]); | |
| break; | |
| case "--max-moves": | |
| config.maxMoves = parseInt(args[++i], 10); | |
| break; | |
| case "--board": | |
| config.boardShape = parseBoardShape(args[++i]); | |
| break; | |
| case "--model": | |
| config.modelPath = args[++i]; | |
| break; | |
| case "--eval-model": | |
| config.evaluationModelPath = args[++i]; | |
| break; | |
| case "--use-mcts": | |
| config.useMCTS = true; | |
| break; | |
| case "--mcts-simulations": | |
| config.mctsSimulations = parseInt(args[++i], 10); | |
| break; | |
| case "--mcts-cpuct": | |
| config.mctsCPuct = parseFloat(args[++i]); | |
| break; | |
| case "--mcts-dirichlet-alpha": | |
| config.mctsDirichletAlpha = parseFloat(args[++i]); | |
| break; | |
| case "--mcts-dirichlet-epsilon": | |
| config.mctsDirichletEpsilon = parseFloat(args[++i]); | |
| break; | |
| case "--save-visit-counts": | |
| config.saveVisitCounts = true; | |
| break; | |
| case "--verbose": | |
| config.verbose = true; | |
| break; | |
| case "--help": | |
| printHelp(); | |
| process.exit(0); | |
| default: | |
| if (args[i].startsWith("--")) { | |
| console.error(`Unknown option: ${args[i]}`); | |
| printHelp(); | |
| process.exit(1); | |
| } | |
| } | |
| } | |
| return config; | |
| } | |
| /** | |
| * Print help message | |
| */ | |
| function printHelp(): void { | |
| console.log(` | |
| Usage: npx tsx tools/selfPlayGames.ts [options] | |
| Options: | |
| --games <n> Number of games to generate (default: 10) | |
| --output <dir> Output directory (default: ./tools/output/selfplay) | |
| --board <shape> Board shape "X*Y*Z" or "random" (default: "random") | |
| --temperature <t> Sampling temperature (default: 1.0) | |
| --max-moves <n> Maximum moves per game (default: 300) | |
| --model <path> Path to tree model ONNX file | |
| --eval-model <path> Path to evaluation model ONNX file (for MCTS) | |
| --verbose Enable verbose logging | |
| MCTS Options: | |
| --use-mcts Enable MCTS for move selection | |
| --mcts-simulations <n> MCTS simulations per move (default: 600) | |
| --mcts-cpuct <f> PUCT exploration constant (default: 1.0) | |
| --mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03) | |
| --mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25) | |
| --save-visit-counts Save visit count statistics | |
| --help Show this help message | |
| Board Shape Examples: | |
| --board "5*5*5" Fixed 5x5x5 board for all games | |
| --board "9*9*1" Fixed 9x9x1 (2D) board for all games | |
| --board random Random board shape for each game (default) | |
| Examples: | |
| # Generate 100 games with random board shapes (no MCTS) | |
| npx tsx tools/selfPlayGames.ts --games 100 | |
| # Generate 50 games on 5x5x5 board with MCTS | |
| npx tsx tools/selfPlayGames.ts --games 50 --board "5*5*5" --use-mcts --mcts-simulations 600 | |
| # Generate games with custom models | |
| npx tsx tools/selfPlayGames.ts --games 20 --model ./models/tree.onnx --eval-model ./models/eval.onnx | |
| # Generate games with custom output directory and save visit counts | |
| npx tsx tools/selfPlayGames.ts --games 20 --output ./my_dataset --use-mcts --save-visit-counts | |
| `); | |
| } | |
| /** | |
| * Initialize the AI agent (tree agent or MCTS agent) | |
| */ | |
| async function initializeAgent(config: GenerationConfig): Promise<TrigoTreeAgent | MCTSAgent> { | |
| console.log("Initializing AI Agent..."); | |
| console.log(` Mode: ${config.useMCTS ? "MCTS Search" : "Tree Attention"}`); | |
| console.log(` Tree Model: ${config.modelPath}`); | |
| if (config.useMCTS) { | |
| console.log(` Evaluation Model: ${config.evaluationModelPath}`); | |
| console.log(` MCTS Simulations: ${config.mctsSimulations}`); | |
| console.log(` C-PUCT: ${config.mctsCPuct}`); | |
| } | |
| console.log(` Vocab Size: ${config.vocabSize}`); | |
| console.log(` Sequence Length: ${config.seqLen}`); | |
| // Load tree model | |
| const sessionOptions = getOnnxSessionOptions(); | |
| const treeSession = await ort.InferenceSession.create(config.modelPath, sessionOptions); | |
| const treeInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: config.vocabSize, | |
| seqLen: config.seqLen, | |
| modelPath: config.modelPath | |
| }); | |
| treeInferencer.setSession(treeSession as any); | |
| const treeAgent = new TrigoTreeAgent(treeInferencer); | |
| // Return tree agent if MCTS is disabled | |
| if (!config.useMCTS) { | |
| console.log("✓ Tree Agent initialized\n"); | |
| return treeAgent; | |
| } | |
| // Load evaluation model for MCTS | |
| const evalSession = await ort.InferenceSession.create(config.evaluationModelPath, sessionOptions); | |
| const evalInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: config.vocabSize, | |
| seqLen: config.seqLen, | |
| modelPath: config.evaluationModelPath | |
| }); | |
| evalInferencer.setSession(evalSession as any); | |
| const evaluationAgent = new TrigoEvaluationAgent(evalInferencer); | |
| // Create MCTS agent | |
| const mctsConfig: MCTSConfig = { | |
| numSimulations: config.mctsSimulations, | |
| cPuct: config.mctsCPuct, | |
| temperature: config.temperature, | |
| dirichletAlpha: config.mctsDirichletAlpha, | |
| dirichletEpsilon: config.mctsDirichletEpsilon | |
| }; | |
| const mctsAgent = new MCTSAgent(treeAgent, evaluationAgent, mctsConfig); | |
| console.log("✓ MCTS Agent initialized\n"); | |
| return mctsAgent; | |
| } | |
| /** | |
| * Sample a move from probability distribution with temperature | |
| */ | |
| function sampleMove(scoredMoves: Array<{ move: Move; score: number; notation: string }>, temperature: number): Move { | |
| // Apply temperature to scores (log probabilities) | |
| const adjustedScores = scoredMoves.map((m) => m.score / temperature); | |
| // Convert to probabilities using softmax | |
| const maxScore = Math.max(...adjustedScores); | |
| const expScores = adjustedScores.map((score) => Math.exp(score - maxScore)); | |
| const sumExp = expScores.reduce((sum, exp) => sum + exp, 0); | |
| const probabilities = expScores.map((exp) => exp / sumExp); | |
| // Sample from distribution | |
| const random = Math.random(); | |
| let cumulative = 0; | |
| for (let i = 0; i < scoredMoves.length; i++) { | |
| cumulative += probabilities[i]; | |
| if (random <= cumulative) { | |
| return scoredMoves[i].move; | |
| } | |
| } | |
| // Fallback to last move (should never happen) | |
| return scoredMoves[scoredMoves.length - 1].move; | |
| } | |
| /** | |
| * Play a single self-play game | |
| */ | |
| async function playSelfPlayGame( | |
| agent: TrigoTreeAgent | MCTSAgent, | |
| gameId: number, | |
| config: GenerationConfig | |
| ): Promise<{ game: TrigoGame; stats: GameStats; visitCounts?: number[][] }> { | |
| // Select board shape (random or fixed) | |
| const boardShape: BoardShape = | |
| config.boardShape === "random" | |
| ? selectRandomBoardShape() | |
| : config.boardShape; | |
| const game = new TrigoGame(boardShape, {}); | |
| const startTime = Date.now(); | |
| let moveCount = 0; | |
| let totalMoveTime = 0; | |
| let consecutivePasses = 0; | |
| const visitCountsHistory: number[][] = []; | |
| // Calculate territory check threshold (50% coverage, same as MCTS) | |
| const totalPositions = boardShape.x * boardShape.y * boardShape.z; | |
| const coverageThreshold = Math.floor(totalPositions * 0.5); | |
| let territoryCheckStarted = false; | |
| if (config.verbose) { | |
| console.log(`\nGame ${gameId} started [Board: ${boardShape.x}×${boardShape.y}×${boardShape.z}]`); | |
| } | |
| while (moveCount < config.maxMoves) { | |
| // Check if we should start territory checking (after 50% coverage) | |
| if (!territoryCheckStarted && moveCount >= coverageThreshold) { | |
| territoryCheckStarted = true; | |
| if (config.verbose) { | |
| console.log(` Reached 50% coverage (${moveCount} moves), starting territory check`); | |
| } | |
| } | |
| // Get current player | |
| const currentPlayer = game.getCurrentPlayer() === StoneType.BLACK ? "black" : "white"; | |
| const moveStartTime = Date.now(); | |
| let selectedMove: Move; | |
| let visitCounts: Map<string, number> | undefined; | |
| // Use MCTS or tree agent depending on agent type | |
| if (agent instanceof MCTSAgent) { | |
| // MCTS move selection | |
| const result = await agent.selectMove(game, moveCount); | |
| selectedMove = result.move; | |
| visitCounts = result.visitCounts; | |
| // Store visit counts if requested | |
| if (config.saveVisitCounts && visitCounts) { | |
| visitCountsHistory.push(Array.from(visitCounts.values())); | |
| } | |
| } else { | |
| // Tree agent move selection (original method) | |
| // Get all valid moves | |
| const validPositions = game.validMovePositions(); | |
| const moves: Move[] = validPositions.map((pos) => ({ | |
| x: pos.x, | |
| y: pos.y, | |
| z: pos.z, | |
| player: currentPlayer | |
| })); | |
| moves.push({ player: currentPlayer, isPass: true }); | |
| // If no valid moves (only pass), game is over | |
| if (validPositions.length === 0) { | |
| game.pass(); | |
| break; | |
| } | |
| // Score all moves | |
| const scoredMoves = await agent.scoreMoves(game, moves); | |
| if (scoredMoves.length === 0) { | |
| break; | |
| } | |
| // Sort by score | |
| scoredMoves.sort((a, b) => b.score - a.score); | |
| // Sample move with temperature | |
| selectedMove = sampleMove(scoredMoves, config.temperature); | |
| } | |
| const moveEndTime = Date.now(); | |
| totalMoveTime += moveEndTime - moveStartTime; | |
| // Apply move | |
| let success = false; | |
| let moveNotation = ""; | |
| if (selectedMove.isPass) { | |
| success = game.pass(); | |
| moveNotation = "Pass"; | |
| consecutivePasses++; | |
| } else if ( | |
| selectedMove.x !== undefined && | |
| selectedMove.y !== undefined && | |
| selectedMove.z !== undefined | |
| ) { | |
| success = game.drop({ x: selectedMove.x, y: selectedMove.y, z: selectedMove.z }); | |
| moveNotation = encodeAb0yz([selectedMove.x, selectedMove.y, selectedMove.z], [boardShape.x, boardShape.y, boardShape.z]); | |
| consecutivePasses = 0; | |
| } | |
| process.stdout.write(moveNotation + " "); | |
| if (!success) { | |
| console.error(`Failed to apply move: ${moveNotation}`); | |
| break; | |
| } | |
| moveCount++; | |
| if (config.verbose) { | |
| const player = currentPlayer === "black" ? "Black" : "White"; | |
| console.log(` Move ${moveCount}: ${player} plays ${moveNotation}`); | |
| } | |
| // Check for game end (two consecutive passes) | |
| if (consecutivePasses >= 2) { | |
| if (config.verbose) { | |
| console.log(" Game ended: Two consecutive passes"); | |
| } | |
| break; | |
| } | |
| // Check territory after 50% coverage (same as MCTS) | |
| if (territoryCheckStarted && !selectedMove.isPass) { | |
| // Check for natural termination (all territory claimed, no capturing moves) | |
| if (game.isNaturallyTerminal()) { | |
| if (config.verbose) { | |
| const territory = game.getTerritory(); | |
| console.log(` Game ended: No neutral territory and no captures possible (settled)`); | |
| console.log(` Black: ${territory.black}, White: ${territory.white}`); | |
| } | |
| break; | |
| } else if (config.verbose) { | |
| const territory = game.getTerritory(); | |
| if (territory.neutral === 0) { | |
| console.log(` Territory settled but captures still possible (continuing...)`); | |
| } | |
| } | |
| } | |
| } | |
| const endTime = Date.now(); | |
| const duration = endTime - startTime; | |
| const averageMoveTime = moveCount > 0 ? totalMoveTime / moveCount : 0; | |
| // Determine game result | |
| const maxMovesReached = moveCount >= config.maxMoves; | |
| // Get final territory and calculate score difference | |
| const territory = game.getTerritory(); | |
| const scoreDiff = territory.white - territory.black; | |
| const stats: GameStats = { | |
| gameId, | |
| boardShape: `${boardShape.x}×${boardShape.y}×${boardShape.z}`, | |
| moveCount, | |
| maxMovesReached, | |
| duration, | |
| averageMoveTime, | |
| scoreDiff | |
| }; | |
| return { | |
| game, | |
| stats, | |
| visitCounts: config.saveVisitCounts ? visitCountsHistory : undefined | |
| }; | |
| } | |
| /** | |
| * Save game to TGN file using hash-based filename | |
| */ | |
| function saveGame(game: TrigoGame, outputDir: string): string { | |
| const tgn = game.toTGN({}, { markResult: true }); | |
| // Generate filename based on content hash (same as generateRandomGames.ts) | |
| const hash = crypto.createHash('sha256').update(tgn).digest('hex'); | |
| const filename = `game_${hash.substring(0, 16)}.tgn`; | |
| const filepath = path.join(outputDir, filename); | |
| fs.writeFileSync(filepath, tgn, "utf-8"); | |
| return filename; | |
| } | |
| /** | |
| * Generate dataset of self-play games | |
| */ | |
| async function generateDataset(config: GenerationConfig): Promise<void> { | |
| console.log("=".repeat(80)); | |
| console.log("Trigo Self-Play Dataset Generation"); | |
| console.log("=".repeat(80)); | |
| console.log(`Configuration:`); | |
| console.log(` Number of games: ${config.numGames}`); | |
| console.log(` Output directory: ${config.outputDir}`); | |
| console.log(` Temperature: ${config.temperature}`); | |
| console.log(` Max moves per game: ${config.maxMoves}`); | |
| console.log(` Verbose: ${config.verbose}`); | |
| console.log(); | |
| // Create output directory | |
| if (!fs.existsSync(config.outputDir)) { | |
| fs.mkdirSync(config.outputDir, { recursive: true }); | |
| console.log(`✓ Created output directory: ${config.outputDir}\n`); | |
| } | |
| // Initialize agent | |
| const agent = await initializeAgent(config); | |
| // Generate games | |
| const startTime = Date.now(); | |
| const datasetStats: DatasetStats = { | |
| totalGames: 0, | |
| totalMoves: 0, | |
| blackWins: 0, | |
| whiteWins: 0, | |
| resignations: 0, | |
| maxMovesReached: 0, | |
| averageGameLength: 0, | |
| averageMoveTime: 0, | |
| generationTime: 0, | |
| averageScoreDiff: 0, | |
| boardShapeStats: [], | |
| games: [] | |
| }; | |
| console.log("Generating games..."); | |
| console.log("=".repeat(80)); | |
| for (let i = 1; i <= config.numGames; i++) { | |
| const gameStartTime = Date.now(); | |
| // Play game | |
| const { game, stats, visitCounts } = await playSelfPlayGame(agent, i, config); | |
| // Save game with hash-based filename | |
| saveGame(game, config.outputDir); | |
| // Save visit counts if available | |
| if (visitCounts && config.saveVisitCounts) { | |
| const visitCountsPath = path.join(config.outputDir, `game_${i}_visit_counts.json`); | |
| fs.writeFileSync(visitCountsPath, JSON.stringify(visitCounts, null, 2), "utf-8"); | |
| } | |
| const gameEndTime = Date.now(); | |
| const gameDuration = gameEndTime - gameStartTime; | |
| // Update statistics | |
| datasetStats.totalGames++; | |
| datasetStats.totalMoves += stats.moveCount; | |
| if (stats.maxMovesReached) datasetStats.maxMovesReached++; | |
| datasetStats.games.push(stats); | |
| // Progress update | |
| const progress = ((i / config.numGames) * 100).toFixed(1); | |
| const result = stats.scoreDiff > 0 ? `White +${stats.scoreDiff}` : stats.scoreDiff < 0 ? `Black +${Math.abs(stats.scoreDiff)}` : "Draw"; | |
| console.log( | |
| `[${progress}%] Game ${i}/${config.numGames} [${stats.boardShape}]: ` + | |
| `${stats.moveCount} moves, ${result}, ${(gameDuration / 1000).toFixed(1)}s` | |
| ); | |
| } | |
| const endTime = Date.now(); | |
| datasetStats.generationTime = endTime - startTime; | |
| datasetStats.averageGameLength = datasetStats.totalMoves / datasetStats.totalGames; | |
| datasetStats.averageMoveTime = | |
| datasetStats.games.reduce((sum, g) => sum + g.averageMoveTime, 0) / datasetStats.totalGames; | |
| // Calculate average score difference | |
| datasetStats.averageScoreDiff = | |
| datasetStats.games.reduce((sum, g) => sum + g.scoreDiff, 0) / datasetStats.totalGames; | |
| // Calculate per-board-shape statistics | |
| const shapeMap = new Map<string, GameStats[]>(); | |
| for (const game of datasetStats.games) { | |
| if (!shapeMap.has(game.boardShape)) { | |
| shapeMap.set(game.boardShape, []); | |
| } | |
| shapeMap.get(game.boardShape)!.push(game); | |
| } | |
| datasetStats.boardShapeStats = Array.from(shapeMap.entries()).map(([boardShape, games]) => ({ | |
| boardShape, | |
| gameCount: games.length, | |
| averageScoreDiff: games.reduce((sum, g) => sum + g.scoreDiff, 0) / games.length, | |
| averageMoveCount: games.reduce((sum, g) => sum + g.moveCount, 0) / games.length | |
| })); | |
| // Sort by board shape for consistent output | |
| datasetStats.boardShapeStats.sort((a, b) => a.boardShape.localeCompare(b.boardShape)); | |
| // Save dataset statistics with timestamp | |
| const timestamp = new Date().toISOString().replace(/[:.]/g, "-").split("T")[0]; | |
| const statsFilepath = path.join(config.outputDir, `${timestamp}_dataset_stats.json`); | |
| fs.writeFileSync(statsFilepath, JSON.stringify(datasetStats, null, 2), "utf-8"); | |
| // Print summary | |
| console.log("=".repeat(80)); | |
| console.log("Dataset Generation Complete!"); | |
| console.log("=".repeat(80)); | |
| console.log(`Total games: ${datasetStats.totalGames}`); | |
| console.log(`Total moves: ${datasetStats.totalMoves}`); | |
| console.log(`Average game length: ${datasetStats.averageGameLength.toFixed(1)} moves`); | |
| console.log(`Average move time: ${datasetStats.averageMoveTime.toFixed(1)}ms`); | |
| console.log(`Average score diff (W-B): ${datasetStats.averageScoreDiff.toFixed(2)}`); | |
| // Print per-board-shape statistics | |
| if (datasetStats.boardShapeStats.length > 0) { | |
| console.log(`\nBoard Shape Statistics:`); | |
| for (const shapeStats of datasetStats.boardShapeStats) { | |
| console.log(` [${shapeStats.boardShape}] ${shapeStats.gameCount} games, avg score: ${shapeStats.averageScoreDiff.toFixed(2)}, avg moves: ${shapeStats.averageMoveCount.toFixed(1)}`); | |
| } | |
| } | |
| console.log(`\nBlack wins: ${datasetStats.blackWins} (${((datasetStats.blackWins / datasetStats.totalGames) * 100).toFixed(1)}%)`); | |
| console.log(`White wins: ${datasetStats.whiteWins} (${((datasetStats.whiteWins / datasetStats.totalGames) * 100).toFixed(1)}%)`); | |
| console.log(`Resignations: ${datasetStats.resignations}`); | |
| console.log(`Max moves reached: ${datasetStats.maxMovesReached}`); | |
| console.log(`Total time: ${(datasetStats.generationTime / 1000 / 60).toFixed(1)} minutes`); | |
| console.log(`\nOutput directory: ${config.outputDir}`); | |
| console.log(`Statistics file: ${statsFilepath}`); | |
| console.log("=".repeat(80)); | |
| } | |
| /** | |
| * Main function | |
| */ | |
| async function main() { | |
| try { | |
| const config = parseArgs(); | |
| await generateDataset(config); | |
| } catch (error) { | |
| console.error("Error:", error); | |
| process.exit(1); | |
| } | |
| } | |
| // Run main function | |
| main(); | |