| | import marimo |
| |
|
| | __generated_with = "0.8.22" |
| | app = marimo.App(width="medium") |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import marimo as mo |
| | return (mo,) |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import pandas as pd |
| | df = pd.read_csv("our_visualization/datasets/test_set.csv") |
| | df.head() |
| | return df, pd |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import pickle |
| | from utils import ChessBoard |
| | import onnxruntime as ort |
| | from leela_board import _idx_to_move_bn, _idx_to_move_wn |
| | import numpy as np |
| | from onnx2torch import convert |
| | import onnx |
| | import torch |
| | import os |
| |
|
| | def get_models(root="/Users/sereda/Documents/chessXAI/our_visualization/models"): |
| | paths = os.listdir(root) |
| | model_paths = [] |
| | for path in paths: |
| | if ".onnx" in path: model_paths.append(os.path.join(root, path)) |
| | return model_paths |
| |
|
| | def get_activations_from_model(model_path, pattern, fen): |
| | |
| | def register_hooks_for_capture(model, pattern): |
| | activations = {} |
| | def get_activation(name): |
| | def hook(module, input, output): |
| | activations[name] = output.detach().numpy() |
| | return hook |
| |
|
| | handles = [] |
| | for n, m in model.named_modules(): |
| | if pattern in n: |
| | handle = m.register_forward_hook(get_activation(n)) |
| | handles.append(handle) |
| | return activations, handles |
| |
|
| | |
| | model = convert(onnx.load(model_path)) |
| | act, handles = register_hooks_for_capture(model, pattern) |
| |
|
| | |
| | board = ChessBoard(fen) |
| | inputs = board.t |
| | _, _, _ = model(inputs.unsqueeze(dim=0)) |
| |
|
| | |
| | [h.remove() for h in handles] |
| | return act |
| | return ( |
| | ChessBoard, |
| | convert, |
| | get_activations_from_model, |
| | get_models, |
| | np, |
| | onnx, |
| | ort, |
| | os, |
| | pickle, |
| | torch, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def __(df, mo): |
| | min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100 |
| | elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)] |
| | dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}") |
| | dropdown_elo |
| | return dropdown_elo, elo_list, max_elo, min_elo |
| |
|
| |
|
| | @app.cell |
| | def __(df, dropdown_elo, mo): |
| | unique_themes = set() |
| | df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)] |
| | for i in range(len(df_rated)): |
| | themes = df_rated.iloc[i]["Themes"].split(" ") |
| | for theme in themes: unique_themes.add(theme) |
| | unique_themes_list = list(unique_themes) |
| | unique_themes_list.sort() |
| |
|
| | dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme") |
| | dropdown_themes |
| | return ( |
| | df_rated, |
| | dropdown_themes, |
| | i, |
| | theme, |
| | themes, |
| | unique_themes, |
| | unique_themes_list, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def __(df_rated, dropdown_themes): |
| | themes_mask = [] |
| | def _(themes_mask): |
| | for i in range(len(df_rated)): |
| | themes_new = df_rated.iloc[i]["Themes"].split(" ") |
| | if dropdown_themes.value in themes_new: themes_mask.append(i) |
| | _(themes_mask) |
| | fens = list(df_rated.iloc[themes_mask]["FEN"]) |
| | df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]] |
| | return fens, themes_mask |
| |
|
| |
|
| | @app.cell |
| | def __(fens, mo): |
| | dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN") |
| | dropdown_fen |
| | return (dropdown_fen,) |
| |
|
| |
|
| | @app.cell |
| | def __(df_rated, dropdown_fen, mo): |
| | moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ") |
| | player_moves = moves[1::2] |
| | board_moves = [] |
| | def _(board_moves): |
| | for i in range(len(player_moves)): |
| | board_moves.append(moves[:2 * i + 1]) |
| | _(board_moves) |
| | moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)} |
| | dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at") |
| | |
| | dropdown_moves |
| | return board_moves, dropdown_moves, moves, moves_dict, player_moves |
| |
|
| |
|
| | @app.cell |
| | def __(dropdown_moves, mo): |
| | dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)") |
| | focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...") |
| | mo.vstack([dropdown_layer, focus_square]) |
| | return dropdown_layer, focus_square |
| |
|
| |
|
| | @app.cell |
| | def __(ChessBoard, dropdown_fen, dropdown_moves): |
| | def _(): |
| | board = ChessBoard(dropdown_fen.value) |
| | for move in dropdown_moves.value: |
| | print(move) |
| | |
| | return board.board.pc_board.fen() |
| | FEN = _() |
| | return (FEN,) |
| |
|
| |
|
| | @app.cell |
| | def __(focus_square): |
| | import chess |
| | from global_data import global_data |
| |
|
| | focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a") |
| |
|
| | def set_plotting_parameters(act, layer_number, fen): |
| | layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}") |
| | print(act.keys()) |
| | global_data.model = 'test' |
| | global_data.activations = act[layer_key][0, :, ::-1 , :] |
| | print(global_data.activations.shape) |
| | global_data.subplot_rows = 8 |
| | global_data.subplot_cols = 4 |
| | global_data.board = chess.Board(fen) |
| | global_data.show_all_heads = True |
| | |
| | global_data.visualization_mode = 'ROW' |
| | global_data.focused_square_ind = focus_square_ind |
| | |
| |
|
| | global_data.visualization_mode_is_64x64 = False |
| | global_data.colorscale_mode = "mode1" |
| | global_data.show_colorscale = False |
| | return chess, focus_square_ind, global_data, set_plotting_parameters |
| |
|
| |
|
| | @app.cell |
| | def __( |
| | FEN, |
| | dropdown_layer, |
| | get_activations_from_model, |
| | get_models, |
| | set_plotting_parameters, |
| | ): |
| | |
| | |
| | |
| | |
| | PATTERN = "mha/QK/softmax" |
| | |
| | MODEL = get_models()[-1] |
| | ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN) |
| | set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN) |
| | from activation_heatmap import heatmap_figure |
| | fig = heatmap_figure() |
| | fig.update_layout(height=1500, width=1200) |
| | fig |
| | return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | |
| | |
| | return |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run() |
| |
|