trigo / trigo-web /tools /modelBattle.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
/**
* Model vs Model Battle Script
*
* This script enables different ONNX model versions to battle against each other.
* Features:
* - Fair color rotation (each model plays equal games as Black and White)
* - Fixed board size
* - Win rate statistics
* - Optional MCTS search (disabled by default)
*
* Usage: npx tsx tools/modelBattle.ts [options]
*/
import * as ort from "onnxruntime-node";
import * as path from "path";
import * as fs from "fs";
import { TrigoGame, StoneType } from "../inc/trigo/game";
import { ModelInferencer } from "../inc/modelInferencer";
import { TrigoTreeAgent } from "../inc/trigoTreeAgent";
import { TrigoEvaluationAgent } from "../inc/trigoEvaluationAgent";
import { MCTSAgent, MCTSConfig } from "../inc/mctsAgent";
import type { Move, BoardShape } from "../inc/trigo/types";
import { loadEnvConfig, getOnnxSessionOptions } from "../inc/config";
// ============================================================================
// Constants
// ============================================================================
const DEFAULT_VOCAB_SIZE = 128;
const DEFAULT_SEQ_LEN = 256;
const DEFAULT_NUM_GAMES = 10;
const DEFAULT_TEMPERATURE = 1.0;
const DEFAULT_MAX_MOVES = 300;
const DEFAULT_MCTS_SIMULATIONS = 600;
const DEFAULT_MCTS_CPUCT = 1.0;
const DEFAULT_MCTS_DIRICHLET_ALPHA = 0.03;
const DEFAULT_MCTS_DIRICHLET_EPSILON = 0.0; // Zero noise for deterministic evaluation
const DEFAULT_KOMI = 0.5; // Komi: Black compensates White (White's advantage)
// ============================================================================
// Interfaces
// ============================================================================
interface BattleConfig {
model1Dir: string;
model2Dir: string;
model1TreePath: string;
model2TreePath: string;
model1EvalPath?: string;
model2EvalPath?: string;
boardShape: BoardShape;
numGames: number;
temperature: number;
maxMoves: number;
useMCTS: boolean;
mctsSimulations: number;
mctsCPuct: number;
mctsDirichletAlpha: number;
mctsDirichletEpsilon: number;
outputDir?: string;
verbose: boolean;
vocabSize: number;
seqLen: number;
}
interface BattleResult {
gameId: number;
model1AsBlack: boolean;
winner: "model1" | "model2" | "draw";
model1Score: number;
model2Score: number;
scoreDiff: number;
moveCount: number;
duration: number;
tgn?: string; // TGN format game record
}
interface BattleStatistics {
totalGames: number;
model1Wins: number;
model2Wins: number;
draws: number;
model1WinRate: number;
model2WinRate: number;
model1AsBlackWins: number;
model1AsWhiteWins: number;
model2AsBlackWins: number;
model2AsWhiteWins: number;
averageGameLength: number;
averageScoreDiff: number;
totalTime: number;
results: BattleResult[];
}
type Agent = TrigoTreeAgent | MCTSAgent;
// ============================================================================
// Argument Parsing
// ============================================================================
function printUsage(): void {
console.log(`
Usage: npx tsx tools/modelBattle.ts [options]
Required:
--model1 <path> Model 1 base path (without suffix, e.g., ./dir/ModelName)
--model2 <path> Model 2 base path (without suffix, e.g., ./dir/ModelName)
--board <shape> Board shape "X*Y*Z" (e.g., "5*5*5")
Options:
--games <n> Number of games (default: 10, must be even)
--temperature <t> Sampling temperature (default: 1.0)
--max-moves <n> Maximum moves per game (default: 300)
--output <dir> Output directory for results
--verbose Enable verbose logging
MCTS Options:
--use-mcts Enable MCTS for move selection (requires evaluation model)
--mcts-simulations <n> MCTS simulations per move (default: 600)
--mcts-cpuct <f> PUCT exploration constant (default: 1.0)
--mcts-dirichlet-alpha <f> Dirichlet noise alpha (default: 0.03)
--mcts-dirichlet-epsilon <f> Dirichlet noise epsilon (default: 0.25)
--help Show this help message
Model Path Format:
Provide the base path WITHOUT the suffix. The script auto-appends:
- _tree.onnx -> Tree model (required)
- _evaluation.onnx -> Evaluation model (auto-detected, required for MCTS)
Examples:
# Battle two models on 5x5x5 board for 20 games
npx tsx tools/modelBattle.ts \\
--model1 ./public/onnx/20251215-trigo-value-llama-l6-h64-251211/LlamaCausalLM_ep0045 \\
--model2 ./public/onnx/20251204-trigo-value-gpt2-l6-h64-251125-lr500/GPT2LMHeadModel_ep0100 \\
--board "5*5*5" --games 20
# Battle with MCTS enabled
npx tsx tools/modelBattle.ts \\
--model1 ./public/onnx/model_v1/Model_ep0100 \\
--model2 ./public/onnx/model_v2/Model_ep0100 \\
--board "5*5*5" --games 20 --use-mcts --mcts-simulations 400
`);
}
function parseBoardShape(shapeStr: string): BoardShape | null {
const match = shapeStr.match(/^(\d+)\*(\d+)\*(\d+)$/);
if (!match) return null;
return {
x: parseInt(match[1], 10),
y: parseInt(match[2], 10),
z: parseInt(match[3], 10)
};
}
/**
* Resolve model paths from base path (without suffix)
* Auto-appends _tree.onnx and _evaluation.onnx suffixes
* Returns { treePath, evalPath } or null if tree model not found
*/
function resolveModelPaths(basePath: string): { treePath: string; evalPath?: string } | null {
const treePath = `${basePath}_tree.onnx`;
const evalPath = `${basePath}_evaluation.onnx`;
if (!fs.existsSync(treePath)) {
console.error(`Tree model not found: ${treePath}`);
return null;
}
return {
treePath,
evalPath: fs.existsSync(evalPath) ? evalPath : undefined
};
}
/**
* Check if next argument exists and is not a flag
*/
function requireNextArg(args: string[], i: number, flag: string): string | null {
const next = args[i + 1];
if (!next || next.startsWith("-")) {
console.error(`Error: ${flag} requires a value`);
printUsage();
return null;
}
return next;
}
/**
* Parse and validate integer argument
*/
function parseIntArg(value: string, flag: string, minValue?: number): number | null {
const n = parseInt(value, 10);
if (Number.isNaN(n)) {
console.error(`Error: Invalid integer for ${flag}: ${value}`);
return null;
}
if (minValue !== undefined && n < minValue) {
console.error(`Error: ${flag} must be >= ${minValue}`);
return null;
}
return n;
}
/**
* Parse and validate float argument
*/
function parseFloatArg(value: string, flag: string, minValue?: number): number | null {
const n = parseFloat(value);
if (Number.isNaN(n) || !Number.isFinite(n)) {
console.error(`Error: Invalid number for ${flag}: ${value}`);
return null;
}
if (minValue !== undefined && n < minValue) {
console.error(`Error: ${flag} must be >= ${minValue}`);
return null;
}
return n;
}
function parseArgs(): BattleConfig | null {
const args = process.argv.slice(2);
if (args.includes("--help") || args.includes("-h")) {
printUsage();
return null;
}
let model1Base = "";
let model2Base = "";
let boardShape: BoardShape = { x: 5, y: 5, z: 5 };
let boardSpecified = false;
let numGames = DEFAULT_NUM_GAMES;
let temperature = DEFAULT_TEMPERATURE;
let maxMoves = DEFAULT_MAX_MOVES;
let useMCTS = false;
let mctsSimulations = DEFAULT_MCTS_SIMULATIONS;
let mctsCPuct = DEFAULT_MCTS_CPUCT;
let mctsDirichletAlpha = DEFAULT_MCTS_DIRICHLET_ALPHA;
let mctsDirichletEpsilon = DEFAULT_MCTS_DIRICHLET_EPSILON;
let outputDir: string | undefined;
let verbose = false;
for (let i = 0; i < args.length; i++) {
const arg = args[i];
switch (arg) {
case "--model1": {
const value = requireNextArg(args, i, "--model1");
if (!value) return null;
model1Base = value;
i++;
break;
}
case "--model2": {
const value = requireNextArg(args, i, "--model2");
if (!value) return null;
model2Base = value;
i++;
break;
}
case "--board": {
const value = requireNextArg(args, i, "--board");
if (!value) return null;
const shape = parseBoardShape(value);
if (!shape) {
console.error(`Invalid board shape: ${value}. Use format "X*Y*Z" (e.g., "5*5*5")`);
return null;
}
boardShape = shape;
boardSpecified = true;
i++;
break;
}
case "--games": {
const value = requireNextArg(args, i, "--games");
if (!value) return null;
const parsed = parseIntArg(value, "--games", 1);
if (parsed === null) return null;
numGames = parsed;
i++;
break;
}
case "--temperature": {
const value = requireNextArg(args, i, "--temperature");
if (!value) return null;
const parsed = parseFloatArg(value, "--temperature", 0);
if (parsed === null) return null;
temperature = parsed;
i++;
break;
}
case "--max-moves": {
const value = requireNextArg(args, i, "--max-moves");
if (!value) return null;
const parsed = parseIntArg(value, "--max-moves", 1);
if (parsed === null) return null;
maxMoves = parsed;
i++;
break;
}
case "--output": {
const value = requireNextArg(args, i, "--output");
if (!value) return null;
outputDir = value;
i++;
break;
}
case "--verbose":
verbose = true;
break;
case "--use-mcts":
useMCTS = true;
break;
case "--mcts-simulations": {
const value = requireNextArg(args, i, "--mcts-simulations");
if (!value) return null;
const parsed = parseIntArg(value, "--mcts-simulations", 1);
if (parsed === null) return null;
mctsSimulations = parsed;
i++;
break;
}
case "--mcts-cpuct": {
const value = requireNextArg(args, i, "--mcts-cpuct");
if (!value) return null;
const parsed = parseFloatArg(value, "--mcts-cpuct", 0);
if (parsed === null) return null;
mctsCPuct = parsed;
i++;
break;
}
case "--mcts-dirichlet-alpha": {
const value = requireNextArg(args, i, "--mcts-dirichlet-alpha");
if (!value) return null;
const parsed = parseFloatArg(value, "--mcts-dirichlet-alpha", 0);
if (parsed === null) return null;
mctsDirichletAlpha = parsed;
i++;
break;
}
case "--mcts-dirichlet-epsilon": {
const value = requireNextArg(args, i, "--mcts-dirichlet-epsilon");
if (!value) return null;
const parsed = parseFloatArg(value, "--mcts-dirichlet-epsilon", 0);
if (parsed === null) return null;
mctsDirichletEpsilon = parsed;
i++;
break;
}
default:
if (arg.startsWith("-")) {
console.error(`Unknown option: ${arg}`);
printUsage();
return null;
}
}
}
// Validation
if (!model1Base) {
console.error("Error: --model1 is required");
printUsage();
return null;
}
if (!model2Base) {
console.error("Error: --model2 is required");
printUsage();
return null;
}
if (!boardSpecified) {
console.error("Error: --board is required");
printUsage();
return null;
}
// Resolve model paths
const model1Paths = resolveModelPaths(model1Base);
if (!model1Paths) {
return null;
}
const model2Paths = resolveModelPaths(model2Base);
if (!model2Paths) {
return null;
}
// Ensure even number of games for fairness
if (numGames % 2 !== 0) {
numGames++;
console.log(`Note: Adjusted games to ${numGames} (must be even for fair color rotation)`);
}
// Check MCTS requirements
if (useMCTS) {
if (!model1Paths.evalPath) {
console.error(`Error: Evaluation model not found for model1 (expected: ${model1Base}_evaluation.onnx)`);
return null;
}
if (!model2Paths.evalPath) {
console.error(`Error: Evaluation model not found for model2 (expected: ${model2Base}_evaluation.onnx)`);
return null;
}
}
// Calculate max moves if not explicitly set (2× board size)
if (maxMoves === DEFAULT_MAX_MOVES && boardSpecified) {
const boardSize = boardShape.x * boardShape.y * boardShape.z;
maxMoves = boardSize * 2;
}
return {
model1Dir: path.dirname(model1Base),
model2Dir: path.dirname(model2Base),
model1TreePath: model1Paths.treePath,
model2TreePath: model2Paths.treePath,
model1EvalPath: model1Paths.evalPath,
model2EvalPath: model2Paths.evalPath,
boardShape,
numGames,
temperature,
maxMoves,
useMCTS,
mctsSimulations,
mctsCPuct,
mctsDirichletAlpha,
mctsDirichletEpsilon,
outputDir,
verbose,
vocabSize: DEFAULT_VOCAB_SIZE,
seqLen: DEFAULT_SEQ_LEN
};
}
// ============================================================================
// Agent Initialization
// ============================================================================
async function createAgent(
treeModelPath: string,
evalModelPath: string | undefined,
config: BattleConfig,
name: string
): Promise<Agent> {
console.log(` Loading ${name}...`);
const sessionOptions = getOnnxSessionOptions();
// Load tree model
const treeSession = await ort.InferenceSession.create(treeModelPath, sessionOptions);
const treeInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: config.vocabSize,
seqLen: config.seqLen,
modelPath: treeModelPath
});
treeInferencer.setSession(treeSession as any);
const treeAgent = new TrigoTreeAgent(treeInferencer);
if (!config.useMCTS) {
return treeAgent;
}
// MCTS mode requires evaluation model
if (!evalModelPath) {
throw new Error(`MCTS enabled but evaluation model path is missing for ${name}`);
}
// Load evaluation model for MCTS
const evalSession = await ort.InferenceSession.create(evalModelPath, sessionOptions);
const evalInferencer = new ModelInferencer(ort.Tensor as any, {
vocabSize: config.vocabSize,
seqLen: config.seqLen,
modelPath: evalModelPath
});
evalInferencer.setSession(evalSession as any);
const evalAgent = new TrigoEvaluationAgent(evalInferencer);
const mctsConfig: MCTSConfig = {
numSimulations: config.mctsSimulations,
cPuct: config.mctsCPuct,
temperature: config.temperature,
dirichletAlpha: config.mctsDirichletAlpha,
dirichletEpsilon: config.mctsDirichletEpsilon
};
return new MCTSAgent(treeAgent, evalAgent, mctsConfig);
}
async function initializeAgents(config: BattleConfig): Promise<{ agent1: Agent; agent2: Agent }> {
console.log("\nInitializing agents...");
const agent1 = await createAgent(
config.model1TreePath,
config.model1EvalPath,
config,
"Model 1"
);
console.log(` ✓ Model 1 ready`);
const agent2 = await createAgent(
config.model2TreePath,
config.model2EvalPath,
config,
"Model 2"
);
console.log(` ✓ Model 2 ready`);
return { agent1, agent2 };
}
// ============================================================================
// Game Playing
// ============================================================================
async function playBattleGame(
agent1: Agent,
agent2: Agent,
gameId: number,
config: BattleConfig,
model1AsBlack: boolean
): Promise<BattleResult> {
const startTime = Date.now();
const game = new TrigoGame(config.boardShape, {});
// Determine which agent plays which color
const blackAgent = model1AsBlack ? agent1 : agent2;
const whiteAgent = model1AsBlack ? agent2 : agent1;
let moveCount = 0;
let consecutivePasses = 0;
const totalPositions = config.boardShape.x * config.boardShape.y * config.boardShape.z;
const coverageThreshold = Math.floor(totalPositions * 0.5);
let territoryCheckStarted = false;
while (moveCount < config.maxMoves) {
const currentPlayer = game.getCurrentPlayer();
const playerStr = currentPlayer === StoneType.BLACK ? "black" : "white";
const currentAgent = currentPlayer === StoneType.BLACK ? blackAgent : whiteAgent;
// Check if we should start territory checking
if (!territoryCheckStarted && moveCount >= coverageThreshold) {
territoryCheckStarted = true;
}
let selectedMove: Move;
try {
if (currentAgent instanceof MCTSAgent) {
const result = await currentAgent.selectMove(game, moveCount);
selectedMove = result.move;
} else {
// Tree agent mode - use selectMove with temperature
selectedMove = await currentAgent.selectMove(game, config.temperature);
}
} catch (err) {
if (config.verbose) {
console.error(`Error getting move: ${err}`);
}
// Pass on error
game.pass();
consecutivePasses++;
if (consecutivePasses >= 2) break;
moveCount++;
continue;
}
// Apply move
let success: boolean;
if (selectedMove.isPass) {
success = game.pass();
consecutivePasses++;
} else if (selectedMove.x !== undefined && selectedMove.y !== undefined && selectedMove.z !== undefined) {
success = game.drop({
x: selectedMove.x,
y: selectedMove.y,
z: selectedMove.z
});
consecutivePasses = 0;
} else {
success = game.pass();
consecutivePasses++;
}
if (!success) {
if (config.verbose) {
console.warn(`Move failed, passing instead`);
}
game.pass();
consecutivePasses++;
}
moveCount++;
// Check game end conditions
if (consecutivePasses >= 2) break;
// Check for natural termination (all territory claimed, no capturing moves)
if (territoryCheckStarted && !selectedMove.isPass) {
if (game.isNaturallyTerminal()) {
break;
}
}
}
// Calculate final scores with komi (Black compensates White)
const territory = game.getTerritory();
const blackScore = territory.black;
const whiteScore = territory.white + DEFAULT_KOMI; // White gets komi
const scoreDiff = blackScore - whiteScore; // Positive = Black wins by X
// Determine winner (with komi applied to White)
let winner: "model1" | "model2" | "draw";
if (blackScore > whiteScore) {
winner = model1AsBlack ? "model1" : "model2";
} else if (whiteScore > blackScore) {
winner = model1AsBlack ? "model2" : "model1";
} else {
winner = "draw";
}
const duration = Date.now() - startTime;
// Generate TGN record if output is enabled
let tgn: string | undefined;
if (config.outputDir) {
// Format: parentDir/modelName (without _tree.onnx suffix)
const formatModelName = (treePath: string) => {
const dir = path.basename(path.dirname(treePath));
const filename = path.basename(treePath).replace(/_tree\.onnx$/, "");
return `${dir}/${filename}`;
};
const blackModel = model1AsBlack ? formatModelName(config.model1TreePath) : formatModelName(config.model2TreePath);
const whiteModel = model1AsBlack ? formatModelName(config.model2TreePath) : formatModelName(config.model1TreePath);
tgn = game.toTGN({
black: blackModel,
white: whiteModel,
event: "Model Battle",
date: new Date().toISOString().split("T")[0].replace(/-/g, ".")
}, { markResult: true });
}
return {
gameId,
model1AsBlack,
winner,
model1Score: model1AsBlack ? blackScore : whiteScore,
model2Score: model1AsBlack ? whiteScore : blackScore,
scoreDiff,
moveCount,
duration,
tgn
};
}
// ============================================================================
// Battle Execution
// ============================================================================
async function runBattle(
agent1: Agent,
agent2: Agent,
config: BattleConfig,
gamesDir?: string
): Promise<BattleStatistics> {
const results: BattleResult[] = [];
const startTime = Date.now();
console.log(`\nStarting battle: ${config.numGames} games`);
console.log("=".repeat(80));
// Alternating colors for fairness
for (let i = 0; i < config.numGames; i++) {
const gameId = i + 1;
const model1AsBlack = i % 2 === 0; // Odd games: M1=Black, Even games: M1=White
const colorStr = model1AsBlack ? "M1=Black" : "M1=White";
process.stdout.write(`Game ${gameId}/${config.numGames} (${colorStr}): `);
const result = await playBattleGame(agent1, agent2, gameId, config, model1AsBlack);
results.push(result);
// Save TGN immediately after each game
if (gamesDir && result.tgn) {
const colorSuffix = model1AsBlack ? "M1B" : "M1W";
const winnerSuffix = result.winner === "model1" ? "M1" : result.winner === "model2" ? "M2" : "Draw";
const tgnFilename = `game_${String(gameId).padStart(3, "0")}_${colorSuffix}_${winnerSuffix}.tgn`;
const tgnPath = path.join(gamesDir, tgnFilename);
fs.writeFileSync(tgnPath, result.tgn);
}
const winnerStr = result.winner === "model1" ? "M1" : result.winner === "model2" ? "M2" : "Draw";
const scoreDiffStr = result.scoreDiff > 0 ? `B+${result.scoreDiff.toFixed(1)}` : `W+${(-result.scoreDiff).toFixed(1)}`;
console.log(`${result.moveCount} moves, ${winnerStr} wins (${scoreDiffStr}), ${(result.duration / 1000).toFixed(1)}s`);
}
const totalTime = Date.now() - startTime;
// Calculate statistics
const stats: BattleStatistics = {
totalGames: config.numGames,
model1Wins: results.filter((r) => r.winner === "model1").length,
model2Wins: results.filter((r) => r.winner === "model2").length,
draws: results.filter((r) => r.winner === "draw").length,
model1WinRate: 0,
model2WinRate: 0,
model1AsBlackWins: results.filter((r) => r.model1AsBlack && r.winner === "model1").length,
model1AsWhiteWins: results.filter((r) => !r.model1AsBlack && r.winner === "model1").length,
model2AsBlackWins: results.filter((r) => !r.model1AsBlack && r.winner === "model2").length,
model2AsWhiteWins: results.filter((r) => r.model1AsBlack && r.winner === "model2").length,
averageGameLength: results.reduce((sum, r) => sum + r.moveCount, 0) / results.length,
averageScoreDiff: results.reduce((sum, r) => sum + r.scoreDiff, 0) / results.length,
totalTime,
results
};
stats.model1WinRate = (stats.model1Wins / stats.totalGames) * 100;
stats.model2WinRate = (stats.model2Wins / stats.totalGames) * 100;
return stats;
}
// ============================================================================
// Output and Statistics
// ============================================================================
function printStatistics(stats: BattleStatistics, config: BattleConfig): void {
console.log("\n" + "=".repeat(80));
console.log("Battle Results");
console.log("=".repeat(80));
console.log(`\nModel 1: ${path.basename(config.model1TreePath)}`);
console.log(`Model 2: ${path.basename(config.model2TreePath)}`);
console.log(`Board: ${config.boardShape.x}×${config.boardShape.y}×${config.boardShape.z}`);
console.log(`Games: ${stats.totalGames} (${stats.totalGames / 2} as each color)`);
console.log(`Mode: ${config.useMCTS ? `MCTS (${config.mctsSimulations} simulations)` : "Tree Attention"}`);
console.log("\n" + "-".repeat(80));
console.log("Win Statistics:");
console.log("-".repeat(80));
console.log(`\n Model 1 Wins: ${stats.model1Wins} (${stats.model1WinRate.toFixed(1)}%)`);
console.log(` - As Black: ${stats.model1AsBlackWins}`);
console.log(` - As White: ${stats.model1AsWhiteWins}`);
console.log(`\n Model 2 Wins: ${stats.model2Wins} (${stats.model2WinRate.toFixed(1)}%)`);
console.log(` - As Black: ${stats.model2AsBlackWins}`);
console.log(` - As White: ${stats.model2AsWhiteWins}`);
console.log(`\n Draws: ${stats.draws}`);
console.log("\n" + "-".repeat(80));
console.log("Game Statistics:");
console.log("-".repeat(80));
console.log(` Average game length: ${stats.averageGameLength.toFixed(1)} moves`);
console.log(` Average score diff: ${stats.averageScoreDiff > 0 ? "+" : ""}${stats.averageScoreDiff.toFixed(1)} (White - Black)`);
console.log(` Total time: ${(stats.totalTime / 60000).toFixed(1)} minutes`);
console.log(` Average time per game: ${(stats.totalTime / stats.totalGames / 1000).toFixed(1)} seconds`);
console.log("\n" + "=".repeat(80));
if (stats.model1Wins > stats.model2Wins) {
console.log(`Winner: Model 1 (${stats.model1WinRate.toFixed(1)}% vs ${stats.model2WinRate.toFixed(1)}%)`);
} else if (stats.model2Wins > stats.model1Wins) {
console.log(`Winner: Model 2 (${stats.model2WinRate.toFixed(1)}% vs ${stats.model1WinRate.toFixed(1)}%)`);
} else {
console.log(`Result: Tie (${stats.model1WinRate.toFixed(1)}% each)`);
}
console.log("=".repeat(80));
}
function saveStatistics(stats: BattleStatistics, config: BattleConfig, gamesDir?: string): void {
if (!config.outputDir) return;
// Create output directory
if (!fs.existsSync(config.outputDir)) {
fs.mkdirSync(config.outputDir, { recursive: true });
}
// Save battle summary JSON (use gamesDir name for consistency)
const dirName = gamesDir ? path.basename(gamesDir) : `battle_${new Date().toISOString().replace(/[:.]/g, "-")}`;
const summaryFilename = `${dirName}.json`;
const summaryFilepath = path.join(config.outputDir, summaryFilename);
// Remove TGN from results to keep JSON clean
const resultsWithoutTgn = stats.results.map(({ tgn, ...rest }) => rest);
const output = {
config: {
model1Tree: config.model1TreePath,
model2Tree: config.model2TreePath,
model1Eval: config.model1EvalPath,
model2Eval: config.model2EvalPath,
boardShape: config.boardShape,
numGames: config.numGames,
temperature: config.temperature,
useMCTS: config.useMCTS,
mctsSimulations: config.mctsSimulations,
komi: DEFAULT_KOMI
},
statistics: { ...stats, results: resultsWithoutTgn },
timestamp: new Date().toISOString()
};
fs.writeFileSync(summaryFilepath, JSON.stringify(output, null, 2));
console.log(`\nResults saved to: ${summaryFilepath}`);
if (gamesDir) {
console.log(`Game records saved to: ${gamesDir}/`);
}
}
// ============================================================================
// Main
// ============================================================================
async function main(): Promise<void> {
// Load environment config
await loadEnvConfig();
// Parse arguments
const config = parseArgs();
if (!config) {
process.exit(1);
}
console.log("=".repeat(80));
console.log("Model vs Model Battle");
console.log("=".repeat(80));
console.log(`Model 1: ${config.model1TreePath}`);
console.log(`Model 2: ${config.model2TreePath}`);
console.log(`Board: ${config.boardShape.x}×${config.boardShape.y}×${config.boardShape.z}`);
console.log(`Games: ${config.numGames}`);
console.log(`Mode: ${config.useMCTS ? "MCTS" : "Tree Attention"}`);
if (config.useMCTS) {
console.log(`MCTS Simulations: ${config.mctsSimulations}`);
}
// Initialize agents
const { agent1, agent2 } = await initializeAgents(config);
// Create games directory if output is enabled
let gamesDir: string | undefined;
if (config.outputDir) {
const timestamp = new Date().toISOString().replace(/[:.]/g, "-");
gamesDir = path.join(config.outputDir, `battle_${timestamp}`);
fs.mkdirSync(gamesDir, { recursive: true });
}
// Run battle
const stats = await runBattle(agent1, agent2, config, gamesDir);
// Print results
printStatistics(stats, config);
// Save results
if (config.outputDir) {
saveStatistics(stats, config, gamesDir);
}
}
main().catch((err) => {
console.error("Error:", err);
process.exit(1);
});