import chess import dash import plotly.graph_objs as go from plotly.subplots import make_subplots from global_data import global_data from svg_pieces import get_svg_board from dash import dcc, html, Input, Output, State import time import numpy as np from plotly.io import to_json import pickle V_GAP = 0.15 LAYOUT_MARGIN_V = 40 def heatmap_data(head): data = global_data.get_head_data(head) return data def heatmap_figure(): if global_data.model is None: return {} start = time.time() fig = make_figure() print('make fig:', time.time() - start) start = time.time() fig = add_heatmap_traces(fig) print('add traces:', time.time() - start) start = time.time() fig = add_layout(fig) print('add layout total:', time.time() - start) start = time.time() if global_data.selected_layer == 'Smolgen': with open('fig_as_json_no_pieces.json', 'w') as f: f.write(to_json(fig, pretty=True)) if not global_data.visualization_mode_is_64x64:# and global_data.selected_layer != 'Smolgen': fig = add_pieces(fig) print('add pieces:', time.time() - start) if global_data.selected_layer == 'Smolgen': with open('fig_as_json.json', 'w') as f: f.write(to_json(fig, pretty=True)) return fig def heatmap(): start = time.time() # We need to recalculate graph when grid size changes, other wise layout is a mess (Dash bug?). Use hidden Div's children to trigger callback for graph recalc. # Otherwise, we can just recalculate figure part and frontend rendering will be much faster # graph = html.Div(id='graph-container', children=[heatmap_graph()], style={'height': '100%', 'width': '100%', "overflow": "auto"#, 'textAlign': 'center'#, "display": "flex", "justifyContent":"center" }) print('GRAPH CREATION:', time.time() - start) return graph def heatmap_graph(): fig = heatmap_figure() config = { 'displaylogo': False, 'displayModeBar': True, 'modeBarButtonsToRemove': ['zoom', 'pan', 'select', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'], 'toImageButtonOptions': { 'format': global_data.export_format, 'scale': global_data.export_scale }} style = {'height': global_data.figure_container_height, 'width': '100%'}#, 'margin': '0 auto'} graph = dcc.Graph(figure=fig, id='graph', style=style, responsive='auto',#True, # True, config=config ) # graph = html.Div(id='graph-container', children=[graph], style={'height': '100%', 'width': '100%', "overflow": "auto" # }) # graph_component.children = [graph] global_data.cache_figure(fig) return graph def make_figure(): #print('assumed key', global_data.subplot_rows, global_data.subplot_cols, global_data.visualization_mode_is_64x64, global_data.selected_head if not global_data.show_all_heads else -1) #print('key', global_data.get_figure_cache_key()) #print('all keys', global_data.figure_cache.keys()) fig = global_data.get_cached_figure() if fig is None: if global_data.show_all_heads: titles = [f"Head {i + 1}" for i in range(global_data.number_of_heads)] print('MAKING SUBPLOTS', 'rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols) print('NUMBER OF HEADS:', global_data.number_of_heads) fig = make_subplots(rows=global_data.subplot_rows, cols=global_data.subplot_cols, subplot_titles=titles, horizontal_spacing=global_data.heatmap_horizontal_gap / global_data.subplot_cols, vertical_spacing=V_GAP / global_data.subplot_rows, ) else: print('CREATING 1x1') titles = [f"head {global_data.selected_head +1}"] fig = make_subplots(rows=1, cols=1, subplot_titles=titles)#go.Figure()#make_subplots(rows=1, cols=1, subplot_titles=titles) return fig def add_layout(fig): start = time.time() #coloraxis1 = None #if global_data.visualization_mode_is_64x64: # if global_data.colorscale_mode == '3': # coloraxis1 = {'colorscale': 'Viridis'} coloraxis = None if global_data.colorscale_mode == 'mode3': cmin = np.amin(global_data.activations[:, :, :]) cmax = np.amax(global_data.activations[:, :, :]) coloraxis = {'colorscale': 'Viridis', 'colorbar': {'ypad': 0} , 'cmin': cmin, 'cmax': cmax, 'showscale': global_data.show_colorscale} if global_data.check_if_figure_is_cached(): print('Using existing layout') fig.update_layout({'coloraxis1': coloraxis}, overwrite=True) return fig layout = go.Layout( # title='Plot title goes here', margin={'t': LAYOUT_MARGIN_V, 'b': LAYOUT_MARGIN_V, 'r': 40, 'l': 40}, coloraxis1=coloraxis, modebar={'orientation': 'v'} #coloraxis={'colorscale': 'Viridis'} #pa #plot_bgcolor='rgb(0,0,0)', #paper_bgcolor="black" ) fig.update_layout(layout) # fig['layout'].update(layout) print('update layout:', time.time() - start) start = time.time() fig = update_axis(fig) print('update axis:', time.time() - start) # print(fig) return fig def update_axis(fig): if global_data.visualization_mode_is_64x64: tickvals_x = list(range(0, 64, 4)) tickvals_y = list(range(3, 67, 4))#list(range(0, 64, 4))#list(range(3, 67, 4)) if global_data.board.turn or global_data.selected_layer == 'Smolgen': ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788')] #tickvals = list(range(0, 64)) #ticktext_x = [x + y for x, y in zip('abcdefg' * 8, '1'*8 + '2'*8 + '3'*8 + '4'*8 + '5'*8 + '6'*8 + '7'*8 + '8'*8)] ticktext_y = ticktext_x[::-1] else: ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788'[::-1])] ticktext_y = ticktext_x[::-1] showticklabels = True #ticklabelstep = 4 val_range = [-0.5, 63.5] ticks = 'outside' title_x = {'text': "Keys ('to' square)", 'standoff': 1} title_y = {'text': "Queries ('from' square)", 'standoff': 1} else: title_x = None title_y = None tickvals_x = list(range(8)) # [0, 1, 2, 3, 4, 5, 6, 7] tickvals_y = tickvals_x ticktext_x = [letter for letter in 'abcdefgh'] ticktext_y = [letter for letter in '12345678'] showticklabels = True #ticklabelstep = 1 val_range = [-0.5, 7.5] ticks = '' if not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1): constraintowards_x = 'center' else: constraintowards_x = 'right' fig.update_xaxes(title=title_x, range=val_range, # ticklen=50, zeroline=False, showgrid=False, scaleanchor='y', constrain='domain', constraintoward=constraintowards_x, ticks=ticks, # ticks, ticktext=ticktext_x, tickvals=tickvals_x, showticklabels=showticklabels, # mirror='ticks', fixedrange=True, #ticklabelstep=ticklabelstep, ) fig.update_yaxes(title=title_y, range=val_range, zeroline=False, showgrid=False, scaleanchor='x', constrain='domain', constraintoward='top', ticks=ticks, # ticks, ticktext=ticktext_y, tickvals=tickvals_y, showticklabels=showticklabels, # mirror='allticks', # side='top', fixedrange=True, #ticklabelstep=ticklabelstep ) return fig def calc_colorbar(row, col): row = global_data.subplot_rows - row + 1 #invert #row = global_data.subplot_rows - row - 1 #invert dy = (1/global_data.subplot_rows) dx = (1/global_data.subplot_cols) offset = 1/global_data.subplot_cols - 2*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4#global_data.colorscale_x_offset#(494.1125)/2239.2#1/global_data.subplot_cols - 3*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4 if global_data.heatmap_h == 0: len = (1 - V_GAP/global_data.subplot_rows) / global_data.subplot_rows #- #V_GAP/global_data.subplot_rows lenmode = 'fraction' offset2 = len / 2 else: #total_h = global_data.heatmap_fig_h * global_data.heatmap_h + (global_data.subplot_rows - 1) len = global_data.heatmap_h/(global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V) lenmode = 'fraction' #offset2 = len / 2 #lenmode = 'pixels' #offset2 = 1 - len/(global_data.subplot_rows*len + V_GAP) #1/global_data.subplot_rows - (V_GAP/global_data.subplot_rows) #tot_h = global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V #max_h = ((1 - V_GAP)) / global_data.subplot_rows #cur_h = len offset2 = 1 - (global_data.subplot_rows-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) - len/2 #0#len/2 #+ (max_h - cur_h) #offset = global_data.colorscale_x_offset #shift = (global_data.heatmap_w + 20 + 20 + global_data.heatmap_gap)/global_data.heatmap_fig_w #cx = (col - 1) * shift + offset cx = (col-1)*(dx + (global_data.heatmap_horizontal_gap / global_data.subplot_cols)/global_data.subplot_cols) + offset cy = (row-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) + offset2 #cy = (global_data.subplot_rows - 1 - row) * (dy + (V_GAP / global_data.subplot_rows) / global_data.subplot_rows) + offset2 ####################### colorbar=dict(len=len, y=cy, x=cx, ypad=0, xpad=0, ticklabelposition='inside', ticks='inside', ticklen=3, lenmode=lenmode, tickfont=dict(color='#7e807f')) return colorbar def add_heatmap_trace(fig, row, col, head=None): # print('ADDING heatmap', row, col) if head is None: head = (row - 1) * global_data.subplot_cols + (col - 1) data = heatmap_data(head) if data is None: return fig if global_data.visualization_mode_is_64x64: xgap, ygap = 0, 0 #hovertemplate = 'Query: %{y}
Key: %{x}
value: %{z}' hovertemplate = 'Query: %{customdata[0]}
Key: %{customdata[1]}
value: %{z:.5f}' if global_data.board.turn or global_data.selected_layer == 'Smolgen': customdata_x = [[letter + ind for ind in '12345678' for letter in 'abcdefgh'] for _ in range(64)] customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678'[::-1] for letter in 'abcdefgh'[::-1]] else: customdata_x = [[letter + ind for ind in '12345678'[::-1] for letter in 'abcdefgh'] for _ in range(64)] customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'[::-1]] #customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'] customdata = np.moveaxis([customdata_y, customdata_x], 0, -1)#[customdata_x, customdata_y] else: xgap, ygap = 2, 2 hovertemplate = '%{x}%{y}: %{z}' customdata = None coloraxis = None #Colorscale #if global_data.visualization_mode_is_64x64: # if global_data.colorscale_mode == '3': # coloraxis = 'coloraxis1' coloraxis = None colorscale = 'Viridis' colorbar = None #if global_data.show_colorscale and global_data.colorscale_mode == 'mode3': if global_data.colorscale_mode == 'mode3': coloraxis = 'coloraxis1' colorscale = None elif global_data.show_colorscale and not global_data.colorscale_mode == 'mode3' and not (not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1)): colorbar = calc_colorbar(row, col) zmin, zmax = None, None if global_data.colorscale_mode == 'mode2': pass zmin = np.amin(global_data.activations[head, :, :]) zmax = np.amax(global_data.activations[head, :, :]) #print('ZMINMAX M2', head, zmin, zmax) elif global_data.colorscale_mode == 'mode1': zmin = np.amin(data) zmax = np.amax(data) #print('Trace data shape', data.shape) trace = go.Heatmap( z=data, colorscale=colorscale, showscale=global_data.show_colorscale,#True, colorbar=colorbar, #colorbar=dict(len=len, y=cy, x=cx, ypad=0, ticklabelposition='inside', ticks='inside', ticklen=3, # tickfont=dict(color='#7e807f')), xgap=xgap, ygap=ygap, hovertemplate=hovertemplate, customdata=customdata, zmin=zmin, zmax=zmax, coloraxis=coloraxis #zmin=zmin, #zmax=zmax ) fig.add_trace(trace, row=row, col=col) return fig def add_heatmap_traces(fig): print('adding traces, rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols) #adding traces is quick so we don't bother using cached values. Wipe old traces and add new. fig.data = [] if global_data.show_all_heads: for row in range(global_data.subplot_rows): for col in range(global_data.subplot_cols): fig = add_heatmap_trace(fig, row + 1, col + 1) else: fig = add_heatmap_trace(fig, 1, 1, global_data.selected_head) return fig def add_pieces(fig): if global_data.selected_layer != 'Smolgen': board = global_data.board else: board = chess.Board(fen=None) #Empty board, we want to draw only the focused square board_svg = get_svg_board(board, global_data.focused_square_ind, True) images = [dict( source=board_svg, xref="x"+str(i), yref="y"+str(i), x=3.5, y=3.5, sizex=8, sizey=8, xanchor='center', yanchor='middle', sizing="stretch", ) for i in range(2, 2+255) ] images = [dict( source=board_svg, xref="x", yref="y", x=3.5, y=3.5, sizex=8, sizey=8, xanchor='center', yanchor='middle', sizing="stretch", )] + images fig.layout.images = images return fig board_svg = get_svg_board(board, global_data.focused_square_ind, True) if global_data.check_if_figure_is_cached(): print('USING CACHED') for img in fig.layout.images: img['source'] = board_svg else: fig.add_layout_image( dict( source=board_svg, xref="x", yref="y", x=3.5, y=3.5, sizex=8, sizey=8, xanchor='center', yanchor='middle', sizing="stretch", ), row='all', col='all', exclude_empty_subplots=True, ) return fig