import marimo __generated_with = "0.8.22" app = marimo.App(width="medium") @app.cell def __(): import marimo as mo return (mo,) @app.cell def __(): import pandas as pd df = pd.read_csv("datasets/test_set.csv") return df, pd @app.cell def __(): import pickle from utils import ChessBoard import onnxruntime as ort from leela_board import _idx_to_move_bn, _idx_to_move_wn import numpy as np from onnx2torch import convert import onnx import torch import os def get_models(root="models"): paths = os.listdir(root) model_paths = [] for path in paths: if ".onnx" in path: model_paths.append(os.path.join(root, path)) return model_paths def get_activations_from_model(model_path, pattern, fen): # Write hooks for selected model path def register_hooks_for_capture(model, pattern): activations = {} def get_activation(name): def hook(module, input, output): activations[name] = output.detach().numpy() return hook handles = [] for n, m in model.named_modules(): if pattern in n: handle = m.register_forward_hook(get_activation(n)) handles.append(handle) return activations, handles # Load model and register hooks for it model = convert(onnx.load(model_path)) act, handles = register_hooks_for_capture(model, pattern) # Get fen and pass it through model to generate activations board = ChessBoard(fen) inputs = board.t _, _, _ = model(inputs.unsqueeze(dim=0)) # Remove handles [h.remove() for h in handles] return act return ( ChessBoard, convert, get_activations_from_model, get_models, np, onnx, ort, os, pickle, torch, ) @app.cell def __(df, mo): min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100 elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)] dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}") dropdown_elo return dropdown_elo, elo_list, max_elo, min_elo @app.cell def __(df, dropdown_elo, mo): unique_themes = set() df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)] for i in range(len(df_rated)): themes = df_rated.iloc[i]["Themes"].split(" ") for theme in themes: unique_themes.add(theme) unique_themes_list = list(unique_themes) unique_themes_list.sort() dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme") dropdown_themes return ( df_rated, dropdown_themes, i, theme, themes, unique_themes, unique_themes_list, ) @app.cell def __(df_rated, dropdown_themes): themes_mask = [] def _(themes_mask): for i in range(len(df_rated)): themes_new = df_rated.iloc[i]["Themes"].split(" ") if dropdown_themes.value in themes_new: themes_mask.append(i) _(themes_mask) fens = list(df_rated.iloc[themes_mask]["FEN"]) df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]] return fens, themes_mask @app.cell def __(fens, mo): dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN") dropdown_fen return (dropdown_fen,) @app.cell def __(df_rated, dropdown_fen, mo): moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ") player_moves = moves[1::2] board_moves = [] def _(board_moves): for i in range(len(player_moves)): board_moves.append(moves[:2 * i + 1]) _(board_moves) moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)} dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at") # print(moves) dropdown_moves return board_moves, dropdown_moves, moves, moves_dict, player_moves @app.cell def __(dropdown_moves, mo): dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)") focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...") mo.vstack([dropdown_layer, focus_square]) return dropdown_layer, focus_square @app.cell def __(ChessBoard, dropdown_fen, dropdown_moves): def _(): board = ChessBoard(dropdown_fen.value) for move in dropdown_moves.value: # print(move) board.move(move) return board.board.pc_board.fen() FEN = _() return (FEN,) @app.cell def __(torch): def rollout(x, skip_last_layers=0, skip_connection=True, parse="min"): attns = [] for k, v in x.items(): v = v[0, :, ::-1, :] v = torch.tensor(v.copy()) if parse == "min": item = torch.min(v, dim=0).values elif parse == "max": item = torch.max(v, dim=0).values elif parse == "mean": item = torch.mean(v, dim=0) attns.append(item + int(skip_connection) * torch.eye(item.shape[0])) roll = torch.prod(torch.stack(attns)[:skip_last_layers], dim=0) return roll return (rollout,) @app.cell def __(mo): METHODS = ["Attention visualization", "Attention rollout (MIN)", "Attention rollout (MEAN)", "Attention rollout (MAX)"] dropdown_method = mo.ui.dropdown(options=METHODS, value=METHODS[0], label="Select XAI method") dropdown_method return METHODS, dropdown_method @app.cell def __(mo): switch = mo.ui.switch(value=False, label="To use skip connection in rollout") switch return (switch,) @app.cell def __(mo): max_value_switch = mo.ui.switch(value=False, label="use rollout layers with max value") max_value_switch return (max_value_switch,) @app.cell def __(mo): highlight_squares_switch = mo.ui.switch(value=False, label="Use rollout to highlight squares") highlight_squares_switch return (highlight_squares_switch,) @app.cell def __( dropdown_layer, dropdown_method, focus_square, highlight_squares_switch, max_value_switch, rollout, switch, torch, ): import chess from global_data import global_data def find_max(a): ar = a.reshape(a.shape[0], -1) i = torch.max(ar, dim=1).values im = torch.argmax(i[1:]) return a[im + 1] def parse_activations(act, layer_number=None): if dropdown_method.value == "Attention visualization": layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}") a = act[layer_key][0, :, ::-1 , :] elif dropdown_method.value == "Attention rollout (MIN)": if not max_value_switch.value: a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="min") else: a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="min") for _i in range(0, 15)], dim=0) a = find_max(a) elif dropdown_method.value == "Attention rollout (MAX)": if not max_value_switch.value: a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="max") else: a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="max") for _i in range(0, 15)], dim=0) a = find_max(a) elif dropdown_method.value == "Attention rollout (MEAN)": if not max_value_switch.value: a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="mean") else: a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="mean") for _i in range(0, 15)], dim=0) a = find_max(a) if dropdown_method.value != "Attention visualization": if highlight_squares_switch.value: a = a.max(dim=0).values a = torch.stack([a for _ in range(64)], dim=0) a = torch.stack([a for _ in range(32)]).numpy() return a focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a") def set_plotting_parameters(act, fen): global_data.model = 'test' global_data.activations = act print(global_data.activations.shape) global_data.subplot_rows = 8 global_data.subplot_cols = 4 global_data.board = chess.Board(fen) global_data.show_all_heads = True # global_data.selected_head = 1 global_data.visualization_mode = 'ROW' global_data.focused_square_ind = focus_square_ind # global_data.heatmap_horizontal_gap = 0.001 global_data.visualization_mode_is_64x64 = False global_data.colorscale_mode = "mode1" global_data.show_colorscale = False return ( chess, find_max, focus_square_ind, global_data, parse_activations, set_plotting_parameters, ) @app.cell def __( FEN, dropdown_layer, get_activations_from_model, get_models, parse_activations, set_plotting_parameters, ): # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1" # board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14") # board.move("f3e5") # FEN = board.board.pc_board.fen() PATTERN = "mha/QK/softmax" # PATTERN = "smolgen_weights" MODEL = get_models()[-1] ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN) a = parse_activations(ACTIVATIONS, layer_number=int(dropdown_layer.value)) set_plotting_parameters(a, FEN) from activation_heatmap import heatmap_figure fig = heatmap_figure() fig.update_layout(height=1500, width=1200) fig return ACTIVATIONS, MODEL, PATTERN, a, fig, heatmap_figure @app.cell def __(): # Add fens after opponents moves # Default squares of interest return if __name__ == "__main__": app.run()