trigo / trigo-web /app /src /views /TrigoTreeTestView.vue
k-l-lambda's picture
updated
502af73
<template>
<div class="tree-test-container">
<h1>Tree Attention Test - Trigo AI Agent</h1>
<div class="info-section">
<h2>Evaluation Mode with Tree Attention</h2>
<p>
This test uses the evaluation mode ONNX model to evaluate all valid moves in
parallel using tree attention. Each move is an independent branch that can only see
the prefix and itself.
</p>
</div>
<!-- Agent Status -->
<div class="status-section">
<h3>Agent Status</h3>
<div class="status-grid">
<div class="status-item">
<span class="label">Model:</span>
<span class="value">{{ modelPath }}</span>
</div>
<div class="status-item">
<span class="label">Status:</span>
<span class="value" :class="statusClass">{{ agentStatus }}</span>
</div>
<div class="status-item">
<span class="label">Vocab Size:</span>
<span class="value">{{ vocabSize }}</span>
</div>
</div>
<button
@click="initializeAgent"
:disabled="isInitializing || isReady"
class="btn-primary"
>
{{
isInitializing
? "Initializing..."
: isReady
? "✓ Agent Ready"
: "Initialize Agent"
}}
</button>
</div>
<!-- Game Board -->
<div class="board-section">
<h3>Game Board (5×5×5)</h3>
<div class="board-info">
<div class="info-item">
<span class="label">Current Player:</span>
<span class="value">{{ currentPlayer }}</span>
</div>
<div class="info-item">
<span class="label">Move Count:</span>
<span class="value">{{ moveCount }}</span>
</div>
<div class="info-item">
<span class="label">Valid Moves:</span>
<span class="value">{{ validMovesCount }}</span>
</div>
</div>
<div class="board-display">
<pre>{{ boardDisplay }}</pre>
</div>
<button @click="resetGame" class="btn-secondary">Reset Game</button>
</div>
<!-- Move Generation -->
<div class="generation-section">
<h3>Tree Attention Move Generation</h3>
<div class="button-group">
<button
@click="generateMoves"
:disabled="!isReady || isGenerating"
class="btn-primary btn-large"
>
{{ isGenerating ? "Generating..." : "Score All Moves (Tree Attention)" }}
</button>
<button
@click="generateBestMove"
:disabled="!isReady || isGenerating"
class="btn-primary btn-large"
>
{{ isGenerating ? "Generating..." : "Generate & Apply Best Move" }}
</button>
</div>
<div v-if="generationTime !== null" class="timing-info">
<strong>Generation Time:</strong> {{ generationTime }}ms
<span class="timing-detail">
(~{{ (generationTime / validMovesCount).toFixed(1) }}ms per move)
</span>
</div>
<div v-if="bestMoveApplied" class="success-info">
<strong>Best Move Applied:</strong> {{ bestMoveApplied }}
</div>
<!-- Scored Moves Display -->
<div v-if="scoredMoves.length > 0" class="moves-section">
<h4>Scored Moves (Top 20)</h4>
<div class="moves-table-container">
<table class="moves-table">
<thead>
<tr>
<th>Rank</th>
<th>Notation</th>
<th>Position</th>
<th>Log Prob</th>
<th>Probability</th>
<th>Action</th>
</tr>
</thead>
<tbody>
<tr
v-for="(scoredMove, index) in topMoves"
:key="index"
:class="{ 'top-move': index === 0 }"
>
<td>{{ index + 1 }}</td>
<td class="notation">{{ scoredMove.notation }}</td>
<td class="position">{{ formatPosition(scoredMove.move) }}</td>
<td class="score">{{ scoredMove.score.toFixed(3) }}</td>
<td class="probability">
{{ ((scoredMove.probability ?? 0) * 100).toFixed(4) }}%
</td>
<td>
<button @click="applyMove(scoredMove.move)" class="btn-small">
Apply
</button>
</td>
</tr>
</tbody>
</table>
</div>
</div>
<!-- Error Display -->
<div v-if="errorMessage" class="error-section">
<h4>Error</h4>
<pre>{{ errorMessage }}</pre>
</div>
</div>
<!-- Tree Visualization -->
<div v-if="treeVisualization" class="visualization-section">
<h3>Tree Structure Visualization</h3>
<!-- Move Details -->
<div class="move-details">
<h4>Move Details</h4>
<div class="details-grid">
<div
v-for="(move, index) in treeVisualization.moveData"
:key="index"
class="move-detail-item"
>
<span class="move-notation">{{ move.notation }}</span>
<span class="move-positions"
>parent={{ move.parentPos }}, leaf={{ move.leafPos }}</span
>
</div>
</div>
</div>
<!-- Token Sequence -->
<div class="token-sequence">
<h4>
Evaluated Token Sequence ({{ treeVisualization.evaluatedIds.length }} tokens)
</h4>
<div class="tokens-display">
<div
v-for="(tokenId, index) in treeVisualization.evaluatedIds"
:key="index"
class="token-item"
>
<div class="token-pos">{{ index }}</div>
<div class="token-char">{{ String.fromCharCode(tokenId) }}</div>
<div class="token-id">{{ tokenId }}</div>
</div>
</div>
</div>
<!-- Attention Mask Matrix -->
<div class="mask-matrix">
<h4>
Attention Mask Matrix ({{ Math.sqrt(treeVisualization.mask.length) }}×{{
Math.sqrt(treeVisualization.mask.length)
}})
</h4>
<div class="matrix-container">
<table class="matrix-table">
<thead>
<tr>
<th class="corner-cell"></th>
<th
v-for="(tokenId, col) in treeVisualization.evaluatedIds"
:key="col"
class="header-cell"
>
<div class="header-content">
<div class="header-pos">{{ col }}</div>
<div class="header-char">
{{ String.fromCharCode(tokenId) }}
</div>
</div>
</th>
</tr>
</thead>
<tbody>
<tr v-for="(tokenId, row) in treeVisualization.evaluatedIds" :key="row">
<th class="row-header">
<div class="row-content">
<div class="row-pos">{{ row }}</div>
<div class="row-char">
{{ String.fromCharCode(tokenId) }}
</div>
</div>
</th>
<td
v-for="col in treeVisualization.evaluatedIds.length"
:key="col - 1"
:class="getMaskCellClass(row, col - 1)"
:title="`[${row},${col - 1}] = ${getMaskValue(row, col - 1)}`"
>
{{ getMaskValue(row, col - 1) }}
</td>
</tr>
</tbody>
</table>
</div>
<div class="matrix-legend">
<span class="legend-item"
><span class="legend-box active"></span> 1 (can attend)</span
>
<span class="legend-item"
><span class="legend-box inactive"></span> 0 (cannot attend)</span
>
</div>
</div>
</div>
<!-- Debug Info -->
<div class="debug-section">
<h3>Debug Information</h3>
<div class="debug-content">
<div class="debug-item">
<span class="label">Current TGN:</span>
<pre class="debug-value">{{ currentTGN }}</pre>
</div>
</div>
</div>
</div>
</template>
<script setup lang="ts">
import { ref, computed, onMounted } from "vue";
import * as ort from "onnxruntime-web";
import { TrigoGame } from "../../../inc/trigo/game";
import { ModelInferencer } from "../../../inc/modelInferencer";
import { TrigoTreeAgent } from "../../../inc/trigoTreeAgent";
import type { ScoredMove } from "../../../inc/trigoTreeAgent";
import type { Move } from "../../../inc/trigo/types";
// Configuration
const modelPath = "/onnx/GPT2CausalLM_ep0015_evaluation.onnx";
const vocabSize = 259;
// State
const game = ref<TrigoGame>(new TrigoGame({ x: 5, y: 5, z: 5 }, {}));
const inferencer = ref<ModelInferencer | null>(null);
const agent = ref<TrigoTreeAgent | null>(null);
const isInitializing = ref(false);
const isReady = ref(false);
const isGenerating = ref(false);
const scoredMoves = ref<ScoredMove[]>([]);
const generationTime = ref<number | null>(null);
const errorMessage = ref<string>("");
const bestMoveApplied = ref<string>("");
// Tree visualization data
const treeVisualization = ref<{
evaluatedIds: number[];
mask: number[];
moveData: Array<{ notation: string; parentPos: number; leafPos: number }>;
} | null>(null);
// Computed properties
const agentStatus = computed(() => {
if (isReady.value) return "Ready";
if (isInitializing.value) return "Initializing...";
return "Not initialized";
});
const statusClass = computed(() => {
if (isReady.value) return "status-ready";
if (isInitializing.value) return "status-loading";
return "status-error";
});
const currentPlayer = computed(() => {
const player = game.value.getCurrentPlayer();
return player === 1 ? "Black" : "White";
});
const moveCount = computed(() => game.value.stepHistory.length);
const validMovesCount = computed(() => {
return game.value.validMovePositions().length + 1; // +1 for pass
});
const currentTGN = computed(() => game.value.toTGN());
// Compute normalized probabilities for scored moves (softmax across move dimension)
const movesWithProbability = computed(() => {
if (scoredMoves.value.length === 0) {
return [];
}
// Find max score for numerical stability
const maxScore = Math.max(...scoredMoves.value.map((m) => m.score));
// Compute exp(score - maxScore) for each move
const expScores = scoredMoves.value.map((m) => Math.exp(m.score - maxScore));
// Compute sum of exp scores
const sumExp = expScores.reduce((sum, exp) => sum + exp, 0);
// Return moves with normalized probabilities
return scoredMoves.value.map((move, i) => ({
...move,
probability: expScores[i] / sumExp
}));
});
const topMoves = computed(() => movesWithProbability.value.slice(0, 20));
const boardDisplay = computed(() => {
// Simple text representation of the board
const shape = game.value.getShape();
let display = `Board: ${shape.x}×${shape.y}×${shape.z}\n\n`;
display += `Moves played: ${moveCount.value}\n`;
display += `Current player: ${currentPlayer.value}\n`;
display += `Valid positions: ${validMovesCount.value - 1} (+ pass)\n`;
return display;
});
// Methods
async function initializeAgent() {
isInitializing.value = true;
errorMessage.value = "";
try {
console.log("Loading evaluation model:", modelPath);
// Create inference session
const session = await ort.InferenceSession.create(modelPath);
console.log("✓ Model loaded successfully");
// Create inferencer
const inf = new ModelInferencer(ort.Tensor as any, {
vocabSize,
seqLen: 2048,
modelPath
});
inf.setSession(session as any);
inferencer.value = inf;
// Create agent
agent.value = new TrigoTreeAgent(inf);
isReady.value = true;
console.log("✓ Agent initialized");
} catch (error) {
console.error("Failed to initialize agent:", error);
errorMessage.value = String(error);
isReady.value = false;
} finally {
isInitializing.value = false;
}
}
async function generateMoves() {
if (!agent.value || !isReady.value) {
errorMessage.value = "Agent not ready";
return;
}
isGenerating.value = true;
errorMessage.value = "";
scoredMoves.value = [];
generationTime.value = null;
try {
console.log("Generating moves with tree attention...");
const startTime = performance.now();
// Get all valid moves
const currentPlayer = game.value.getCurrentPlayer() === 1 ? "black" : "white";
const validPositions = game.value.validMovePositions();
const moves: Move[] = validPositions.map((pos) => ({
x: pos.x,
y: pos.y,
z: pos.z,
player: currentPlayer
}));
moves.push({ player: currentPlayer, isPass: true }); // Add pass
console.log(`Scoring ${moves.length} moves...`);
// Get tree structure for visualization
const treeStructure = agent.value.getTreeStructure(game.value, moves);
treeVisualization.value = {
evaluatedIds: treeStructure.evaluatedIds,
mask: treeStructure.mask,
moveData: treeStructure.moveData.map((m) => ({
notation: m.notation,
parentPos: m.parentPos,
leafPos: m.leafPos
}))
};
// Score all moves using tree attention
const scored = await agent.value.scoreMoves(game.value, moves);
const endTime = performance.now();
generationTime.value = Math.round(endTime - startTime);
scoredMoves.value = scored.sort((a, b) => b.score - a.score);
console.log(`✓ Scored ${scored.length} moves in ${generationTime.value}ms`);
console.log(
"Top 5 moves:",
movesWithProbability.value.slice(0, 5).map((m) => ({
notation: m.notation,
score: m.score.toFixed(3),
prob: ((m.probability ?? 0) * 100).toFixed(4) + "%"
}))
);
} catch (error) {
console.error("Failed to generate moves:", error);
errorMessage.value = String(error);
} finally {
isGenerating.value = false;
}
}
async function generateBestMove() {
if (!agent.value || !isReady.value) {
errorMessage.value = "Agent not ready";
return;
}
isGenerating.value = true;
errorMessage.value = "";
bestMoveApplied.value = "";
scoredMoves.value = [];
generationTime.value = null;
try {
console.log("Generating best move with tree attention...");
const startTime = performance.now();
// Use selectBestMove to get the best move
const bestMove = await agent.value.selectBestMove(game.value);
const endTime = performance.now();
generationTime.value = Math.round(endTime - startTime);
if (!bestMove) {
errorMessage.value = "No valid moves available";
return;
}
// Apply the best move
let success = false;
let moveNotation = "";
if (bestMove.isPass) {
success = game.value.pass();
moveNotation = "pass";
} else if (
bestMove.x !== undefined &&
bestMove.y !== undefined &&
bestMove.z !== undefined
) {
success = game.value.drop({ x: bestMove.x, y: bestMove.y, z: bestMove.z });
const shape = game.value.getShape();
const posArray = [bestMove.x, bestMove.y, bestMove.z];
const shapeArray = [shape.x, shape.y, shape.z];
// Import encodeAb0yz at the top if needed, or just show position
moveNotation = `(${bestMove.x}, ${bestMove.y}, ${bestMove.z})`;
}
if (success) {
bestMoveApplied.value = moveNotation;
console.log(`✓ Best move applied: ${moveNotation} in ${generationTime.value}ms`);
} else {
errorMessage.value = "Failed to apply best move";
}
} catch (error) {
console.error("Failed to generate best move:", error);
errorMessage.value = String(error);
} finally {
isGenerating.value = false;
}
}
function applyMove(move: Move) {
try {
let success = false;
if (move.isPass) {
success = game.value.pass();
} else if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
success = game.value.drop({ x: move.x, y: move.y, z: move.z });
}
if (success) {
console.log("✓ Move applied");
scoredMoves.value = [];
generationTime.value = null;
} else {
errorMessage.value = "Failed to apply move";
}
} catch (error) {
console.error("Error applying move:", error);
errorMessage.value = String(error);
}
}
function resetGame() {
game.value = new TrigoGame({ x: 5, y: 5, z: 5 }, {});
scoredMoves.value = [];
generationTime.value = null;
errorMessage.value = "";
bestMoveApplied.value = "";
console.log("✓ Game reset");
}
function formatPosition(move: Move): string {
if (move.isPass) return "Pass";
if (move.x !== undefined && move.y !== undefined && move.z !== undefined) {
return `(${move.x}, ${move.y}, ${move.z})`;
}
return "Unknown";
}
// Mask matrix helper functions
function getMaskValue(row: number, col: number): number {
if (!treeVisualization.value) return 0;
const m = treeVisualization.value.evaluatedIds.length;
return treeVisualization.value.mask[row * m + col];
}
function getMaskCellClass(row: number, col: number): string {
const value = getMaskValue(row, col);
return value === 1 ? "mask-cell mask-active" : "mask-cell mask-inactive";
}
// Lifecycle
onMounted(() => {
console.log("TrigoTreeTestView mounted");
});
</script>
<style scoped>
.tree-test-container {
max-width: 1200px;
margin: 0 auto;
padding: 2rem;
font-family: -apple-system, BlinkMacSystemFont, "Segoe UI", Roboto, sans-serif;
min-height: 100vh;
overflow-y: auto;
}
h1 {
color: #2c3e50;
margin-bottom: 1rem;
}
h2,
h3,
h4 {
color: #34495e;
margin-top: 2rem;
margin-bottom: 1rem;
}
.info-section {
background: #f8f9fa;
padding: 1.5rem;
border-radius: 8px;
margin-bottom: 2rem;
}
.info-section p {
margin: 0;
line-height: 1.6;
color: #555;
}
.status-section,
.board-section,
.generation-section,
.debug-section {
background: white;
border: 1px solid #e1e4e8;
border-radius: 8px;
padding: 1.5rem;
margin-bottom: 2rem;
}
.status-grid {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(250px, 1fr));
gap: 1rem;
margin-bottom: 1rem;
}
.status-item,
.info-item {
display: flex;
justify-content: space-between;
padding: 0.5rem;
background: #f8f9fa;
border-radius: 4px;
}
.label {
font-weight: 600;
color: #666;
}
.value {
color: #2c3e50;
}
.status-ready {
color: #28a745;
font-weight: 600;
}
.status-loading {
color: #ffc107;
font-weight: 600;
}
.status-error {
color: #dc3545;
font-weight: 600;
}
.board-info {
display: grid;
grid-template-columns: repeat(auto-fit, minmax(200px, 1fr));
gap: 1rem;
margin-bottom: 1rem;
}
.board-display {
background: #f8f9fa;
border: 1px solid #e1e4e8;
border-radius: 4px;
padding: 1rem;
margin-bottom: 1rem;
}
.board-display pre {
margin: 0;
font-family: "Courier New", monospace;
font-size: 0.9rem;
color: #2c3e50;
}
.button-group {
display: flex;
gap: 1rem;
flex-wrap: wrap;
margin-bottom: 1rem;
}
.btn-primary,
.btn-secondary,
.btn-small {
padding: 0.75rem 1.5rem;
border: none;
border-radius: 4px;
font-size: 1rem;
font-weight: 600;
cursor: pointer;
transition: all 0.2s;
}
.btn-primary {
background: #007bff;
color: white;
}
.btn-primary:hover:not(:disabled) {
background: #0056b3;
}
.btn-primary:disabled {
background: #6c757d;
cursor: not-allowed;
opacity: 0.6;
}
.btn-secondary {
background: #6c757d;
color: white;
}
.btn-secondary:hover {
background: #545b62;
}
.btn-large {
padding: 1rem 2rem;
font-size: 1.1rem;
}
.btn-small {
padding: 0.25rem 0.75rem;
font-size: 0.875rem;
}
.timing-info {
margin-top: 1rem;
padding: 1rem;
background: #e7f5ff;
border-left: 4px solid #007bff;
border-radius: 4px;
}
.timing-detail {
color: #666;
font-size: 0.9rem;
margin-left: 0.5rem;
}
.success-info {
margin-top: 1rem;
padding: 1rem;
background: #d4edda;
border-left: 4px solid #28a745;
border-radius: 4px;
color: #155724;
font-weight: 600;
}
.moves-section {
margin-top: 2rem;
}
.moves-table-container {
overflow-x: auto;
margin-top: 1rem;
}
.moves-table {
width: 100%;
border-collapse: collapse;
font-size: 0.9rem;
}
.moves-table thead {
background: #f8f9fa;
}
.moves-table th {
padding: 0.75rem;
text-align: left;
font-weight: 600;
color: #495057;
border-bottom: 2px solid #dee2e6;
}
.moves-table td {
padding: 0.75rem;
border-bottom: 1px solid #dee2e6;
}
.moves-table tbody tr:hover {
background: #f8f9fa;
}
.moves-table .top-move {
background: #d4edda;
font-weight: 600;
}
.notation {
font-family: "Courier New", monospace;
font-weight: 600;
color: #007bff;
}
.position {
font-family: "Courier New", monospace;
color: #666;
}
.score {
font-family: "Courier New", monospace;
text-align: right;
}
.probability {
font-family: "Courier New", monospace;
text-align: right;
color: #28a745;
}
.error-section {
margin-top: 1rem;
padding: 1rem;
background: #f8d7da;
border: 1px solid #f5c6cb;
border-radius: 4px;
color: #721c24;
}
.error-section pre {
margin: 0.5rem 0 0 0;
white-space: pre-wrap;
word-wrap: break-word;
font-size: 0.9rem;
}
.debug-section {
background: #f8f9fa;
}
.debug-content {
display: flex;
flex-direction: column;
gap: 1rem;
}
.debug-item {
display: flex;
flex-direction: column;
gap: 0.5rem;
}
.debug-value {
background: white;
border: 1px solid #e1e4e8;
border-radius: 4px;
padding: 1rem;
margin: 0;
font-family: "Courier New", monospace;
font-size: 0.85rem;
overflow-x: auto;
color: #2c3e50;
}
/* Tree Visualization Styles */
.visualization-section {
background: white;
border: 1px solid #e1e4e8;
border-radius: 8px;
padding: 1.5rem;
margin-bottom: 2rem;
}
.move-details {
margin-bottom: 2rem;
}
.details-grid {
display: grid;
grid-template-columns: repeat(auto-fill, minmax(200px, 1fr));
gap: 0.5rem;
margin-top: 1rem;
}
.move-detail-item {
display: flex;
flex-direction: column;
padding: 0.5rem;
background: #f8f9fa;
border-radius: 4px;
font-size: 0.85rem;
}
.move-notation {
font-family: "Courier New", monospace;
font-weight: 600;
color: #007bff;
font-size: 1rem;
}
.move-positions {
font-family: "Courier New", monospace;
color: #666;
font-size: 0.75rem;
margin-top: 0.25rem;
}
.token-sequence {
margin-bottom: 2rem;
}
.tokens-display {
display: flex;
flex-wrap: wrap;
gap: 0.5rem;
margin-top: 1rem;
}
.token-item {
display: flex;
flex-direction: column;
align-items: center;
padding: 0.5rem;
background: #e7f5ff;
border: 1px solid #007bff;
border-radius: 4px;
min-width: 50px;
}
.token-pos {
font-size: 0.7rem;
color: #666;
font-weight: 600;
}
.token-char {
font-family: "Courier New", monospace;
font-size: 1.2rem;
font-weight: 600;
color: #007bff;
margin: 0.25rem 0;
}
.token-id {
font-size: 0.7rem;
color: #666;
font-family: "Courier New", monospace;
}
.mask-matrix {
margin-top: 2rem;
}
.matrix-container {
overflow: auto;
margin-top: 1rem;
max-height: 600px;
border: 1px solid #e1e4e8;
border-radius: 4px;
}
.matrix-table {
border-collapse: collapse;
font-size: 0.75rem;
min-width: 100%;
}
.corner-cell {
background: #f8f9fa;
border: 1px solid #dee2e6;
position: sticky;
left: 0;
top: 0;
z-index: 3;
}
.header-cell {
background: #f8f9fa;
border: 1px solid #dee2e6;
padding: 0.5rem;
position: sticky;
top: 0;
z-index: 2;
text-align: center;
min-width: 40px;
}
.header-content {
display: flex;
flex-direction: column;
align-items: center;
}
.header-pos {
font-size: 0.65rem;
color: #666;
font-weight: 600;
}
.header-char {
font-family: "Courier New", monospace;
font-size: 0.9rem;
font-weight: 600;
color: #007bff;
margin-top: 0.1rem;
}
.row-header {
background: #f8f9fa;
border: 1px solid #dee2e6;
padding: 0.5rem;
position: sticky;
left: 0;
z-index: 1;
text-align: center;
min-width: 50px;
}
.row-content {
display: flex;
flex-direction: column;
align-items: center;
}
.row-pos {
font-size: 0.65rem;
color: #666;
font-weight: 600;
}
.row-char {
font-family: "Courier New", monospace;
font-size: 0.9rem;
font-weight: 600;
color: #007bff;
margin-top: 0.1rem;
}
.mask-cell {
border: 1px solid #dee2e6;
padding: 0.5rem;
text-align: center;
font-family: "Courier New", monospace;
font-weight: 600;
font-size: 0.8rem;
}
.mask-active {
background: #d4edda;
color: #155724;
}
.mask-inactive {
background: #f8f9fa;
color: #ccc;
}
.matrix-legend {
display: flex;
gap: 1.5rem;
margin-top: 1rem;
padding: 0.75rem;
background: #f8f9fa;
border-radius: 4px;
}
.legend-item {
display: flex;
align-items: center;
gap: 0.5rem;
font-size: 0.85rem;
}
.legend-box {
width: 20px;
height: 20px;
border: 1px solid #dee2e6;
border-radius: 3px;
}
.legend-box.active {
background: #d4edda;
}
.legend-box.inactive {
background: #f8f9fa;
}
</style>