| | import chess |
| | import dash |
| | import plotly.graph_objs as go |
| | from plotly.subplots import make_subplots |
| | from global_data import global_data |
| |
|
| | from svg_pieces import get_svg_board |
| | from dash import dcc, html, Input, Output, State |
| |
|
| |
|
| | import time |
| |
|
| | import numpy as np |
| |
|
| | from plotly.io import to_json |
| |
|
| | import pickle |
| |
|
| | V_GAP = 0.15 |
| | LAYOUT_MARGIN_V = 40 |
| |
|
| | def heatmap_data(head): |
| | data = global_data.get_head_data(head) |
| | return data |
| |
|
| |
|
| | def heatmap_figure(): |
| | if global_data.model is None: |
| | return {} |
| | start = time.time() |
| | fig = make_figure() |
| | print('make fig:', time.time() - start) |
| |
|
| | start = time.time() |
| | fig = add_heatmap_traces(fig) |
| | print('add traces:', time.time() - start) |
| | start = time.time() |
| | fig = add_layout(fig) |
| | print('add layout total:', time.time() - start) |
| |
|
| | start = time.time() |
| |
|
| | if global_data.selected_layer == 'Smolgen': |
| | with open('fig_as_json_no_pieces.json', 'w') as f: |
| | f.write(to_json(fig, pretty=True)) |
| |
|
| | if not global_data.visualization_mode_is_64x64: |
| | fig = add_pieces(fig) |
| | print('add pieces:', time.time() - start) |
| |
|
| | if global_data.selected_layer == 'Smolgen': |
| | with open('fig_as_json.json', 'w') as f: |
| | f.write(to_json(fig, pretty=True)) |
| |
|
| | return fig |
| |
|
| |
|
| | def heatmap(): |
| | start = time.time() |
| | |
| | |
| | |
| | graph = html.Div(id='graph-container', children=[heatmap_graph()], |
| | style={'height': '100%', 'width': '100%', "overflow": "auto" |
| | }) |
| | print('GRAPH CREATION:', time.time() - start) |
| | return graph |
| |
|
| |
|
| | def heatmap_graph(): |
| | fig = heatmap_figure() |
| |
|
| | config = { |
| | 'displaylogo': False, |
| | 'displayModeBar': True, |
| | 'modeBarButtonsToRemove': ['zoom', 'pan', 'select', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'], |
| | 'toImageButtonOptions': { |
| | 'format': global_data.export_format, |
| | 'scale': global_data.export_scale |
| | }} |
| |
|
| | style = {'height': global_data.figure_container_height, 'width': '100%'} |
| |
|
| | graph = dcc.Graph(figure=fig, id='graph', style=style, |
| | responsive='auto', |
| | config=config |
| | ) |
| |
|
| | |
| | |
| | |
| |
|
| | global_data.cache_figure(fig) |
| |
|
| | return graph |
| |
|
| |
|
| | def make_figure(): |
| | |
| | |
| | |
| | fig = global_data.get_cached_figure() |
| | if fig is None: |
| | if global_data.show_all_heads: |
| | titles = [f"Head {i + 1}" for i in range(global_data.number_of_heads)] |
| | print('MAKING SUBPLOTS', 'rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols) |
| | print('NUMBER OF HEADS:', global_data.number_of_heads) |
| | fig = make_subplots(rows=global_data.subplot_rows, cols=global_data.subplot_cols, subplot_titles=titles, |
| | horizontal_spacing=global_data.heatmap_horizontal_gap / global_data.subplot_cols, |
| | vertical_spacing=V_GAP / global_data.subplot_rows, |
| | ) |
| | else: |
| | print('CREATING 1x1') |
| | titles = [f"head {global_data.selected_head +1}"] |
| | fig = make_subplots(rows=1, cols=1, subplot_titles=titles) |
| |
|
| | return fig |
| |
|
| |
|
| | def add_layout(fig): |
| | start = time.time() |
| |
|
| | |
| | |
| | |
| | |
| |
|
| | coloraxis = None |
| | if global_data.colorscale_mode == 'mode3': |
| | cmin = np.amin(global_data.activations[:, :, :]) |
| | cmax = np.amax(global_data.activations[:, :, :]) |
| | coloraxis = {'colorscale': 'Viridis', 'colorbar': {'ypad': 0} , 'cmin': cmin, 'cmax': cmax, 'showscale': global_data.show_colorscale} |
| |
|
| | if global_data.check_if_figure_is_cached(): |
| | print('Using existing layout') |
| | fig.update_layout({'coloraxis1': coloraxis}, overwrite=True) |
| | return fig |
| |
|
| | layout = go.Layout( |
| | |
| | margin={'t': LAYOUT_MARGIN_V, 'b': LAYOUT_MARGIN_V, 'r': 40, 'l': 40}, |
| | coloraxis1=coloraxis, |
| | modebar={'orientation': 'v'} |
| | |
| | |
| | |
| | |
| | ) |
| |
|
| | fig.update_layout(layout) |
| | |
| |
|
| | print('update layout:', time.time() - start) |
| |
|
| | start = time.time() |
| | fig = update_axis(fig) |
| | print('update axis:', time.time() - start) |
| | |
| | return fig |
| |
|
| |
|
| | def update_axis(fig): |
| | if global_data.visualization_mode_is_64x64: |
| | tickvals_x = list(range(0, 64, 4)) |
| | tickvals_y = list(range(3, 67, 4)) |
| | if global_data.board.turn or global_data.selected_layer == 'Smolgen': |
| | ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788')] |
| | |
| | |
| | ticktext_y = ticktext_x[::-1] |
| | else: |
| | ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788'[::-1])] |
| | ticktext_y = ticktext_x[::-1] |
| | showticklabels = True |
| | |
| | val_range = [-0.5, 63.5] |
| | ticks = 'outside' |
| | title_x = {'text': "Keys ('to' square)", 'standoff': 1} |
| | title_y = {'text': "Queries ('from' square)", 'standoff': 1} |
| | else: |
| | title_x = None |
| | title_y = None |
| | tickvals_x = list(range(8)) |
| | tickvals_y = tickvals_x |
| | ticktext_x = [letter for letter in 'abcdefgh'] |
| | ticktext_y = [letter for letter in '12345678'] |
| | showticklabels = True |
| | |
| | val_range = [-0.5, 7.5] |
| | ticks = '' |
| |
|
| | if not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1): |
| | constraintowards_x = 'center' |
| | else: |
| | constraintowards_x = 'right' |
| |
|
| |
|
| | fig.update_xaxes(title=title_x, |
| | range=val_range, |
| | |
| | zeroline=False, |
| | showgrid=False, |
| | scaleanchor='y', |
| | constrain='domain', |
| | constraintoward=constraintowards_x, |
| | ticks=ticks, |
| | ticktext=ticktext_x, |
| | tickvals=tickvals_x, |
| | showticklabels=showticklabels, |
| | |
| | fixedrange=True, |
| | |
| | ) |
| |
|
| | fig.update_yaxes(title=title_y, |
| | range=val_range, |
| | zeroline=False, |
| | showgrid=False, |
| | scaleanchor='x', |
| | constrain='domain', |
| | constraintoward='top', |
| | ticks=ticks, |
| | ticktext=ticktext_y, |
| | tickvals=tickvals_y, |
| | showticklabels=showticklabels, |
| | |
| | |
| | fixedrange=True, |
| | |
| | ) |
| | return fig |
| |
|
| |
|
| | def calc_colorbar(row, col): |
| | row = global_data.subplot_rows - row + 1 |
| | |
| |
|
| | dy = (1/global_data.subplot_rows) |
| | dx = (1/global_data.subplot_cols) |
| |
|
| | offset = 1/global_data.subplot_cols - 2*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4 |
| |
|
| | if global_data.heatmap_h == 0: |
| | len = (1 - V_GAP/global_data.subplot_rows) / global_data.subplot_rows |
| | lenmode = 'fraction' |
| | offset2 = len / 2 |
| | else: |
| | |
| | len = global_data.heatmap_h/(global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V) |
| | lenmode = 'fraction' |
| | |
| | |
| | |
| |
|
| | |
| | |
| | |
| | offset2 = 1 - (global_data.subplot_rows-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) - len/2 |
| |
|
| |
|
| | |
| | |
| | |
| |
|
| |
|
| | cx = (col-1)*(dx + (global_data.heatmap_horizontal_gap / global_data.subplot_cols)/global_data.subplot_cols) + offset |
| | cy = (row-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) + offset2 |
| | |
| |
|
| | |
| |
|
| |
|
| | colorbar=dict(len=len, y=cy, x=cx, ypad=0, xpad=0, ticklabelposition='inside', ticks='inside', ticklen=3, lenmode=lenmode, |
| | tickfont=dict(color='#7e807f')) |
| |
|
| | return colorbar |
| |
|
| | def add_heatmap_trace(fig, row, col, head=None): |
| | |
| | if head is None: |
| | head = (row - 1) * global_data.subplot_cols + (col - 1) |
| | data = heatmap_data(head) |
| |
|
| | if data is None: |
| | return fig |
| |
|
| | if global_data.visualization_mode_is_64x64: |
| | xgap, ygap = 0, 0 |
| | |
| | hovertemplate = 'Query: <b>%{customdata[0]}</b> <br>Key: <b>%{customdata[1]}</b> <br>value: <b>%{z:.5f}</b><extra></extra>' |
| | if global_data.board.turn or global_data.selected_layer == 'Smolgen': |
| | customdata_x = [[letter + ind for ind in '12345678' for letter in 'abcdefgh'] for _ in range(64)] |
| | customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678'[::-1] for letter in 'abcdefgh'[::-1]] |
| | else: |
| | customdata_x = [[letter + ind for ind in '12345678'[::-1] for letter in 'abcdefgh'] for _ in range(64)] |
| | customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'[::-1]] |
| |
|
| | |
| | customdata = np.moveaxis([customdata_y, customdata_x], 0, -1) |
| |
|
| | else: |
| | xgap, ygap = 2, 2 |
| | hovertemplate = '<b>%{x}%{y}</b>: <b>%{z}</b><extra></extra>' |
| | customdata = None |
| |
|
| | coloraxis = None |
| | |
| |
|
| | |
| | |
| | |
| |
|
| | coloraxis = None |
| | colorscale = 'Viridis' |
| | colorbar = None |
| | |
| | if global_data.colorscale_mode == 'mode3': |
| | coloraxis = 'coloraxis1' |
| | colorscale = None |
| |
|
| | elif global_data.show_colorscale and not global_data.colorscale_mode == 'mode3' and not (not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1)): |
| | colorbar = calc_colorbar(row, col) |
| |
|
| | zmin, zmax = None, None |
| |
|
| | if global_data.colorscale_mode == 'mode2': |
| | pass |
| | zmin = np.amin(global_data.activations[head, :, :]) |
| | zmax = np.amax(global_data.activations[head, :, :]) |
| | |
| |
|
| | elif global_data.colorscale_mode == 'mode1': |
| | zmin = np.amin(data) |
| | zmax = np.amax(data) |
| |
|
| | |
| | trace = go.Heatmap( |
| | z=data, |
| | colorscale=colorscale, |
| | showscale=global_data.show_colorscale, |
| | colorbar=colorbar, |
| | |
| | |
| | xgap=xgap, |
| | ygap=ygap, |
| | hovertemplate=hovertemplate, |
| | customdata=customdata, |
| | zmin=zmin, |
| | zmax=zmax, |
| | coloraxis=coloraxis |
| | |
| | |
| | ) |
| | fig.add_trace(trace, row=row, col=col) |
| | return fig |
| |
|
| |
|
| | def add_heatmap_traces(fig): |
| | print('adding traces, rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols) |
| | |
| | fig.data = [] |
| | if global_data.show_all_heads: |
| | for row in range(global_data.subplot_rows): |
| | for col in range(global_data.subplot_cols): |
| | fig = add_heatmap_trace(fig, row + 1, col + 1) |
| | else: |
| | fig = add_heatmap_trace(fig, 1, 1, global_data.selected_head) |
| | return fig |
| |
|
| |
|
| | def add_pieces(fig): |
| | if global_data.selected_layer != 'Smolgen': |
| | board = global_data.board |
| |
|
| | else: |
| | board = chess.Board(fen=None) |
| | board_svg = get_svg_board(board, global_data.focused_square_ind, True) |
| |
|
| | images = [dict( |
| | source=board_svg, |
| | xref="x"+str(i), |
| | yref="y"+str(i), |
| | x=3.5, |
| | y=3.5, |
| | sizex=8, |
| | sizey=8, |
| | xanchor='center', |
| | yanchor='middle', |
| | sizing="stretch", |
| | ) |
| | for i in range(2, 2+255) |
| | ] |
| | images = [dict( |
| | source=board_svg, |
| | xref="x", |
| | yref="y", |
| | x=3.5, |
| | y=3.5, |
| | sizex=8, |
| | sizey=8, |
| | xanchor='center', |
| | yanchor='middle', |
| | sizing="stretch", |
| | )] + images |
| |
|
| | fig.layout.images = images |
| | return fig |
| | board_svg = get_svg_board(board, global_data.focused_square_ind, True) |
| | if global_data.check_if_figure_is_cached(): |
| | print('USING CACHED') |
| | for img in fig.layout.images: |
| | img['source'] = board_svg |
| | else: |
| | fig.add_layout_image( |
| | dict( |
| | source=board_svg, |
| | xref="x", |
| | yref="y", |
| | x=3.5, |
| | y=3.5, |
| | sizex=8, |
| | sizey=8, |
| | xanchor='center', |
| | yanchor='middle', |
| | sizing="stretch", |
| | ), |
| | row='all', |
| | col='all', |
| | exclude_empty_subplots=True, |
| | ) |
| |
|
| | return fig |
| |
|
| |
|