| | import marimo |
| |
|
| | __generated_with = "0.8.22" |
| | app = marimo.App(width="medium") |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import marimo as mo |
| | return (mo,) |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import pandas as pd |
| | df = pd.read_csv("datasets/test_set.csv") |
| | return df, pd |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | import pickle |
| | from utils import ChessBoard |
| | import onnxruntime as ort |
| | from leela_board import _idx_to_move_bn, _idx_to_move_wn |
| | import numpy as np |
| | from onnx2torch import convert |
| | import onnx |
| | import torch |
| | import os |
| |
|
| | def get_models(root="models"): |
| | paths = os.listdir(root) |
| | model_paths = [] |
| | for path in paths: |
| | if ".onnx" in path: model_paths.append(os.path.join(root, path)) |
| | return model_paths |
| |
|
| | def get_activations_from_model(model_path, pattern, fen): |
| | |
| | def register_hooks_for_capture(model, pattern): |
| | activations = {} |
| | def get_activation(name): |
| | def hook(module, input, output): |
| | activations[name] = output.detach().numpy() |
| | return hook |
| |
|
| | handles = [] |
| | for n, m in model.named_modules(): |
| | if pattern in n: |
| | handle = m.register_forward_hook(get_activation(n)) |
| | handles.append(handle) |
| | return activations, handles |
| |
|
| | |
| | model = convert(onnx.load(model_path)) |
| | act, handles = register_hooks_for_capture(model, pattern) |
| |
|
| | |
| | board = ChessBoard(fen) |
| | inputs = board.t |
| | _, _, _ = model(inputs.unsqueeze(dim=0)) |
| |
|
| | |
| | [h.remove() for h in handles] |
| | return act |
| | return ( |
| | ChessBoard, |
| | convert, |
| | get_activations_from_model, |
| | get_models, |
| | np, |
| | onnx, |
| | ort, |
| | os, |
| | pickle, |
| | torch, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def __(df, mo): |
| | min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100 |
| | elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)] |
| | dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}") |
| | dropdown_elo |
| | return dropdown_elo, elo_list, max_elo, min_elo |
| |
|
| |
|
| | @app.cell |
| | def __(df, dropdown_elo, mo): |
| | unique_themes = set() |
| | df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)] |
| | for i in range(len(df_rated)): |
| | themes = df_rated.iloc[i]["Themes"].split(" ") |
| | for theme in themes: unique_themes.add(theme) |
| | unique_themes_list = list(unique_themes) |
| | unique_themes_list.sort() |
| |
|
| | dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme") |
| | dropdown_themes |
| | return ( |
| | df_rated, |
| | dropdown_themes, |
| | i, |
| | theme, |
| | themes, |
| | unique_themes, |
| | unique_themes_list, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def __(df_rated, dropdown_themes): |
| | themes_mask = [] |
| | def _(themes_mask): |
| | for i in range(len(df_rated)): |
| | themes_new = df_rated.iloc[i]["Themes"].split(" ") |
| | if dropdown_themes.value in themes_new: themes_mask.append(i) |
| | _(themes_mask) |
| | fens = list(df_rated.iloc[themes_mask]["FEN"]) |
| | df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]] |
| | return fens, themes_mask |
| |
|
| |
|
| | @app.cell |
| | def __(fens, mo): |
| | dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN") |
| | dropdown_fen |
| | return (dropdown_fen,) |
| |
|
| |
|
| | @app.cell |
| | def __(df_rated, dropdown_fen, mo): |
| | moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ") |
| | player_moves = moves[1::2] |
| | board_moves = [] |
| | def _(board_moves): |
| | for i in range(len(player_moves)): |
| | board_moves.append(moves[:2 * i + 1]) |
| | _(board_moves) |
| | moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)} |
| | dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at") |
| | |
| | dropdown_moves |
| | return board_moves, dropdown_moves, moves, moves_dict, player_moves |
| |
|
| |
|
| | @app.cell |
| | def __(dropdown_moves, mo): |
| | dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)") |
| | focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...") |
| | mo.vstack([dropdown_layer, focus_square]) |
| | return dropdown_layer, focus_square |
| |
|
| |
|
| | @app.cell |
| | def __(ChessBoard, dropdown_fen, dropdown_moves): |
| | def _(): |
| | board = ChessBoard(dropdown_fen.value) |
| | for move in dropdown_moves.value: |
| | |
| | board.move(move) |
| | return board.board.pc_board.fen() |
| | FEN = _() |
| | return (FEN,) |
| |
|
| |
|
| | @app.cell |
| | def __(torch): |
| | def rollout(x, skip_last_layers=0, skip_connection=True, parse="min"): |
| | attns = [] |
| | for k, v in x.items(): |
| | v = v[0, :, ::-1, :] |
| | v = torch.tensor(v.copy()) |
| | if parse == "min": |
| | item = torch.min(v, dim=0).values |
| | elif parse == "max": |
| | item = torch.max(v, dim=0).values |
| | elif parse == "mean": |
| | item = torch.mean(v, dim=0) |
| | attns.append(item + int(skip_connection) * torch.eye(item.shape[0])) |
| | roll = torch.prod(torch.stack(attns)[:skip_last_layers], dim=0) |
| | return roll |
| | return (rollout,) |
| |
|
| |
|
| | @app.cell |
| | def __(mo): |
| | METHODS = ["Attention visualization", "Attention rollout (MIN)", "Attention rollout (MEAN)", "Attention rollout (MAX)"] |
| | dropdown_method = mo.ui.dropdown(options=METHODS, value=METHODS[0], label="Select XAI method") |
| | dropdown_method |
| | return METHODS, dropdown_method |
| |
|
| |
|
| | @app.cell |
| | def __(mo): |
| | switch = mo.ui.switch(value=False, label="To use skip connection in rollout") |
| | switch |
| | return (switch,) |
| |
|
| |
|
| | @app.cell |
| | def __(mo): |
| | max_value_switch = mo.ui.switch(value=False, label="use rollout layers with max value") |
| | max_value_switch |
| | return (max_value_switch,) |
| |
|
| |
|
| | @app.cell |
| | def __(mo): |
| | highlight_squares_switch = mo.ui.switch(value=False, label="Use rollout to highlight squares") |
| | highlight_squares_switch |
| | return (highlight_squares_switch,) |
| |
|
| |
|
| | @app.cell |
| | def __( |
| | dropdown_layer, |
| | dropdown_method, |
| | focus_square, |
| | highlight_squares_switch, |
| | max_value_switch, |
| | rollout, |
| | switch, |
| | torch, |
| | ): |
| | import chess |
| | from global_data import global_data |
| |
|
| | def find_max(a): |
| | ar = a.reshape(a.shape[0], -1) |
| | i = torch.max(ar, dim=1).values |
| | im = torch.argmax(i[1:]) |
| | return a[im + 1] |
| |
|
| | def parse_activations(act, layer_number=None): |
| | if dropdown_method.value == "Attention visualization": |
| | layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}") |
| | a = act[layer_key][0, :, ::-1 , :] |
| | elif dropdown_method.value == "Attention rollout (MIN)": |
| | if not max_value_switch.value: |
| | a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="min") |
| | else: |
| | a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="min") for _i in range(0, 15)], dim=0) |
| | a = find_max(a) |
| | elif dropdown_method.value == "Attention rollout (MAX)": |
| | if not max_value_switch.value: |
| | a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="max") |
| | else: |
| | a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="max") for _i in range(0, 15)], dim=0) |
| | a = find_max(a) |
| | elif dropdown_method.value == "Attention rollout (MEAN)": |
| | if not max_value_switch.value: |
| | a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="mean") |
| | else: |
| | a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="mean") for _i in range(0, 15)], dim=0) |
| | a = find_max(a) |
| | if dropdown_method.value != "Attention visualization": |
| | if highlight_squares_switch.value: |
| | a = a.max(dim=0).values |
| | a = torch.stack([a for _ in range(64)], dim=0) |
| | a = torch.stack([a for _ in range(32)]).numpy() |
| | return a |
| |
|
| | focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a") |
| |
|
| | def set_plotting_parameters(act, fen): |
| | global_data.model = 'test' |
| | global_data.activations = act |
| | print(global_data.activations.shape) |
| | global_data.subplot_rows = 8 |
| | global_data.subplot_cols = 4 |
| | global_data.board = chess.Board(fen) |
| | global_data.show_all_heads = True |
| | |
| | global_data.visualization_mode = 'ROW' |
| | global_data.focused_square_ind = focus_square_ind |
| | |
| |
|
| | global_data.visualization_mode_is_64x64 = False |
| | global_data.colorscale_mode = "mode1" |
| | global_data.show_colorscale = False |
| | return ( |
| | chess, |
| | find_max, |
| | focus_square_ind, |
| | global_data, |
| | parse_activations, |
| | set_plotting_parameters, |
| | ) |
| |
|
| |
|
| | @app.cell |
| | def __( |
| | FEN, |
| | dropdown_layer, |
| | get_activations_from_model, |
| | get_models, |
| | parse_activations, |
| | set_plotting_parameters, |
| | ): |
| | |
| | |
| | |
| | |
| | PATTERN = "mha/QK/softmax" |
| | |
| | MODEL = get_models()[-1] |
| | ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN) |
| | a = parse_activations(ACTIVATIONS, layer_number=int(dropdown_layer.value)) |
| | set_plotting_parameters(a, FEN) |
| | from activation_heatmap import heatmap_figure |
| | fig = heatmap_figure() |
| | fig.update_layout(height=1500, width=1200) |
| | fig |
| | return ACTIVATIONS, MODEL, PATTERN, a, fig, heatmap_figure |
| |
|
| |
|
| | @app.cell |
| | def __(): |
| | |
| | |
| | return |
| |
|
| |
|
| | if __name__ == "__main__": |
| | app.run() |
| |
|