import marimo __generated_with = "0.8.22" app = marimo.App(width="medium") @app.cell def __(): import marimo as mo return (mo,) @app.cell def __(): import pandas as pd df = pd.read_csv("our_visualization/datasets/test_set.csv") df.head() return df, pd @app.cell def __(): import pickle from utils import ChessBoard import onnxruntime as ort from leela_board import _idx_to_move_bn, _idx_to_move_wn import numpy as np from onnx2torch import convert import onnx import torch import os def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"): paths = os.listdir(root) model_paths = [] for path in paths: if ".onnx" in path: model_paths.append(os.path.join(root, path)) return model_paths def get_activations_from_model(model_path, pattern, fen): # Write hooks for selected model path def register_hooks_for_capture(model, pattern): activations = {} def get_activation(name): def hook(module, input, output): activations[name] = output.detach().numpy() return hook handles = [] for n, m in model.named_modules(): if pattern in n: handle = m.register_forward_hook(get_activation(n)) handles.append(handle) return activations, handles # Load model and register hooks for it model = convert(onnx.load(model_path)) act, handles = register_hooks_for_capture(model, pattern) # Get fen and pass it through model to generate activations board = ChessBoard(fen) inputs = board.t _, _, _ = model(inputs.unsqueeze(dim=0)) # Remove handles [h.remove() for h in handles] return act return ( ChessBoard, convert, get_activations_from_model, get_models, np, onnx, ort, os, pickle, torch, ) @app.cell def __(df, mo): min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100 elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)] dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}") dropdown_elo return dropdown_elo, elo_list, max_elo, min_elo @app.cell def __(df, dropdown_elo, mo): unique_themes = set() df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)] for i in range(len(df_rated)): themes = df_rated.iloc[i]["Themes"].split(" ") for theme in themes: unique_themes.add(theme) unique_themes_list = list(unique_themes) unique_themes_list.sort() dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme") dropdown_themes return ( df_rated, dropdown_themes, i, theme, themes, unique_themes, unique_themes_list, ) @app.cell def __(df_rated, dropdown_themes): themes_mask = [] def _(themes_mask): for i in range(len(df_rated)): themes_new = df_rated.iloc[i]["Themes"].split(" ") if dropdown_themes.value in themes_new: themes_mask.append(i) _(themes_mask) fens = list(df_rated.iloc[themes_mask]["FEN"]) df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]] return fens, themes_mask @app.cell def __(fens, mo): dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN") dropdown_fen return (dropdown_fen,) @app.cell def __(df_rated, dropdown_fen, mo): moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ") player_moves = moves[1::2] board_moves = [] def _(board_moves): for i in range(len(player_moves)): board_moves.append(moves[:2 * i + 1]) _(board_moves) moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)} dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at") # print(moves) dropdown_moves return board_moves, dropdown_moves, moves, moves_dict, player_moves @app.cell def __(dropdown_moves, mo): dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)") focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...") mo.vstack([dropdown_layer, focus_square]) return dropdown_layer, focus_square @app.cell def __(ChessBoard, dropdown_fen, dropdown_moves): def _(): board = ChessBoard(dropdown_fen.value) for move in dropdown_moves.value: print(move) # board.move(move) return board.board.pc_board.fen() FEN = _() return (FEN,) @app.cell def __(focus_square): import chess from global_data import global_data focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a") def set_plotting_parameters(act, layer_number, fen): layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}") print(act.keys()) global_data.model = 'test' global_data.activations = act[layer_key][0, :, ::-1 , :] print(global_data.activations.shape) global_data.subplot_rows = 8 global_data.subplot_cols = 4 global_data.board = chess.Board(fen) global_data.show_all_heads = True # global_data.selected_head = 1 global_data.visualization_mode = 'ROW' global_data.focused_square_ind = focus_square_ind # global_data.heatmap_horizontal_gap = 0.001 global_data.visualization_mode_is_64x64 = False global_data.colorscale_mode = "mode1" global_data.show_colorscale = False return chess, focus_square_ind, global_data, set_plotting_parameters @app.cell def __( FEN, dropdown_layer, get_activations_from_model, get_models, set_plotting_parameters, ): # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" # board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14") # board.move("f3e5") # FEN = board.board.pc_board.fen() PATTERN = "mha/QK/softmax" # PATTERN = "smolgen_weights" MODEL = get_models()[-1] ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN) set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN) from activation_heatmap import heatmap_figure fig = heatmap_figure() fig.update_layout(height=1500, width=1200) fig return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure @app.cell def __(): # Add fens after opponents moves # Default squares of interest return if __name__ == "__main__": app.run()