Spaces:
Sleeping
Sleeping
| /** | |
| * Model vs Model Battle Script | |
| * | |
| * This script enables different ONNX model versions to battle against each other. | |
| * Features: | |
| * - Fair color rotation (each model plays equal games as Black and White) | |
| * - Fixed board size | |
| * - Win rate statistics | |
| * - Optional MCTS search (disabled by default) | |
| * | |
| * Usage: npx tsx tools/modelBattle.ts [options] | |
| */ | |
| import * as ort from "onnxruntime-node"; | |
| import * as path from "path"; | |
| import * as fs from "fs"; | |
| import { TrigoGame, StoneType } from "../inc/trigo/game"; | |
| import { ModelInferencer } from "../inc/modelInferencer"; | |
| import { TrigoTreeAgent } from "../inc/trigoTreeAgent"; | |
| import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent"; | |
| import { MCTSAgent, MCTSConfig } from "../inc/mctsAgent"; | |
| import type { Move, BoardShape } from "../inc/trigo/types"; | |
| import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config"; | |
| // ============================================================================ | |
| // Constants | |
| // ============================================================================ | |
| const DEFAULT_VOCAB_SIZE = 128; | |
| const DEFAULT_SEQ_LEN = 256; | |
| const DEFAULT_NUM_GAMES = 10; | |
| const DEFAULT_TEMPERATURE = 1.0; | |
| const DEFAULT_MAX_MOVES = 300; | |
| const DEFAULT_MCTS_SIMULATIONS = 600; | |
| const DEFAULT_MCTS_CPUCT = 1.0; | |
| const DEFAULT_MCTS_DIRICHLET_ALPHA = 0.03; | |
| const DEFAULT_MCTS_DIRICHLET_EPSILON = 0.0; // Zero noise for deterministic evaluation | |
| const DEFAULT_KOMI = 0.5; // Komi: Black compensates White (White's advantage) | |
| // ============================================================================ | |
| // Interfaces | |
| // ============================================================================ | |
| interface BattleConfig { | |
| model1Dir: string; | |
| model2Dir: string; | |
| model1TreePath: string; | |
| model2TreePath: string; | |
| model1EvalPath?: string; | |
| model2EvalPath?: string; | |
| boardShape: BoardShape; | |
| numGames: number; | |
| temperature: number; | |
| maxMoves: number; | |
| useMCTS: boolean; | |
| mctsSimulations: number; | |
| mctsCPuct: number; | |
| mctsDirichletAlpha: number; | |
| mctsDirichletEpsilon: number; | |
| outputDir?: string; | |
| verbose: boolean; | |
| vocabSize: number; | |
| seqLen: number; | |
| } | |
| interface BattleResult { | |
| gameId: number; | |
| model1AsBlack: boolean; | |
| winner: "model1" | "model2" | "draw"; | |
| model1Score: number; | |
| model2Score: number; | |
| scoreDiff: number; | |
| moveCount: number; | |
| duration: number; | |
| tgn?: string; // TGN format game record | |
| } | |
| interface BattleStatistics { | |
| totalGames: number; | |
| model1Wins: number; | |
| model2Wins: number; | |
| draws: number; | |
| model1WinRate: number; | |
| model2WinRate: number; | |
| model1AsBlackWins: number; | |
| model1AsWhiteWins: number; | |
| model2AsBlackWins: number; | |
| model2AsWhiteWins: number; | |
| averageGameLength: number; | |
| averageScoreDiff: number; | |
| totalTime: number; | |
| results: BattleResult[]; | |
| } | |
| type Agent = TrigoTreeAgent | MCTSAgent; | |
| // ============================================================================ | |
| // Argument Parsing | |
| // ============================================================================ | |
| function printUsage(): void { | |
| console.log(` | |
| Usage: npx tsx tools/modelBattle.ts [options] | |
| Required: | |
| --model1 <path> Model 1 base path (without suffix, e.g., ./dir/ModelName) | |
| --model2 <path> Model 2 base path (without suffix, e.g., ./dir/ModelName) | |
| --board <shape> Board shape "X*Y*Z" (e.g., "5*5*5") | |
| Options: | |
| --games <n> Number of games (default: 10, must be even) | |
| --temperature <t> Sampling temperature (default: 1.0) | |
| --max-moves <n> Maximum moves per game (default: 300) | |
| --output <dir> Output directory for results | |
| --verbose Enable verbose logging | |
| MCTS Options: | |
| --use-mcts Enable MCTS for move selection (requires evaluation model) | |
| --mcts-simulations <n> MCTS simulations per move (default: 600) | |
| --mcts-cpuct <f> PUCT exploration constant (default: 1.0) | |
| --mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03) | |
| --mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25) | |
| --help Show this help message | |
| Model Path Format: | |
| Provide the base path WITHOUT the suffix. The script auto-appends: | |
| - _tree.onnx -> Tree model (required) | |
| - _evaluation.onnx -> Evaluation model (auto-detected, required for MCTS) | |
| Examples: | |
| # Battle two models on 5x5x5 board for 20 games | |
| npx tsx tools/modelBattle.ts \\ | |
| --model1 ./public/onnx/20251215-trigo-value-llama-l6-h64-251211/LlamaCausalLM_ep0045 \\ | |
| --model2 ./public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500/GPT2LMHeadModel_ep0100 \\ | |
| --board "5*5*5" --games 20 | |
| # Battle with MCTS enabled | |
| npx tsx tools/modelBattle.ts \\ | |
| --model1 ./public/onnx/model_v1/Model_ep0100 \\ | |
| --model2 ./public/onnx/model_v2/Model_ep0100 \\ | |
| --board "5*5*5" --games 20 --use-mcts --mcts-simulations 400 | |
| `); | |
| } | |
| function parseBoardShape(shapeStr: string): BoardShape | null { | |
| const match = shapeStr.match(/^(\d+)\*(\d+)\*(\d+)$/); | |
| if (!match) return null; | |
| return { | |
| x: parseInt(match[1], 10), | |
| y: parseInt(match[2], 10), | |
| z: parseInt(match[3], 10) | |
| }; | |
| } | |
| /** | |
| * Resolve model paths from base path (without suffix) | |
| * Auto-appends _tree.onnx and _evaluation.onnx suffixes | |
| * Returns { treePath, evalPath } or null if tree model not found | |
| */ | |
| function resolveModelPaths(basePath: string): { treePath: string; evalPath?: string } | null { | |
| const treePath = `${basePath}_tree.onnx`; | |
| const evalPath = `${basePath}_evaluation.onnx`; | |
| if (!fs.existsSync(treePath)) { | |
| console.error(`Tree model not found: ${treePath}`); | |
| return null; | |
| } | |
| return { | |
| treePath, | |
| evalPath: fs.existsSync(evalPath) ? evalPath : undefined | |
| }; | |
| } | |
| /** | |
| * Check if next argument exists and is not a flag | |
| */ | |
| function requireNextArg(args: string[], i: number, flag: string): string | null { | |
| const next = args[i + 1]; | |
| if (!next || next.startsWith("-")) { | |
| console.error(`Error: ${flag} requires a value`); | |
| printUsage(); | |
| return null; | |
| } | |
| return next; | |
| } | |
| /** | |
| * Parse and validate integer argument | |
| */ | |
| function parseIntArg(value: string, flag: string, minValue?: number): number | null { | |
| const n = parseInt(value, 10); | |
| if (Number.isNaN(n)) { | |
| console.error(`Error: Invalid integer for ${flag}: ${value}`); | |
| return null; | |
| } | |
| if (minValue !== undefined && n < minValue) { | |
| console.error(`Error: ${flag} must be >= ${minValue}`); | |
| return null; | |
| } | |
| return n; | |
| } | |
| /** | |
| * Parse and validate float argument | |
| */ | |
| function parseFloatArg(value: string, flag: string, minValue?: number): number | null { | |
| const n = parseFloat(value); | |
| if (Number.isNaN(n) || !Number.isFinite(n)) { | |
| console.error(`Error: Invalid number for ${flag}: ${value}`); | |
| return null; | |
| } | |
| if (minValue !== undefined && n < minValue) { | |
| console.error(`Error: ${flag} must be >= ${minValue}`); | |
| return null; | |
| } | |
| return n; | |
| } | |
| function parseArgs(): BattleConfig | null { | |
| const args = process.argv.slice(2); | |
| if (args.includes("--help") || args.includes("-h")) { | |
| printUsage(); | |
| return null; | |
| } | |
| let model1Base = ""; | |
| let model2Base = ""; | |
| let boardShape: BoardShape = { x: 5, y: 5, z: 5 }; | |
| let boardSpecified = false; | |
| let numGames = DEFAULT_NUM_GAMES; | |
| let temperature = DEFAULT_TEMPERATURE; | |
| let maxMoves = DEFAULT_MAX_MOVES; | |
| let useMCTS = false; | |
| let mctsSimulations = DEFAULT_MCTS_SIMULATIONS; | |
| let mctsCPuct = DEFAULT_MCTS_CPUCT; | |
| let mctsDirichletAlpha = DEFAULT_MCTS_DIRICHLET_ALPHA; | |
| let mctsDirichletEpsilon = DEFAULT_MCTS_DIRICHLET_EPSILON; | |
| let outputDir: string | undefined; | |
| let verbose = false; | |
| for (let i = 0; i < args.length; i++) { | |
| const arg = args[i]; | |
| switch (arg) { | |
| case "--model1": { | |
| const value = requireNextArg(args, i, "--model1"); | |
| if (!value) return null; | |
| model1Base = value; | |
| i++; | |
| break; | |
| } | |
| case "--model2": { | |
| const value = requireNextArg(args, i, "--model2"); | |
| if (!value) return null; | |
| model2Base = value; | |
| i++; | |
| break; | |
| } | |
| case "--board": { | |
| const value = requireNextArg(args, i, "--board"); | |
| if (!value) return null; | |
| const shape = parseBoardShape(value); | |
| if (!shape) { | |
| console.error(`Invalid board shape: ${value}. Use format "X*Y*Z" (e.g., "5*5*5")`); | |
| return null; | |
| } | |
| boardShape = shape; | |
| boardSpecified = true; | |
| i++; | |
| break; | |
| } | |
| case "--games": { | |
| const value = requireNextArg(args, i, "--games"); | |
| if (!value) return null; | |
| const parsed = parseIntArg(value, "--games", 1); | |
| if (parsed === null) return null; | |
| numGames = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--temperature": { | |
| const value = requireNextArg(args, i, "--temperature"); | |
| if (!value) return null; | |
| const parsed = parseFloatArg(value, "--temperature", 0); | |
| if (parsed === null) return null; | |
| temperature = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--max-moves": { | |
| const value = requireNextArg(args, i, "--max-moves"); | |
| if (!value) return null; | |
| const parsed = parseIntArg(value, "--max-moves", 1); | |
| if (parsed === null) return null; | |
| maxMoves = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--output": { | |
| const value = requireNextArg(args, i, "--output"); | |
| if (!value) return null; | |
| outputDir = value; | |
| i++; | |
| break; | |
| } | |
| case "--verbose": | |
| verbose = true; | |
| break; | |
| case "--use-mcts": | |
| useMCTS = true; | |
| break; | |
| case "--mcts-simulations": { | |
| const value = requireNextArg(args, i, "--mcts-simulations"); | |
| if (!value) return null; | |
| const parsed = parseIntArg(value, "--mcts-simulations", 1); | |
| if (parsed === null) return null; | |
| mctsSimulations = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--mcts-cpuct": { | |
| const value = requireNextArg(args, i, "--mcts-cpuct"); | |
| if (!value) return null; | |
| const parsed = parseFloatArg(value, "--mcts-cpuct", 0); | |
| if (parsed === null) return null; | |
| mctsCPuct = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--mcts-dirichlet-alpha": { | |
| const value = requireNextArg(args, i, "--mcts-dirichlet-alpha"); | |
| if (!value) return null; | |
| const parsed = parseFloatArg(value, "--mcts-dirichlet-alpha", 0); | |
| if (parsed === null) return null; | |
| mctsDirichletAlpha = parsed; | |
| i++; | |
| break; | |
| } | |
| case "--mcts-dirichlet-epsilon": { | |
| const value = requireNextArg(args, i, "--mcts-dirichlet-epsilon"); | |
| if (!value) return null; | |
| const parsed = parseFloatArg(value, "--mcts-dirichlet-epsilon", 0); | |
| if (parsed === null) return null; | |
| mctsDirichletEpsilon = parsed; | |
| i++; | |
| break; | |
| } | |
| default: | |
| if (arg.startsWith("-")) { | |
| console.error(`Unknown option: ${arg}`); | |
| printUsage(); | |
| return null; | |
| } | |
| } | |
| } | |
| // Validation | |
| if (!model1Base) { | |
| console.error("Error: --model1 is required"); | |
| printUsage(); | |
| return null; | |
| } | |
| if (!model2Base) { | |
| console.error("Error: --model2 is required"); | |
| printUsage(); | |
| return null; | |
| } | |
| if (!boardSpecified) { | |
| console.error("Error: --board is required"); | |
| printUsage(); | |
| return null; | |
| } | |
| // Resolve model paths | |
| const model1Paths = resolveModelPaths(model1Base); | |
| if (!model1Paths) { | |
| return null; | |
| } | |
| const model2Paths = resolveModelPaths(model2Base); | |
| if (!model2Paths) { | |
| return null; | |
| } | |
| // Ensure even number of games for fairness | |
| if (numGames % 2 !== 0) { | |
| numGames++; | |
| console.log(`Note: Adjusted games to ${numGames} (must be even for fair color rotation)`); | |
| } | |
| // Check MCTS requirements | |
| if (useMCTS) { | |
| if (!model1Paths.evalPath) { | |
| console.error(`Error: Evaluation model not found for model1 (expected: ${model1Base}_evaluation.onnx)`); | |
| return null; | |
| } | |
| if (!model2Paths.evalPath) { | |
| console.error(`Error: Evaluation model not found for model2 (expected: ${model2Base}_evaluation.onnx)`); | |
| return null; | |
| } | |
| } | |
| // Calculate max moves if not explicitly set (2× board size) | |
| if (maxMoves === DEFAULT_MAX_MOVES && boardSpecified) { | |
| const boardSize = boardShape.x * boardShape.y * boardShape.z; | |
| maxMoves = boardSize * 2; | |
| } | |
| return { | |
| model1Dir: path.dirname(model1Base), | |
| model2Dir: path.dirname(model2Base), | |
| model1TreePath: model1Paths.treePath, | |
| model2TreePath: model2Paths.treePath, | |
| model1EvalPath: model1Paths.evalPath, | |
| model2EvalPath: model2Paths.evalPath, | |
| boardShape, | |
| numGames, | |
| temperature, | |
| maxMoves, | |
| useMCTS, | |
| mctsSimulations, | |
| mctsCPuct, | |
| mctsDirichletAlpha, | |
| mctsDirichletEpsilon, | |
| outputDir, | |
| verbose, | |
| vocabSize: DEFAULT_VOCAB_SIZE, | |
| seqLen: DEFAULT_SEQ_LEN | |
| }; | |
| } | |
| // ============================================================================ | |
| // Agent Initialization | |
| // ============================================================================ | |
| async function createAgent( | |
| treeModelPath: string, | |
| evalModelPath: string | undefined, | |
| config: BattleConfig, | |
| name: string | |
| ): Promise<Agent> { | |
| console.log(` Loading ${name}...`); | |
| const sessionOptions = getOnnxSessionOptions(); | |
| // Load tree model | |
| const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions); | |
| const treeInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: config.vocabSize, | |
| seqLen: config.seqLen, | |
| modelPath: treeModelPath | |
| }); | |
| treeInferencer.setSession(treeSession as any); | |
| const treeAgent = new TrigoTreeAgent(treeInferencer); | |
| if (!config.useMCTS) { | |
| return treeAgent; | |
| } | |
| // MCTS mode requires evaluation model | |
| if (!evalModelPath) { | |
| throw new Error(`MCTS enabled but evaluation model path is missing for ${name}`); | |
| } | |
| // Load evaluation model for MCTS | |
| const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions); | |
| const evalInferencer = new ModelInferencer(ort.Tensor as any, { | |
| vocabSize: config.vocabSize, | |
| seqLen: config.seqLen, | |
| modelPath: evalModelPath | |
| }); | |
| evalInferencer.setSession(evalSession as any); | |
| const evalAgent = new TrigoEvaluationAgent(evalInferencer); | |
| const mctsConfig: MCTSConfig = { | |
| numSimulations: config.mctsSimulations, | |
| cPuct: config.mctsCPuct, | |
| temperature: config.temperature, | |
| dirichletAlpha: config.mctsDirichletAlpha, | |
| dirichletEpsilon: config.mctsDirichletEpsilon | |
| }; | |
| return new MCTSAgent(treeAgent, evalAgent, mctsConfig); | |
| } | |
| async function initializeAgents(config: BattleConfig): Promise<{ agent1: Agent; agent2: Agent }> { | |
| console.log("\nInitializing agents..."); | |
| const agent1 = await createAgent( | |
| config.model1TreePath, | |
| config.model1EvalPath, | |
| config, | |
| "Model 1" | |
| ); | |
| console.log(` ✓ Model 1 ready`); | |
| const agent2 = await createAgent( | |
| config.model2TreePath, | |
| config.model2EvalPath, | |
| config, | |
| "Model 2" | |
| ); | |
| console.log(` ✓ Model 2 ready`); | |
| return { agent1, agent2 }; | |
| } | |
| // ============================================================================ | |
| // Game Playing | |
| // ============================================================================ | |
| async function playBattleGame( | |
| agent1: Agent, | |
| agent2: Agent, | |
| gameId: number, | |
| config: BattleConfig, | |
| model1AsBlack: boolean | |
| ): Promise<BattleResult> { | |
| const startTime = Date.now(); | |
| const game = new TrigoGame(config.boardShape, {}); | |
| // Determine which agent plays which color | |
| const blackAgent = model1AsBlack ? agent1 : agent2; | |
| const whiteAgent = model1AsBlack ? agent2 : agent1; | |
| let moveCount = 0; | |
| let consecutivePasses = 0; | |
| const totalPositions = config.boardShape.x * config.boardShape.y * config.boardShape.z; | |
| const coverageThreshold = Math.floor(totalPositions * 0.5); | |
| let territoryCheckStarted = false; | |
| while (moveCount < config.maxMoves) { | |
| const currentPlayer = game.getCurrentPlayer(); | |
| const playerStr = currentPlayer === StoneType.BLACK ? "black" : "white"; | |
| const currentAgent = currentPlayer === StoneType.BLACK ? blackAgent : whiteAgent; | |
| // Check if we should start territory checking | |
| if (!territoryCheckStarted && moveCount >= coverageThreshold) { | |
| territoryCheckStarted = true; | |
| } | |
| let selectedMove: Move; | |
| try { | |
| if (currentAgent instanceof MCTSAgent) { | |
| const result = await currentAgent.selectMove(game, moveCount); | |
| selectedMove = result.move; | |
| } else { | |
| // Tree agent mode - use selectMove with temperature | |
| selectedMove = await currentAgent.selectMove(game, config.temperature); | |
| } | |
| } catch (err) { | |
| if (config.verbose) { | |
| console.error(`Error getting move: ${err}`); | |
| } | |
| // Pass on error | |
| game.pass(); | |
| consecutivePasses++; | |
| if (consecutivePasses >= 2) break; | |
| moveCount++; | |
| continue; | |
| } | |
| // Apply move | |
| let success: boolean; | |
| if (selectedMove.isPass) { | |
| success = game.pass(); | |
| consecutivePasses++; | |
| } else if (selectedMove.x !== undefined && selectedMove.y !== undefined && selectedMove.z !== undefined) { | |
| success = game.drop({ | |
| x: selectedMove.x, | |
| y: selectedMove.y, | |
| z: selectedMove.z | |
| }); | |
| consecutivePasses = 0; | |
| } else { | |
| success = game.pass(); | |
| consecutivePasses++; | |
| } | |
| if (!success) { | |
| if (config.verbose) { | |
| console.warn(`Move failed, passing instead`); | |
| } | |
| game.pass(); | |
| consecutivePasses++; | |
| } | |
| moveCount++; | |
| // Check game end conditions | |
| if (consecutivePasses >= 2) break; | |
| // Check for natural termination (all territory claimed, no capturing moves) | |
| if (territoryCheckStarted && !selectedMove.isPass) { | |
| if (game.isNaturallyTerminal()) { | |
| break; | |
| } | |
| } | |
| } | |
| // Calculate final scores with komi (Black compensates White) | |
| const territory = game.getTerritory(); | |
| const blackScore = territory.black; | |
| const whiteScore = territory.white + DEFAULT_KOMI; // White gets komi | |
| const scoreDiff = blackScore - whiteScore; // Positive = Black wins by X | |
| // Determine winner (with komi applied to White) | |
| let winner: "model1" | "model2" | "draw"; | |
| if (blackScore > whiteScore) { | |
| winner = model1AsBlack ? "model1" : "model2"; | |
| } else if (whiteScore > blackScore) { | |
| winner = model1AsBlack ? "model2" : "model1"; | |
| } else { | |
| winner = "draw"; | |
| } | |
| const duration = Date.now() - startTime; | |
| // Generate TGN record if output is enabled | |
| let tgn: string | undefined; | |
| if (config.outputDir) { | |
| // Format: parentDir/modelName (without _tree.onnx suffix) | |
| const formatModelName = (treePath: string) => { | |
| const dir = path.basename(path.dirname(treePath)); | |
| const filename = path.basename(treePath).replace(/_tree\.onnx$/, ""); | |
| return `${dir}/${filename}`; | |
| }; | |
| const blackModel = model1AsBlack ? formatModelName(config.model1TreePath) : formatModelName(config.model2TreePath); | |
| const whiteModel = model1AsBlack ? formatModelName(config.model2TreePath) : formatModelName(config.model1TreePath); | |
| tgn = game.toTGN({ | |
| black: blackModel, | |
| white: whiteModel, | |
| event: "Model Battle", | |
| date: new Date().toISOString().split("T")[0].replace(/-/g, ".") | |
| }, { markResult: true }); | |
| } | |
| return { | |
| gameId, | |
| model1AsBlack, | |
| winner, | |
| model1Score: model1AsBlack ? blackScore : whiteScore, | |
| model2Score: model1AsBlack ? whiteScore : blackScore, | |
| scoreDiff, | |
| moveCount, | |
| duration, | |
| tgn | |
| }; | |
| } | |
| // ============================================================================ | |
| // Battle Execution | |
| // ============================================================================ | |
| async function runBattle( | |
| agent1: Agent, | |
| agent2: Agent, | |
| config: BattleConfig, | |
| gamesDir?: string | |
| ): Promise<BattleStatistics> { | |
| const results: BattleResult[] = []; | |
| const startTime = Date.now(); | |
| console.log(`\nStarting battle: ${config.numGames} games`); | |
| console.log("=".repeat(80)); | |
| // Alternating colors for fairness | |
| for (let i = 0; i < config.numGames; i++) { | |
| const gameId = i + 1; | |
| const model1AsBlack = i % 2 === 0; // Odd games: M1=Black, Even games: M1=White | |
| const colorStr = model1AsBlack ? "M1=Black" : "M1=White"; | |
| process.stdout.write(`Game ${gameId}/${config.numGames} (${colorStr}): `); | |
| const result = await playBattleGame(agent1, agent2, gameId, config, model1AsBlack); | |
| results.push(result); | |
| // Save TGN immediately after each game | |
| if (gamesDir && result.tgn) { | |
| const colorSuffix = model1AsBlack ? "M1B" : "M1W"; | |
| const winnerSuffix = result.winner === "model1" ? "M1" : result.winner === "model2" ? "M2" : "Draw"; | |
| const tgnFilename = `game_${String(gameId).padStart(3, "0")}_${colorSuffix}_${winnerSuffix}.tgn`; | |
| const tgnPath = path.join(gamesDir, tgnFilename); | |
| fs.writeFileSync(tgnPath, result.tgn); | |
| } | |
| const winnerStr = result.winner === "model1" ? "M1" : result.winner === "model2" ? "M2" : "Draw"; | |
| const scoreDiffStr = result.scoreDiff > 0 ? `B+${result.scoreDiff.toFixed(1)}` : `W+${(-result.scoreDiff).toFixed(1)}`; | |
| console.log(`${result.moveCount} moves, ${winnerStr} wins (${scoreDiffStr}), ${(result.duration / 1000).toFixed(1)}s`); | |
| } | |
| const totalTime = Date.now() - startTime; | |
| // Calculate statistics | |
| const stats: BattleStatistics = { | |
| totalGames: config.numGames, | |
| model1Wins: results.filter((r) => r.winner === "model1").length, | |
| model2Wins: results.filter((r) => r.winner === "model2").length, | |
| draws: results.filter((r) => r.winner === "draw").length, | |
| model1WinRate: 0, | |
| model2WinRate: 0, | |
| model1AsBlackWins: results.filter((r) => r.model1AsBlack && r.winner === "model1").length, | |
| model1AsWhiteWins: results.filter((r) => !r.model1AsBlack && r.winner === "model1").length, | |
| model2AsBlackWins: results.filter((r) => !r.model1AsBlack && r.winner === "model2").length, | |
| model2AsWhiteWins: results.filter((r) => r.model1AsBlack && r.winner === "model2").length, | |
| averageGameLength: results.reduce((sum, r) => sum + r.moveCount, 0) / results.length, | |
| averageScoreDiff: results.reduce((sum, r) => sum + r.scoreDiff, 0) / results.length, | |
| totalTime, | |
| results | |
| }; | |
| stats.model1WinRate = (stats.model1Wins / stats.totalGames) * 100; | |
| stats.model2WinRate = (stats.model2Wins / stats.totalGames) * 100; | |
| return stats; | |
| } | |
| // ============================================================================ | |
| // Output and Statistics | |
| // ============================================================================ | |
| function printStatistics(stats: BattleStatistics, config: BattleConfig): void { | |
| console.log("\n" + "=".repeat(80)); | |
| console.log("Battle Results"); | |
| console.log("=".repeat(80)); | |
| console.log(`\nModel 1: ${path.basename(config.model1TreePath)}`); | |
| console.log(`Model 2: ${path.basename(config.model2TreePath)}`); | |
| console.log(`Board: ${config.boardShape.x}×${config.boardShape.y}×${config.boardShape.z}`); | |
| console.log(`Games: ${stats.totalGames} (${stats.totalGames / 2} as each color)`); | |
| console.log(`Mode: ${config.useMCTS ? `MCTS (${config.mctsSimulations} simulations)` : "Tree Attention"}`); | |
| console.log("\n" + "-".repeat(80)); | |
| console.log("Win Statistics:"); | |
| console.log("-".repeat(80)); | |
| console.log(`\n Model 1 Wins: ${stats.model1Wins} (${stats.model1WinRate.toFixed(1)}%)`); | |
| console.log(` - As Black: ${stats.model1AsBlackWins}`); | |
| console.log(` - As White: ${stats.model1AsWhiteWins}`); | |
| console.log(`\n Model 2 Wins: ${stats.model2Wins} (${stats.model2WinRate.toFixed(1)}%)`); | |
| console.log(` - As Black: ${stats.model2AsBlackWins}`); | |
| console.log(` - As White: ${stats.model2AsWhiteWins}`); | |
| console.log(`\n Draws: ${stats.draws}`); | |
| console.log("\n" + "-".repeat(80)); | |
| console.log("Game Statistics:"); | |
| console.log("-".repeat(80)); | |
| console.log(` Average game length: ${stats.averageGameLength.toFixed(1)} moves`); | |
| console.log(` Average score diff: ${stats.averageScoreDiff > 0 ? "+" : ""}${stats.averageScoreDiff.toFixed(1)} (White - Black)`); | |
| console.log(` Total time: ${(stats.totalTime / 60000).toFixed(1)} minutes`); | |
| console.log(` Average time per game: ${(stats.totalTime / stats.totalGames / 1000).toFixed(1)} seconds`); | |
| console.log("\n" + "=".repeat(80)); | |
| if (stats.model1Wins > stats.model2Wins) { | |
| console.log(`Winner: Model 1 (${stats.model1WinRate.toFixed(1)}% vs ${stats.model2WinRate.toFixed(1)}%)`); | |
| } else if (stats.model2Wins > stats.model1Wins) { | |
| console.log(`Winner: Model 2 (${stats.model2WinRate.toFixed(1)}% vs ${stats.model1WinRate.toFixed(1)}%)`); | |
| } else { | |
| console.log(`Result: Tie (${stats.model1WinRate.toFixed(1)}% each)`); | |
| } | |
| console.log("=".repeat(80)); | |
| } | |
| function saveStatistics(stats: BattleStatistics, config: BattleConfig, gamesDir?: string): void { | |
| if (!config.outputDir) return; | |
| // Create output directory | |
| if (!fs.existsSync(config.outputDir)) { | |
| fs.mkdirSync(config.outputDir, { recursive: true }); | |
| } | |
| // Save battle summary JSON (use gamesDir name for consistency) | |
| const dirName = gamesDir ? path.basename(gamesDir) : `battle_${new Date().toISOString().replace(/[:.]/g, "-")}`; | |
| const summaryFilename = `${dirName}.json`; | |
| const summaryFilepath = path.join(config.outputDir, summaryFilename); | |
| // Remove TGN from results to keep JSON clean | |
| const resultsWithoutTgn = stats.results.map(({ tgn, ...rest }) => rest); | |
| const output = { | |
| config: { | |
| model1Tree: config.model1TreePath, | |
| model2Tree: config.model2TreePath, | |
| model1Eval: config.model1EvalPath, | |
| model2Eval: config.model2EvalPath, | |
| boardShape: config.boardShape, | |
| numGames: config.numGames, | |
| temperature: config.temperature, | |
| useMCTS: config.useMCTS, | |
| mctsSimulations: config.mctsSimulations, | |
| komi: DEFAULT_KOMI | |
| }, | |
| statistics: { ...stats, results: resultsWithoutTgn }, | |
| timestamp: new Date().toISOString() | |
| }; | |
| fs.writeFileSync(summaryFilepath, JSON.stringify(output, null, 2)); | |
| console.log(`\nResults saved to: ${summaryFilepath}`); | |
| if (gamesDir) { | |
| console.log(`Game records saved to: ${gamesDir}/`); | |
| } | |
| } | |
| // ============================================================================ | |
| // Main | |
| // ============================================================================ | |
| async function main(): Promise<void> { | |
| // Load environment config | |
| await loadEnvConfig(); | |
| // Parse arguments | |
| const config = parseArgs(); | |
| if (!config) { | |
| process.exit(1); | |
| } | |
| console.log("=".repeat(80)); | |
| console.log("Model vs Model Battle"); | |
| console.log("=".repeat(80)); | |
| console.log(`Model 1: ${config.model1TreePath}`); | |
| console.log(`Model 2: ${config.model2TreePath}`); | |
| console.log(`Board: ${config.boardShape.x}×${config.boardShape.y}×${config.boardShape.z}`); | |
| console.log(`Games: ${config.numGames}`); | |
| console.log(`Mode: ${config.useMCTS ? "MCTS" : "Tree Attention"}`); | |
| if (config.useMCTS) { | |
| console.log(`MCTS Simulations: ${config.mctsSimulations}`); | |
| } | |
| // Initialize agents | |
| const { agent1, agent2 } = await initializeAgents(config); | |
| // Create games directory if output is enabled | |
| let gamesDir: string | undefined; | |
| if (config.outputDir) { | |
| const timestamp = new Date().toISOString().replace(/[:.]/g, "-"); | |
| gamesDir = path.join(config.outputDir, `battle_${timestamp}`); | |
| fs.mkdirSync(gamesDir, { recursive: true }); | |
| } | |
| // Run battle | |
| const stats = await runBattle(agent1, agent2, config, gamesDir); | |
| // Print results | |
| printStatistics(stats, config); | |
| // Save results | |
| if (config.outputDir) { | |
| saveStatistics(stats, config, gamesDir); | |
| } | |
| } | |
| main().catch((err) => { | |
| console.error("Error:", err); | |
| process.exit(1); | |
| }); | |