chessXAI / app.py
seredapj's picture
Update app.py
5bb75ee verified
import marimo
__generated_with = "0.8.22"
app = marimo.App(width="medium")
@app.cell
def __():
import marimo as mo
return (mo,)
@app.cell
def __():
import pandas as pd
df = pd.read_csv("datasets/test_set.csv")
return df, pd
@app.cell
def __():
import pickle
from utils import ChessBoard
import onnxruntime as ort
from leela_board import _idx_to_move_bn, _idx_to_move_wn
import numpy as np
from onnx2torch import convert
import onnx
import torch
import os
def get_models(root="models"):
paths = os.listdir(root)
model_paths = []
for path in paths:
if ".onnx" in path: model_paths.append(os.path.join(root, path))
return model_paths
def get_activations_from_model(model_path, pattern, fen):
# Write hooks for selected model path
def register_hooks_for_capture(model, pattern):
activations = {}
def get_activation(name):
def hook(module, input, output):
activations[name] = output.detach().numpy()
return hook
handles = []
for n, m in model.named_modules():
if pattern in n:
handle = m.register_forward_hook(get_activation(n))
handles.append(handle)
return activations, handles
# Load model and register hooks for it
model = convert(onnx.load(model_path))
act, handles = register_hooks_for_capture(model, pattern)
# Get fen and pass it through model to generate activations
board = ChessBoard(fen)
inputs = board.t
_, _, _ = model(inputs.unsqueeze(dim=0))
# Remove handles
[h.remove() for h in handles]
return act
return (
ChessBoard,
convert,
get_activations_from_model,
get_models,
np,
onnx,
ort,
os,
pickle,
torch,
)
@app.cell
def __(df, mo):
min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100
elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)]
dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}")
dropdown_elo
return dropdown_elo, elo_list, max_elo, min_elo
@app.cell
def __(df, dropdown_elo, mo):
unique_themes = set()
df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)]
for i in range(len(df_rated)):
themes = df_rated.iloc[i]["Themes"].split(" ")
for theme in themes: unique_themes.add(theme)
unique_themes_list = list(unique_themes)
unique_themes_list.sort()
dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme")
dropdown_themes
return (
df_rated,
dropdown_themes,
i,
theme,
themes,
unique_themes,
unique_themes_list,
)
@app.cell
def __(df_rated, dropdown_themes):
themes_mask = []
def _(themes_mask):
for i in range(len(df_rated)):
themes_new = df_rated.iloc[i]["Themes"].split(" ")
if dropdown_themes.value in themes_new: themes_mask.append(i)
_(themes_mask)
fens = list(df_rated.iloc[themes_mask]["FEN"])
df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]]
return fens, themes_mask
@app.cell
def __(fens, mo):
dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN")
dropdown_fen
return (dropdown_fen,)
@app.cell
def __(df_rated, dropdown_fen, mo):
moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ")
player_moves = moves[1::2]
board_moves = []
def _(board_moves):
for i in range(len(player_moves)):
board_moves.append(moves[:2 * i + 1])
_(board_moves)
moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)}
dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at")
# print(moves)
dropdown_moves
return board_moves, dropdown_moves, moves, moves_dict, player_moves
@app.cell
def __(dropdown_moves, mo):
dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)")
focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...")
mo.vstack([dropdown_layer, focus_square])
return dropdown_layer, focus_square
@app.cell
def __(ChessBoard, dropdown_fen, dropdown_moves):
def _():
board = ChessBoard(dropdown_fen.value)
for move in dropdown_moves.value:
# print(move)
board.move(move)
return board.board.pc_board.fen()
FEN = _()
return (FEN,)
@app.cell
def __(torch):
def rollout(x, skip_last_layers=0, skip_connection=True, parse="min"):
attns = []
for k, v in x.items():
v = v[0, :, ::-1, :]
v = torch.tensor(v.copy())
if parse == "min":
item = torch.min(v, dim=0).values
elif parse == "max":
item = torch.max(v, dim=0).values
elif parse == "mean":
item = torch.mean(v, dim=0)
attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
roll = torch.prod(torch.stack(attns)[:skip_last_layers], dim=0)
return roll
return (rollout,)
@app.cell
def __(mo):
METHODS = ["Attention visualization", "Attention rollout (MIN)", "Attention rollout (MEAN)", "Attention rollout (MAX)"]
dropdown_method = mo.ui.dropdown(options=METHODS, value=METHODS[0], label="Select XAI method")
dropdown_method
return METHODS, dropdown_method
@app.cell
def __(mo):
switch = mo.ui.switch(value=False, label="To use skip connection in rollout")
switch
return (switch,)
@app.cell
def __(mo):
max_value_switch = mo.ui.switch(value=False, label="use rollout layers with max value")
max_value_switch
return (max_value_switch,)
@app.cell
def __(mo):
highlight_squares_switch = mo.ui.switch(value=False, label="Use rollout to highlight squares")
highlight_squares_switch
return (highlight_squares_switch,)
@app.cell
def __(
dropdown_layer,
dropdown_method,
focus_square,
highlight_squares_switch,
max_value_switch,
rollout,
switch,
torch,
):
import chess
from global_data import global_data
def find_max(a):
ar = a.reshape(a.shape[0], -1)
i = torch.max(ar, dim=1).values
im = torch.argmax(i[1:])
return a[im + 1]
def parse_activations(act, layer_number=None):
if dropdown_method.value == "Attention visualization":
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
a = act[layer_key][0, :, ::-1 , :]
elif dropdown_method.value == "Attention rollout (MIN)":
if not max_value_switch.value:
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="min")
else:
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="min") for _i in range(0, 15)], dim=0)
a = find_max(a)
elif dropdown_method.value == "Attention rollout (MAX)":
if not max_value_switch.value:
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="max")
else:
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="max") for _i in range(0, 15)], dim=0)
a = find_max(a)
elif dropdown_method.value == "Attention rollout (MEAN)":
if not max_value_switch.value:
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="mean")
else:
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="mean") for _i in range(0, 15)], dim=0)
a = find_max(a)
if dropdown_method.value != "Attention visualization":
if highlight_squares_switch.value:
a = a.max(dim=0).values
a = torch.stack([a for _ in range(64)], dim=0)
a = torch.stack([a for _ in range(32)]).numpy()
return a
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
def set_plotting_parameters(act, fen):
global_data.model = 'test'
global_data.activations = act
print(global_data.activations.shape)
global_data.subplot_rows = 8
global_data.subplot_cols = 4
global_data.board = chess.Board(fen)
global_data.show_all_heads = True
# global_data.selected_head = 1
global_data.visualization_mode = 'ROW'
global_data.focused_square_ind = focus_square_ind
# global_data.heatmap_horizontal_gap = 0.001
global_data.visualization_mode_is_64x64 = False
global_data.colorscale_mode = "mode1"
global_data.show_colorscale = False
return (
chess,
find_max,
focus_square_ind,
global_data,
parse_activations,
set_plotting_parameters,
)
@app.cell
def __(
FEN,
dropdown_layer,
get_activations_from_model,
get_models,
parse_activations,
set_plotting_parameters,
):
# FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
# board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14")
# board.move("f3e5")
# FEN = board.board.pc_board.fen()
PATTERN = "mha/QK/softmax"
# PATTERN = "smolgen_weights"
MODEL = get_models()[-1]
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
a = parse_activations(ACTIVATIONS, layer_number=int(dropdown_layer.value))
set_plotting_parameters(a, FEN)
from activation_heatmap import heatmap_figure
fig = heatmap_figure()
fig.update_layout(height=1500, width=1200)
fig
return ACTIVATIONS, MODEL, PATTERN, a, fig, heatmap_figure
@app.cell
def __():
# Add fens after opponents moves
# Default squares of interest
return
if __name__ == "__main__":
app.run()