trigo / trigo-web /app /src /utils /mctsDataParser.ts
k-l-lambda's picture
Update trigo-web with VS People multiplayer mode
15f353f
/**
* MCTS Data Parser
*
* Parses TGN game files and visit count JSON files to reconstruct
* move-by-move MCTS statistics.
*/
import { TrigoGame } from "../../../inc/trigo/game";
import type { MCTSMoveData, MCTSMoveStatistic } from "../types/mcts";
/**
* Parse MCTS data from TGN and visit counts JSON
*
* @param tgnContent - TGN game notation string
* @param visitCountsJson - JSON string containing visit counts array
* @returns Array of MCTS move data for each move in the game
* @throws Error if parsing fails or data is invalid
*/
export async function parseMCTSData(
tgnContent: string,
visitCountsJson: string
): Promise<MCTSMoveData[]> {
// 1. Parse TGN to get game
let game: TrigoGame;
try {
game = TrigoGame.fromTGN(tgnContent);
} catch (error) {
throw new Error(`Failed to parse TGN file: ${error instanceof Error ? error.message : String(error)}`);
}
// 2. Validate board shape (must be 2D)
const shape = game.getShape();
if (shape.z !== 1) {
throw new Error(
`Only 2D boards (z=1) are supported. This game has shape ${shape.x}×${shape.y}×${shape.z}`
);
}
// 3. Parse visit counts JSON
let visitCounts: number[][];
try {
visitCounts = JSON.parse(visitCountsJson);
if (!Array.isArray(visitCounts)) {
throw new Error("Visit counts must be an array");
}
if (visitCounts.length === 0) {
throw new Error("Visit counts array is empty");
}
// Validate that each element is an array of numbers
for (let i = 0; i < visitCounts.length; i++) {
if (!Array.isArray(visitCounts[i])) {
throw new Error(`Visit counts at index ${i} is not an array`);
}
}
} catch (error) {
throw new Error(`Failed to parse visit counts JSON: ${error instanceof Error ? error.message : String(error)}`);
}
// 4. Replay game move-by-move and map visit counts to positions
const result: MCTSMoveData[] = [];
const replayGame = new TrigoGame(shape, {});
const stepHistory = game.getHistory();
// Ensure we have visit counts for each move
if (visitCounts.length !== stepHistory.length) {
console.warn(
`Visit counts length (${visitCounts.length}) doesn't match step history length (${stepHistory.length}). Using minimum.`
);
}
const minLength = Math.min(visitCounts.length, stepHistory.length);
for (let i = 0; i < minLength; i++) {
// Get current player
const currentPlayer = replayGame.getCurrentPlayer() === 1 ? "black" : "white";
// Get all valid positions at this state
const validPos = replayGame.validMovePositions();
// Map visit counts to positions
const statistics: MCTSMoveStatistic[] = validPos.map((pos, idx) => ({
position: pos,
actionKey: `${pos.x},${pos.y},${pos.z}`,
N: visitCounts[i][idx] || 0,
pi: 0 // Will be calculated after
}));
// Add pass move (typically the last element in visit counts)
const passN = visitCounts[i][validPos.length] || 0;
statistics.push({
position: null,
actionKey: "pass",
N: passN,
pi: 0
});
// Calculate total visits
const totalN = statistics.reduce((sum, stat) => sum + stat.N, 0);
// Normalize to get search policy π
if (totalN > 0) {
statistics.forEach((stat) => {
stat.pi = stat.N / totalN;
});
}
// Store move data
result.push({
moveNumber: i,
player: currentPlayer,
gameState: replayGame.clone(),
statistics
});
// Apply move to advance game state
const step = stepHistory[i];
if (step.type === 1) { // StepType.PASS
replayGame.pass();
} else if (step.position) {
replayGame.drop(step.position);
}
}
if (result.length === 0) {
throw new Error("No MCTS data could be parsed");
}
console.log(`[MCTS Parser] Successfully parsed ${result.length} moves`);
return result;
}
/**
* Validate that visit counts data has the expected structure
*/
export function validateVisitCounts(visitCounts: any): visitCounts is number[][] {
if (!Array.isArray(visitCounts)) {
return false;
}
for (const counts of visitCounts) {
if (!Array.isArray(counts)) {
return false;
}
for (const count of counts) {
if (typeof count !== "number" || !Number.isFinite(count) || count < 0) {
return false;
}
}
}
return true;
}
/**
* Format position for display
*/
export function formatPosition(position: { x: number; y: number; z?: number } | null): string {
if (!position) {
return "Pass";
}
// Convert to chess-like notation: column letter + row number
const col = String.fromCharCode(65 + position.x); // A, B, C, ...
const row = position.y + 1; // 1, 2, 3, ...
return `${col}${row}`;
}