/** * MCTS Data Parser * * Parses TGN game files and visit count JSON files to reconstruct * move-by-move MCTS statistics. */ import { TrigoGame } from "../../../inc/trigo/game"; import type { MCTSMoveData, MCTSMoveStatistic } from "../types/mcts"; /** * Parse MCTS data from TGN and visit counts JSON * * @param tgnContent - TGN game notation string * @param visitCountsJson - JSON string containing visit counts array * @returns Array of MCTS move data for each move in the game * @throws Error if parsing fails or data is invalid */ export async function parseMCTSData( tgnContent: string, visitCountsJson: string ): Promise { // 1. Parse TGN to get game let game: TrigoGame; try { game = TrigoGame.fromTGN(tgnContent); } catch (error) { throw new Error(`Failed to parse TGN file: ${error instanceof Error ? error.message : String(error)}`); } // 2. Validate board shape (must be 2D) const shape = game.getShape(); if (shape.z !== 1) { throw new Error( `Only 2D boards (z=1) are supported. This game has shape ${shape.x}×${shape.y}×${shape.z}` ); } // 3. Parse visit counts JSON let visitCounts: number[][]; try { visitCounts = JSON.parse(visitCountsJson); if (!Array.isArray(visitCounts)) { throw new Error("Visit counts must be an array"); } if (visitCounts.length === 0) { throw new Error("Visit counts array is empty"); } // Validate that each element is an array of numbers for (let i = 0; i < visitCounts.length; i++) { if (!Array.isArray(visitCounts[i])) { throw new Error(`Visit counts at index ${i} is not an array`); } } } catch (error) { throw new Error(`Failed to parse visit counts JSON: ${error instanceof Error ? error.message : String(error)}`); } // 4. Replay game move-by-move and map visit counts to positions const result: MCTSMoveData[] = []; const replayGame = new TrigoGame(shape, {}); const stepHistory = game.getHistory(); // Ensure we have visit counts for each move if (visitCounts.length !== stepHistory.length) { console.warn( `Visit counts length (${visitCounts.length}) doesn't match step history length (${stepHistory.length}). Using minimum.` ); } const minLength = Math.min(visitCounts.length, stepHistory.length); for (let i = 0; i < minLength; i++) { // Get current player const currentPlayer = replayGame.getCurrentPlayer() === 1 ? "black" : "white"; // Get all valid positions at this state const validPos = replayGame.validMovePositions(); // Map visit counts to positions const statistics: MCTSMoveStatistic[] = validPos.map((pos, idx) => ({ position: pos, actionKey: `${pos.x},${pos.y},${pos.z}`, N: visitCounts[i][idx] || 0, pi: 0 // Will be calculated after })); // Add pass move (typically the last element in visit counts) const passN = visitCounts[i][validPos.length] || 0; statistics.push({ position: null, actionKey: "pass", N: passN, pi: 0 }); // Calculate total visits const totalN = statistics.reduce((sum, stat) => sum + stat.N, 0); // Normalize to get search policy π if (totalN > 0) { statistics.forEach((stat) => { stat.pi = stat.N / totalN; }); } // Store move data result.push({ moveNumber: i, player: currentPlayer, gameState: replayGame.clone(), statistics }); // Apply move to advance game state const step = stepHistory[i]; if (step.type === 1) { // StepType.PASS replayGame.pass(); } else if (step.position) { replayGame.drop(step.position); } } if (result.length === 0) { throw new Error("No MCTS data could be parsed"); } console.log(`[MCTS Parser] Successfully parsed ${result.length} moves`); return result; } /** * Validate that visit counts data has the expected structure */ export function validateVisitCounts(visitCounts: any): visitCounts is number[][] { if (!Array.isArray(visitCounts)) { return false; } for (const counts of visitCounts) { if (!Array.isArray(counts)) { return false; } for (const count of counts) { if (typeof count !== "number" || !Number.isFinite(count) || count < 0) { return false; } } } return true; } /** * Format position for display */ export function formatPosition(position: { x: number; y: number; z?: number } | null): string { if (!position) { return "Pass"; } // Convert to chess-like notation: column letter + row number const col = String.fromCharCode(65 + position.x); // A, B, C, ... const row = position.y + 1; // 1, 2, 3, ... return `${col}${row}`; }