Spaces:
Sleeping
Sleeping
| /** | |
| * MCTS Data Parser | |
| * | |
| * Parses TGN game files and visit count JSON files to reconstruct | |
| * move-by-move MCTS statistics. | |
| */ | |
| import { TrigoGame } from "../../../inc/trigo/game"; | |
| import type { MCTSMoveData, MCTSMoveStatistic } from "../types/mcts"; | |
| /** | |
| * Parse MCTS data from TGN and visit counts JSON | |
| * | |
| * @param tgnContent - TGN game notation string | |
| * @param visitCountsJson - JSON string containing visit counts array | |
| * @returns Array of MCTS move data for each move in the game | |
| * @throws Error if parsing fails or data is invalid | |
| */ | |
| export async function parseMCTSData( | |
| tgnContent: string, | |
| visitCountsJson: string | |
| ): Promise<MCTSMoveData[]> { | |
| // 1. Parse TGN to get game | |
| let game: TrigoGame; | |
| try { | |
| game = TrigoGame.fromTGN(tgnContent); | |
| } catch (error) { | |
| throw new Error(`Failed to parse TGN file: ${error instanceof Error ? error.message : String(error)}`); | |
| } | |
| // 2. Validate board shape (must be 2D) | |
| const shape = game.getShape(); | |
| if (shape.z !== 1) { | |
| throw new Error( | |
| `Only 2D boards (z=1) are supported. This game has shape ${shape.x}×${shape.y}×${shape.z}` | |
| ); | |
| } | |
| // 3. Parse visit counts JSON | |
| let visitCounts: number[][]; | |
| try { | |
| visitCounts = JSON.parse(visitCountsJson); | |
| if (!Array.isArray(visitCounts)) { | |
| throw new Error("Visit counts must be an array"); | |
| } | |
| if (visitCounts.length === 0) { | |
| throw new Error("Visit counts array is empty"); | |
| } | |
| // Validate that each element is an array of numbers | |
| for (let i = 0; i < visitCounts.length; i++) { | |
| if (!Array.isArray(visitCounts[i])) { | |
| throw new Error(`Visit counts at index ${i} is not an array`); | |
| } | |
| } | |
| } catch (error) { | |
| throw new Error(`Failed to parse visit counts JSON: ${error instanceof Error ? error.message : String(error)}`); | |
| } | |
| // 4. Replay game move-by-move and map visit counts to positions | |
| const result: MCTSMoveData[] = []; | |
| const replayGame = new TrigoGame(shape, {}); | |
| const stepHistory = game.getHistory(); | |
| // Ensure we have visit counts for each move | |
| if (visitCounts.length !== stepHistory.length) { | |
| console.warn( | |
| `Visit counts length (${visitCounts.length}) doesn't match step history length (${stepHistory.length}). Using minimum.` | |
| ); | |
| } | |
| const minLength = Math.min(visitCounts.length, stepHistory.length); | |
| for (let i = 0; i < minLength; i++) { | |
| // Get current player | |
| const currentPlayer = replayGame.getCurrentPlayer() === 1 ? "black" : "white"; | |
| // Get all valid positions at this state | |
| const validPos = replayGame.validMovePositions(); | |
| // Map visit counts to positions | |
| const statistics: MCTSMoveStatistic[] = validPos.map((pos, idx) => ({ | |
| position: pos, | |
| actionKey: `${pos.x},${pos.y},${pos.z}`, | |
| N: visitCounts[i][idx] || 0, | |
| pi: 0 // Will be calculated after | |
| })); | |
| // Add pass move (typically the last element in visit counts) | |
| const passN = visitCounts[i][validPos.length] || 0; | |
| statistics.push({ | |
| position: null, | |
| actionKey: "pass", | |
| N: passN, | |
| pi: 0 | |
| }); | |
| // Calculate total visits | |
| const totalN = statistics.reduce((sum, stat) => sum + stat.N, 0); | |
| // Normalize to get search policy π | |
| if (totalN > 0) { | |
| statistics.forEach((stat) => { | |
| stat.pi = stat.N / totalN; | |
| }); | |
| } | |
| // Store move data | |
| result.push({ | |
| moveNumber: i, | |
| player: currentPlayer, | |
| gameState: replayGame.clone(), | |
| statistics | |
| }); | |
| // Apply move to advance game state | |
| const step = stepHistory[i]; | |
| if (step.type === 1) { // StepType.PASS | |
| replayGame.pass(); | |
| } else if (step.position) { | |
| replayGame.drop(step.position); | |
| } | |
| } | |
| if (result.length === 0) { | |
| throw new Error("No MCTS data could be parsed"); | |
| } | |
| console.log(`[MCTS Parser] Successfully parsed ${result.length} moves`); | |
| return result; | |
| } | |
| /** | |
| * Validate that visit counts data has the expected structure | |
| */ | |
| export function validateVisitCounts(visitCounts: any): visitCounts is number[][] { | |
| if (!Array.isArray(visitCounts)) { | |
| return false; | |
| } | |
| for (const counts of visitCounts) { | |
| if (!Array.isArray(counts)) { | |
| return false; | |
| } | |
| for (const count of counts) { | |
| if (typeof count !== "number" || !Number.isFinite(count) || count < 0) { | |
| return false; | |
| } | |
| } | |
| } | |
| return true; | |
| } | |
| /** | |
| * Format position for display | |
| */ | |
| export function formatPosition(position: { x: number; y: number; z?: number } | null): string { | |
| if (!position) { | |
| return "Pass"; | |
| } | |
| // Convert to chess-like notation: column letter + row number | |
| const col = String.fromCharCode(65 + position.x); // A, B, C, ... | |
| const row = position.y + 1; // 1, 2, 3, ... | |
| return `${col}${row}`; | |
| } | |