| | import chess.engine |
| | from constants import ROOT_DIR, CONTENT_HEIGHT, LEFT_PANE_WIDTH, EXPORT_FORMAT, EXPORT_SCALE |
| | from time import sleep |
| | |
| |
|
| | from copy import deepcopy |
| |
|
| | from board2planes import board2planes |
| |
|
| | import yaml |
| |
|
| | import os |
| | from os.path import isdir, join |
| | import sys |
| |
|
| | import numpy as np |
| |
|
| |
|
| | SIMULATE_TF = False |
| | |
| | DEV_MODE = False |
| | SIMULATED_LAYERS = 6 |
| | SIMULATED_HEADS = 64 |
| | FIXED_ROW = None |
| | FIXED_COL = None |
| | if DEV_MODE: |
| | class DummyModel: |
| | def __init__(self, layers, heads): |
| | self.layers = layers |
| | self.heads = heads |
| |
|
| | def __call__(self, *args, **kwargs): |
| | data = [np.random.rand(1, self.heads, 64, 64) for i in range(self.layers)] |
| | return [None, None, None, data] |
| |
|
| | else: |
| | import tensorflow as tf |
| | from tensorflow.compat.v1 import ConfigProto |
| | from tensorflow.compat.v1 import InteractiveSession |
| |
|
| |
|
| | |
| | |
| | |
| | |
| | class GlobalData: |
| | def __init__(self): |
| | import os |
| | if not DEV_MODE: |
| | |
| | os.environ["CUDA_VISIBLE_DEVICES"] = "0" |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.tmp = 0 |
| | self.export_format = EXPORT_FORMAT |
| | self.export_scale = EXPORT_SCALE |
| | self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' |
| | self.board = chess.Board(fen=self.fen) |
| | self.focused_square_ind = 0 |
| | self.active_move_table_cell = None |
| |
|
| | self.activations = None |
| | self.visualization_mode = 'ROW' |
| | self.visualization_mode_is_64x64 = False |
| | self.subplot_mode = 'big' |
| | self.subplot_cols = 0 |
| | self.subplot_rows = 0 |
| | self.number_of_heads = 0 |
| | self.selected_head = None |
| | self.show_all_heads = True |
| |
|
| | self.show_colorscale = False |
| | self.colorscale_mode = 'mode1' |
| |
|
| | self.figure_container_height = '100%' |
| |
|
| | self.running_counter = 0 |
| | self.grid_has_changed = False |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | self.screen_w = 0 |
| | self.screen_h = 0 |
| | self.figure_w = 0 |
| | self.figure_h = 0 |
| | self.heatmap_w = 0 |
| | self.heatmap_h = 0 |
| | self.heatmap_fig_w = 0 |
| | self.heatmap_fig_h = 0 |
| | self.heatmap_gap = 0 |
| | self.colorscale_x_offset = 0 |
| |
|
| | self.heatmap_horizontal_gap = 0.275 |
| |
|
| | self.figure_cache = {} |
| |
|
| | self.update_grid_shape() |
| |
|
| | self.pgn_data = [] |
| | self.move_table_boards = {} |
| |
|
| | if not SIMULATE_TF: |
| | self.selected_layer = None |
| | else: |
| | self.selected_layer = 0 |
| |
|
| | self.nr_of_layers_in_body = -1 |
| | self.has_attention_policy = False |
| |
|
| | self.model_paths = [] |
| | self.model_names = [] |
| | self.model_yamls = {} |
| | self.model_cache = {} |
| | self.find_models2() |
| | self.model_path = None |
| | self.model = None |
| | self.tfp = None |
| | if not SIMULATE_TF: |
| | self.load_model() |
| | self.activations_data = None |
| |
|
| | if self.model is not None or SIMULATE_TF: |
| | self.update_activations_data() |
| |
|
| | if self.selected_layer is not None: |
| | self.set_layer(self.selected_layer) |
| |
|
| | self.move_table_active_cell = None |
| |
|
| | self.force_update_graph = False |
| |
|
| | def set_subplot_mode(self, fit_to_page): |
| | if fit_to_page == [True]: |
| | self.subplot_mode = 'fit' |
| | else: |
| | self.subplot_mode = 'big' |
| | self.update_grid_shape() |
| |
|
| | def set_screen_size(self, w, h): |
| | self.screen_w = w |
| | self.screen_h = h |
| |
|
| | self.figure_w = w*LEFT_PANE_WIDTH/100 |
| | self.figure_h = h*CONTENT_HEIGHT/100 |
| | print('GRAPH AREA', self.figure_w, self.figure_h) |
| |
|
| | def set_heatmap_size(self, size): |
| | if size != '1': |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| |
|
| | self.heatmap_w = float(size[0]) |
| | self.heatmap_h = float(size[1]) |
| | self.heatmap_fig_w = float(size[2]) |
| | self.heatmap_fig_h = float(size[3]) |
| | self.heatmap_gap = round(float(size[4]), 2) |
| |
|
| | self.colorscale_x_offset = float(size[5])/self.heatmap_fig_w |
| |
|
| | if size[6] == 1: |
| | self.force_update_graph = True |
| | else: |
| | self.force_update_graph = False |
| |
|
| |
|
| | |
| | |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | def set_colorscale_mode(self, mode, colorscale_mode, colorscale_mode_64x64, show): |
| | if mode == '64x64': |
| | self.colorscale_mode = colorscale_mode_64x64 |
| | else: |
| | self.colorscale_mode = colorscale_mode |
| | |
| | self.show_colorscale = show == [True] |
| |
|
| | def cache_figure(self, fig): |
| | if not self.check_if_figure_is_cached() and fig != {}: |
| | key = self.get_figure_cache_key() |
| | cached_fig = deepcopy(fig) |
| | cached_fig.update_layout({'coloraxis1': None}, overwrite=True) |
| | |
| | self.figure_cache[key] = cached_fig |
| |
|
| | def get_cached_figure(self): |
| | if self.check_if_figure_is_cached(): |
| | key = self.get_figure_cache_key() |
| | fig = deepcopy(self.figure_cache[key]) |
| | else: |
| | fig = None |
| | return fig |
| |
|
| | def check_if_figure_is_cached(self): |
| | key = self.get_figure_cache_key() |
| | return key in self.figure_cache |
| |
|
| | def get_figure_cache_key(self): |
| | return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, |
| | self.selected_head if not self.show_all_heads else -1, self.show_colorscale, self.colorscale_mode, |
| | self.board.turn) |
| | |
| | |
| |
|
| | def get_side_to_move(self): |
| | return ['Black', 'White'][self.board.turn] |
| |
|
| | def load_model(self): |
| | if self.model_path in self.model_cache: |
| | self.model, self.tfp = self.model_cache[self.model_path] |
| |
|
| | elif self.model_path is not None: |
| | |
| | |
| | if not DEV_MODE: |
| | net = self.model_path |
| | yaml_path = self.model_yamls[self.model_path] |
| | with open(yaml_path) as f: |
| | cfg = f.read() |
| | cfg = yaml.safe_load(cfg) |
| |
|
| | if 'dropout_rate' in cfg['model']: |
| | print('Setting dropout_rate to 0.0') |
| | cfg['model']['dropout_rate'] = 0.0 |
| |
|
| | tfp = tfprocess.TFProcess(cfg) |
| | tfp.init_net() |
| | tfp.replace_weights(net, ignore_errors=True) |
| | self.model = tfp.model |
| | self.tfp = tfp |
| | else: |
| | self.model = DummyModel(SIMULATED_LAYERS, SIMULATED_HEADS) |
| | self.tfp = None |
| |
|
| | else: |
| | self.model = None |
| | self.tfp = None |
| |
|
| | def find_models(self): |
| | root = ROOT_DIR |
| | models_root_folder = os.path.join(root, 'models') |
| | model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] |
| | model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if |
| | isdir(join(models_root_folder, f))] |
| | self.model_names = model_folders |
| | self.model_paths = model_paths |
| |
|
| | |
| | |
| | |
| |
|
| | def find_models2(self): |
| | import os |
| | from os.path import isdir, join |
| | root = ROOT_DIR |
| | models_root_folder = os.path.join(root, 'models') |
| | model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))] |
| | model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if |
| | isdir(join(models_root_folder, f))] |
| |
|
| | models = [] |
| | paths = [] |
| | yamls = [] |
| | for path in model_paths: |
| | yaml_files = [file for file in os.listdir(path) if file.endswith(".yaml")] |
| | if len(yaml_files) != 1: |
| | continue |
| | model_files = [file for file in os.listdir(path) if file.endswith(".pb.gz")] |
| | if len(model_files) == 0: |
| | continue |
| |
|
| | models += model_files |
| | paths += [os.path.relpath(join(path, f)) for f in model_files] |
| | yaml_file = os.path.relpath(join(path, yaml_files[0])) |
| | yamls += [yaml_file]*len(model_files) |
| |
|
| | self.model_yamls = {path: yaml_file for path, yaml_file in zip(paths, yamls)} |
| | self.model_names = models |
| | self.model_paths = paths |
| |
|
| |
|
| | def update_activations_data(self): |
| |
|
| | if self.model is not None and self.selected_layer is None: |
| | self.selected_layer = 0 |
| |
|
| | if not SIMULATE_TF: |
| | if self.selected_layer is not None and self.model is not None and self.selected_layer != 'Smolgen': |
| | if not DEV_MODE: |
| | inputs = board2planes(self.board) |
| | inputs = tf.reshape(tf.convert_to_tensor(inputs, dtype=tf.float32), [-1, 112, 8, 8]) |
| | else: |
| | inputs = None |
| |
|
| | outputs = self.model(inputs) |
| | self.activations_data = outputs[-1] |
| | for i,x in enumerate(self.activations_data): |
| | print( 'LAYERS', i, x.shape) |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | elif self.selected_layer == 'Smolgen' and self.tfp is not None and self.tfp.use_smolgen: |
| | weights = self.tfp.smol_weight_gen_dense.get_weights()[0] |
| | self.activations_data = weights.reshape((weights.shape[0], 64, 64)) |
| | print('TYPEEEEE', type(self.activations_data)) |
| |
|
| | else: |
| | layers = SIMULATED_LAYERS |
| | heads = SIMULATED_HEADS |
| | self.activations_data = [np.random.rand(1, heads, 64, 64) for i in range(layers)] |
| |
|
| | if self.model is not None: |
| |
|
| | if self.model_path not in self.model_cache: |
| | self.model_cache[self.model_path] = [self.model, self.tfp] |
| |
|
| | self.update_layers_in_body_count() |
| |
|
| | |
| | |
| | if self.activations_data is not None and self.activations_data[-2].shape == (1, 8, 24): |
| | self.has_attention_policy = True |
| | else: |
| | self.has_attention_policy = False |
| | |
| | |
| |
|
| | def update_grid_shape(self): |
| | |
| | |
| | |
| |
|
| | def calc_cols(heads, rows): |
| | if heads % rows == 0: |
| | cols = int(heads / rows) |
| | else: |
| | cols = int(1 + heads / rows) |
| | return cols |
| |
|
| | if FIXED_ROW and FIXED_COL: |
| | self.subplot_cols = FIXED_COL |
| | self.subplot_rows = FIXED_ROW |
| | return None |
| |
|
| | heads = self.number_of_heads |
| | if self.subplot_mode == 'fit': |
| | max_rows_in_screen = 4 |
| | if heads <= 4: |
| | rows = 1 |
| | elif heads <= 8: |
| | rows = 2 |
| | else: |
| | rows = heads // 8 + int(heads % 8 != 0) |
| |
|
| | elif self.subplot_mode == 'big': |
| | |
| |
|
| | max_rows_in_screen = 2 |
| | rows = heads // 4 + int(heads % 4 != 0) |
| | |
| |
|
| | if rows > max_rows_in_screen: |
| | container_height = f'{int((rows / max_rows_in_screen) * 100)}%' |
| | else: |
| | container_height = '100%' |
| |
|
| | if rows != 0: |
| | cols = calc_cols(heads, rows) |
| | else: |
| | cols = 0 |
| |
|
| | if self.subplot_rows != rows or self.subplot_cols != cols: |
| | self.grid_has_changed = True |
| |
|
| | self.subplot_cols = cols |
| | self.subplot_rows = rows |
| |
|
| | if self.show_all_heads: |
| | self.figure_container_height = container_height |
| | else: |
| | self.figure_container_height = '100%' |
| |
|
| | def update_selected_activation_data(self): |
| | |
| | |
| | if self.activations_data is not None: |
| | if self.selected_layer not in ('Policy', 'Smolgen'): |
| | if not DEV_MODE: |
| | activations = tf.squeeze(self.activations_data[self.selected_layer], axis=0).numpy() |
| | |
| | else: |
| | activations = np.squeeze(self.activations_data[self.selected_layer], axis=0) |
| | elif self.selected_layer == 'Policy': |
| | print('RAW POLICY SHAPE', self.activations_data[-1].shape) |
| | activations = self.activations_data[-1].numpy() |
| | |
| |
|
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | |
| | elif self.selected_layer == 'Smolgen': |
| | activations = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64)) |
| |
|
| | self.activations = activations[:, ::-1, :] |
| |
|
| | def set_visualization_mode(self, mode): |
| | self.visualization_mode = mode |
| | self.visualization_mode_is_64x64 = mode == '64x64' |
| |
|
| | def set_layer(self, layer): |
| | self.selected_layer = layer |
| | self.update_selected_activation_data() |
| | if layer not in ('Policy', 'Smolgen'): |
| | self.number_of_heads = self.activations_data[self.selected_layer].shape[1] |
| | elif layer == 'Policy': |
| | self.number_of_heads = 1 |
| | elif layer == 'Smolgen': |
| | self.number_of_heads = self.activations.shape[0] |
| | self.set_head(0) |
| | self.update_grid_shape() |
| |
|
| | def set_head(self, head): |
| | self.selected_head = head |
| |
|
| | def set_model(self, model): |
| | if model != self.model_path: |
| | self.model_path = model |
| | self.load_model() |
| | self.update_activations_data() |
| | self.update_selected_activation_data() |
| | self.number_of_heads = self.activations_data[self.selected_layer].shape[1] |
| | if self.selected_head is None: |
| | self.selected_head = 0 |
| | else: |
| | self.selected_head = min(self.selected_head, self.number_of_heads - 1) |
| | self.update_grid_shape() |
| | if SIMULATE_TF: |
| | sleep(2) |
| |
|
| | def update_layers_in_body_count(self): |
| | |
| | heads = self.activations_data[0].shape[1] |
| | for ind, layer in enumerate(self.activations_data): |
| | if layer.shape[1] != heads or len(layer.shape) != 4: |
| | ind = ind - 1 |
| | break |
| | self.nr_of_layers_in_body = ind + 1 |
| | if self.selected_layer not in ('Policy', 'Smolgen'): |
| | self.selected_layer = min(self.selected_layer, self.nr_of_layers_in_body - 1) |
| |
|
| | def get_head_data(self, head): |
| |
|
| | if self.activations.shape[0] <= head: |
| | return None |
| |
|
| | if self.visualization_mode == '64x64': |
| | |
| | data = self.activations[head, :, :] |
| |
|
| | elif self.visualization_mode == 'ROW': |
| | |
| | if self.board.turn or self.selected_layer == 'Smolgen': |
| | row = 63 - self.focused_square_ind |
| | data = self.activations[head, row, :].reshape((8, 8)) |
| | else: |
| | |
| | multiples = self.focused_square_ind // 8 |
| | remainder = self.focused_square_ind % 8 |
| |
|
| | a = 7 - remainder |
| | b = multiples * 8 |
| | row = a + b |
| | data = self.activations[head, row, :].reshape((8, 8))[::-1, :] |
| | else: |
| | |
| | if self.board.turn or self.selected_layer == 'Smolgen': |
| | col = self.focused_square_ind |
| | data = self.activations[head, :, col].reshape((8, 8))[::-1, ::-1] |
| | else: |
| | focused = 63 - self.focused_square_ind |
| | multiples = focused // 8 |
| | remainder = focused % 8 |
| | a = 7 - remainder |
| | b = multiples * 8 |
| | col = a + b |
| | |
| | data = self.activations[head, :, col].reshape((8, 8))[:, ::-1] |
| | return data |
| |
|
| | def set_fen(self, fen): |
| | self.board.set_fen(fen) |
| | self.fen = fen |
| | self.update_activations_data() |
| | self.update_selected_activation_data() |
| |
|
| | def set_board(self, board): |
| | self.board = deepcopy(board) |
| | self.update_activations_data() |
| | self.update_selected_activation_data() |
| |
|
| |
|
| | global_data = GlobalData() |
| | print('global data created') |
| |
|