seredapj commited on
Commit
ec03afc
·
1 Parent(s): 5922570

Delete our_visualization

Browse files
our_visualization/.DS_Store DELETED
Binary file (8.2 kB)
 
our_visualization/__pycache__/activation_heatmap.cpython-38.pyc DELETED
Binary file (9.72 kB)
 
our_visualization/__pycache__/board2planes.cpython-38.pyc DELETED
Binary file (3.69 kB)
 
our_visualization/__pycache__/constants.cpython-38.pyc DELETED
Binary file (614 Bytes)
 
our_visualization/__pycache__/global_data.cpython-38.pyc DELETED
Binary file (13 kB)
 
our_visualization/__pycache__/leela_board.cpython-38.pyc DELETED
Binary file (33 kB)
 
our_visualization/__pycache__/python_chess_customized_svg.cpython-38.pyc DELETED
Binary file (21.8 kB)
 
our_visualization/__pycache__/svg_pieces.cpython-38.pyc DELETED
Binary file (782 Bytes)
 
our_visualization/__pycache__/utils.cpython-38.pyc DELETED
Binary file (1.67 kB)
 
our_visualization/__pycache__/visualization_demo.cpython-38.pyc DELETED
Binary file (7.44 kB)
 
our_visualization/activation_heatmap.py DELETED
@@ -1,436 +0,0 @@
1
- import chess
2
- import dash
3
- import plotly.graph_objs as go
4
- from plotly.subplots import make_subplots
5
- from global_data import global_data
6
-
7
- from svg_pieces import get_svg_board
8
- from dash import dcc, html, Input, Output, State
9
-
10
-
11
- import time
12
-
13
- import numpy as np
14
-
15
- from plotly.io import to_json
16
-
17
- import pickle
18
-
19
- V_GAP = 0.15
20
- LAYOUT_MARGIN_V = 40
21
-
22
- def heatmap_data(head):
23
- data = global_data.get_head_data(head)
24
- return data
25
-
26
-
27
- def heatmap_figure():
28
- if global_data.model is None:
29
- return {}
30
- start = time.time()
31
- fig = make_figure()
32
- print('make fig:', time.time() - start)
33
-
34
- start = time.time()
35
- fig = add_heatmap_traces(fig)
36
- print('add traces:', time.time() - start)
37
- with open("./test_activations_starting.pkl", 'wb') as f:
38
- print("saving activations")
39
- pickle.dump(global_data.activations, f)
40
- start = time.time()
41
- fig = add_layout(fig)
42
- print('add layout total:', time.time() - start)
43
-
44
- start = time.time()
45
-
46
- if global_data.selected_layer == 'Smolgen':
47
- with open('fig_as_json_no_pieces.json', 'w') as f:
48
- f.write(to_json(fig, pretty=True))
49
-
50
- if not global_data.visualization_mode_is_64x64:# and global_data.selected_layer != 'Smolgen':
51
- fig = add_pieces(fig)
52
- print('add pieces:', time.time() - start)
53
-
54
- if global_data.selected_layer == 'Smolgen':
55
- with open('fig_as_json.json', 'w') as f:
56
- f.write(to_json(fig, pretty=True))
57
-
58
- return fig
59
-
60
-
61
- def heatmap():
62
- start = time.time()
63
- # We need to recalculate graph when grid size changes, other wise layout is a mess (Dash bug?). Use hidden Div's children to trigger callback for graph recalc.
64
- # Otherwise, we can just recalculate figure part and frontend rendering will be much faster
65
- #
66
- graph = html.Div(id='graph-container', children=[heatmap_graph()],
67
- style={'height': '100%', 'width': '100%', "overflow": "auto"#, 'textAlign': 'center'#, "display": "flex", "justifyContent":"center"
68
- })
69
- print('GRAPH CREATION:', time.time() - start)
70
- return graph
71
-
72
-
73
- def heatmap_graph():
74
- fig = heatmap_figure()
75
-
76
- config = {
77
- 'displaylogo': False,
78
- 'displayModeBar': True,
79
- 'modeBarButtonsToRemove': ['zoom', 'pan', 'select', 'zoomIn', 'zoomOut', 'autoScale', 'resetScale'],
80
- 'toImageButtonOptions': {
81
- 'format': global_data.export_format,
82
- 'scale': global_data.export_scale
83
- }}
84
-
85
- style = {'height': global_data.figure_container_height, 'width': '100%'}#, 'margin': '0 auto'}
86
-
87
- graph = dcc.Graph(figure=fig, id='graph', style=style,
88
- responsive='auto',#True, # True,
89
- config=config
90
- )
91
-
92
- # graph = html.Div(id='graph-container', children=[graph], style={'height': '100%', 'width': '100%', "overflow": "auto"
93
- # })
94
- # graph_component.children = [graph]
95
-
96
- global_data.cache_figure(fig)
97
-
98
- return graph
99
-
100
-
101
- def make_figure():
102
- #print('assumed key', global_data.subplot_rows, global_data.subplot_cols, global_data.visualization_mode_is_64x64, global_data.selected_head if not global_data.show_all_heads else -1)
103
- #print('key', global_data.get_figure_cache_key())
104
- #print('all keys', global_data.figure_cache.keys())
105
- fig = global_data.get_cached_figure()
106
- if fig is None:
107
- if global_data.show_all_heads:
108
- titles = [f"Head {i + 1}" for i in range(global_data.number_of_heads)]
109
- print('MAKING SUBPLOTS', 'rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
110
- print('NUMBER OF HEADS:', global_data.number_of_heads)
111
- fig = make_subplots(rows=global_data.subplot_rows, cols=global_data.subplot_cols, subplot_titles=titles,
112
- horizontal_spacing=global_data.heatmap_horizontal_gap / global_data.subplot_cols,
113
- vertical_spacing=V_GAP / global_data.subplot_rows,
114
- )
115
- else:
116
- print('CREATING 1x1')
117
- titles = [f"head {global_data.selected_head +1}"]
118
- fig = make_subplots(rows=1, cols=1, subplot_titles=titles)#go.Figure()#make_subplots(rows=1, cols=1, subplot_titles=titles)
119
-
120
- return fig
121
-
122
-
123
- def add_layout(fig):
124
- start = time.time()
125
-
126
- #coloraxis1 = None
127
- #if global_data.visualization_mode_is_64x64:
128
- # if global_data.colorscale_mode == '3':
129
- # coloraxis1 = {'colorscale': 'Viridis'}
130
-
131
- coloraxis = None
132
- if global_data.colorscale_mode == 'mode3':
133
- cmin = np.amin(global_data.activations[:, :, :])
134
- cmax = np.amax(global_data.activations[:, :, :])
135
- coloraxis = {'colorscale': 'Viridis', 'colorbar': {'ypad': 0} , 'cmin': cmin, 'cmax': cmax, 'showscale': global_data.show_colorscale}
136
-
137
- if global_data.check_if_figure_is_cached():
138
- print('Using existing layout')
139
- fig.update_layout({'coloraxis1': coloraxis}, overwrite=True)
140
- return fig
141
-
142
- layout = go.Layout(
143
- # title='Plot title goes here',
144
- margin={'t': LAYOUT_MARGIN_V, 'b': LAYOUT_MARGIN_V, 'r': 40, 'l': 40},
145
- coloraxis1=coloraxis,
146
- modebar={'orientation': 'v'}
147
- #coloraxis={'colorscale': 'Viridis'}
148
- #pa
149
- #plot_bgcolor='rgb(0,0,0)',
150
- #paper_bgcolor="black"
151
- )
152
-
153
- fig.update_layout(layout)
154
- # fig['layout'].update(layout)
155
-
156
- print('update layout:', time.time() - start)
157
-
158
- start = time.time()
159
- fig = update_axis(fig)
160
- print('update axis:', time.time() - start)
161
- # print(fig)
162
- return fig
163
-
164
-
165
- def update_axis(fig):
166
- if global_data.visualization_mode_is_64x64:
167
- tickvals_x = list(range(0, 64, 4))
168
- tickvals_y = list(range(3, 67, 4))#list(range(0, 64, 4))#list(range(3, 67, 4))
169
- if global_data.board.turn or global_data.selected_layer == 'Smolgen':
170
- ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788')]
171
- #tickvals = list(range(0, 64))
172
- #ticktext_x = [x + y for x, y in zip('abcdefg' * 8, '1'*8 + '2'*8 + '3'*8 + '4'*8 + '5'*8 + '6'*8 + '7'*8 + '8'*8)]
173
- ticktext_y = ticktext_x[::-1]
174
- else:
175
- ticktext_x = [x + y for x, y in zip('ae' * 8, '1122334455667788'[::-1])]
176
- ticktext_y = ticktext_x[::-1]
177
- showticklabels = True
178
- #ticklabelstep = 4
179
- val_range = [-0.5, 63.5]
180
- ticks = 'outside'
181
- title_x = {'text': "Keys ('to' square)", 'standoff': 1}
182
- title_y = {'text': "Queries ('from' square)", 'standoff': 1}
183
- else:
184
- title_x = None
185
- title_y = None
186
- tickvals_x = list(range(8)) # [0, 1, 2, 3, 4, 5, 6, 7]
187
- tickvals_y = tickvals_x
188
- ticktext_x = [letter for letter in 'abcdefgh']
189
- ticktext_y = [letter for letter in '12345678']
190
- showticklabels = True
191
- #ticklabelstep = 1
192
- val_range = [-0.5, 7.5]
193
- ticks = ''
194
-
195
- if not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1):
196
- constraintowards_x = 'center'
197
- else:
198
- constraintowards_x = 'right'
199
-
200
-
201
- fig.update_xaxes(title=title_x,
202
- range=val_range,
203
- # ticklen=50,
204
- zeroline=False,
205
- showgrid=False,
206
- scaleanchor='y',
207
- constrain='domain',
208
- constraintoward=constraintowards_x,
209
- ticks=ticks, # ticks,
210
- ticktext=ticktext_x,
211
- tickvals=tickvals_x,
212
- showticklabels=showticklabels,
213
- # mirror='ticks',
214
- fixedrange=True,
215
- #ticklabelstep=ticklabelstep,
216
- )
217
-
218
- fig.update_yaxes(title=title_y,
219
- range=val_range,
220
- zeroline=False,
221
- showgrid=False,
222
- scaleanchor='x',
223
- constrain='domain',
224
- constraintoward='top',
225
- ticks=ticks, # ticks,
226
- ticktext=ticktext_y,
227
- tickvals=tickvals_y,
228
- showticklabels=showticklabels,
229
- # mirror='allticks',
230
- # side='top',
231
- fixedrange=True,
232
- #ticklabelstep=ticklabelstep
233
- )
234
- return fig
235
-
236
-
237
- def calc_colorbar(row, col):
238
- row = global_data.subplot_rows - row + 1 #invert
239
- #row = global_data.subplot_rows - row - 1 #invert
240
-
241
- dy = (1/global_data.subplot_rows)
242
- dx = (1/global_data.subplot_cols)
243
-
244
- offset = 1/global_data.subplot_cols - 2*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4#global_data.colorscale_x_offset#(494.1125)/2239.2#1/global_data.subplot_cols - 3*(global_data.heatmap_horizontal_gap/(global_data.subplot_cols))/4
245
-
246
- if global_data.heatmap_h == 0:
247
- len = (1 - V_GAP/global_data.subplot_rows) / global_data.subplot_rows #- #V_GAP/global_data.subplot_rows
248
- lenmode = 'fraction'
249
- offset2 = len / 2
250
- else:
251
- #total_h = global_data.heatmap_fig_h * global_data.heatmap_h + (global_data.subplot_rows - 1)
252
- len = global_data.heatmap_h/(global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V)
253
- lenmode = 'fraction'
254
- #offset2 = len / 2
255
- #lenmode = 'pixels'
256
- #offset2 = 1 - len/(global_data.subplot_rows*len + V_GAP) #1/global_data.subplot_rows - (V_GAP/global_data.subplot_rows)
257
-
258
- #tot_h = global_data.heatmap_fig_h - 2*LAYOUT_MARGIN_V
259
- #max_h = ((1 - V_GAP)) / global_data.subplot_rows
260
- #cur_h = len
261
- offset2 = 1 - (global_data.subplot_rows-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) - len/2 #0#len/2 #+ (max_h - cur_h)
262
-
263
-
264
- #offset = global_data.colorscale_x_offset
265
- #shift = (global_data.heatmap_w + 20 + 20 + global_data.heatmap_gap)/global_data.heatmap_fig_w
266
- #cx = (col - 1) * shift + offset
267
-
268
-
269
- cx = (col-1)*(dx + (global_data.heatmap_horizontal_gap / global_data.subplot_cols)/global_data.subplot_cols) + offset
270
- cy = (row-1)*(dy + (V_GAP / global_data.subplot_rows)/global_data.subplot_rows) + offset2
271
- #cy = (global_data.subplot_rows - 1 - row) * (dy + (V_GAP / global_data.subplot_rows) / global_data.subplot_rows) + offset2
272
-
273
- #######################
274
-
275
-
276
- colorbar=dict(len=len, y=cy, x=cx, ypad=0, xpad=0, ticklabelposition='inside', ticks='inside', ticklen=3, lenmode=lenmode,
277
- tickfont=dict(color='#7e807f'))
278
-
279
- return colorbar
280
-
281
- def add_heatmap_trace(fig, row, col, head=None):
282
- # print('ADDING heatmap', row, col)
283
- if head is None:
284
- head = (row - 1) * global_data.subplot_cols + (col - 1)
285
- data = heatmap_data(head)
286
-
287
- if data is None:
288
- return fig
289
-
290
- if global_data.visualization_mode_is_64x64:
291
- xgap, ygap = 0, 0
292
- #hovertemplate = 'Query: <b>%{y}</b> <br> Key: <b>%{x}</b> <br> value: <b>%{z}</b><extra></extra>'
293
- hovertemplate = 'Query: <b>%{customdata[0]}</b> <br>Key: <b>%{customdata[1]}</b> <br>value: <b>%{z:.5f}</b><extra></extra>'
294
- if global_data.board.turn or global_data.selected_layer == 'Smolgen':
295
- customdata_x = [[letter + ind for ind in '12345678' for letter in 'abcdefgh'] for _ in range(64)]
296
- customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678'[::-1] for letter in 'abcdefgh'[::-1]]
297
- else:
298
- customdata_x = [[letter + ind for ind in '12345678'[::-1] for letter in 'abcdefgh'] for _ in range(64)]
299
- customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh'[::-1]]
300
-
301
- #customdata_y = [[letter + ind for _ in range(64)] for ind in '12345678' for letter in 'abcdefgh']
302
- customdata = np.moveaxis([customdata_y, customdata_x], 0, -1)#[customdata_x, customdata_y]
303
-
304
- else:
305
- xgap, ygap = 2, 2
306
- hovertemplate = '<b>%{x}%{y}</b>: <b>%{z}</b><extra></extra>'
307
- customdata = None
308
-
309
- coloraxis = None
310
- #Colorscale
311
-
312
- #if global_data.visualization_mode_is_64x64:
313
- # if global_data.colorscale_mode == '3':
314
- # coloraxis = 'coloraxis1'
315
-
316
- coloraxis = None
317
- colorscale = 'Viridis'
318
- colorbar = None
319
- #if global_data.show_colorscale and global_data.colorscale_mode == 'mode3':
320
- if global_data.colorscale_mode == 'mode3':
321
- coloraxis = 'coloraxis1'
322
- colorscale = None
323
-
324
- elif global_data.show_colorscale and not global_data.colorscale_mode == 'mode3' and not (not global_data.show_all_heads or (global_data.subplot_cols == 1 and global_data.subplot_rows == 1)):
325
- colorbar = calc_colorbar(row, col)
326
-
327
- zmin, zmax = None, None
328
-
329
- if global_data.colorscale_mode == 'mode2':
330
- pass
331
- zmin = np.amin(global_data.activations[head, :, :])
332
- zmax = np.amax(global_data.activations[head, :, :])
333
- #print('ZMINMAX M2', head, zmin, zmax)
334
-
335
- elif global_data.colorscale_mode == 'mode1':
336
- zmin = np.amin(data)
337
- zmax = np.amax(data)
338
-
339
- #print('Trace data shape', data.shape)
340
- trace = go.Heatmap(
341
- z=data,
342
- colorscale=colorscale,
343
- showscale=global_data.show_colorscale,#True,
344
- colorbar=colorbar,
345
- #colorbar=dict(len=len, y=cy, x=cx, ypad=0, ticklabelposition='inside', ticks='inside', ticklen=3,
346
- # tickfont=dict(color='#7e807f')),
347
- xgap=xgap,
348
- ygap=ygap,
349
- hovertemplate=hovertemplate,
350
- customdata=customdata,
351
- zmin=zmin,
352
- zmax=zmax,
353
- coloraxis=coloraxis
354
- #zmin=zmin,
355
- #zmax=zmax
356
- )
357
- fig.add_trace(trace, row=row, col=col)
358
- return fig
359
-
360
-
361
- def add_heatmap_traces(fig):
362
- print('adding traces, rows:', global_data.subplot_rows, 'cols:', global_data.subplot_cols)
363
- #adding traces is quick so we don't bother using cached values. Wipe old traces and add new.
364
- fig.data = []
365
- if global_data.show_all_heads:
366
- for row in range(global_data.subplot_rows):
367
- for col in range(global_data.subplot_cols):
368
- fig = add_heatmap_trace(fig, row + 1, col + 1)
369
- else:
370
- fig = add_heatmap_trace(fig, 1, 1, global_data.selected_head)
371
- return fig
372
-
373
-
374
- def add_pieces(fig):
375
- if global_data.selected_layer != 'Smolgen':
376
- board = global_data.board
377
-
378
- else:
379
- board = chess.Board(fen=None) #Empty board, we want to draw only the focused square
380
- board_svg = get_svg_board(board, global_data.focused_square_ind, True)
381
-
382
- images = [dict(
383
- source=board_svg,
384
- xref="x"+str(i),
385
- yref="y"+str(i),
386
- x=3.5,
387
- y=3.5,
388
- sizex=8,
389
- sizey=8,
390
- xanchor='center',
391
- yanchor='middle',
392
- sizing="stretch",
393
- )
394
- for i in range(2, 2+255)
395
- ]
396
- images = [dict(
397
- source=board_svg,
398
- xref="x",
399
- yref="y",
400
- x=3.5,
401
- y=3.5,
402
- sizex=8,
403
- sizey=8,
404
- xanchor='center',
405
- yanchor='middle',
406
- sizing="stretch",
407
- )] + images
408
-
409
- fig.layout.images = images
410
- return fig
411
- board_svg = get_svg_board(board, global_data.focused_square_ind, True)
412
- if global_data.check_if_figure_is_cached():
413
- print('USING CACHED')
414
- for img in fig.layout.images:
415
- img['source'] = board_svg
416
- else:
417
- fig.add_layout_image(
418
- dict(
419
- source=board_svg,
420
- xref="x",
421
- yref="y",
422
- x=3.5,
423
- y=3.5,
424
- sizex=8,
425
- sizey=8,
426
- xanchor='center',
427
- yanchor='middle',
428
- sizing="stretch",
429
- ),
430
- row='all',
431
- col='all',
432
- exclude_empty_subplots=True,
433
- )
434
-
435
- return fig
436
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/assets/DejaVuSans-Bold.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:e6476c1b80502924294eed40894c5b18e06c181444ca953e5334262df9c27724
3
- size 705684
 
 
 
 
our_visualization/assets/DejaVuSans.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:7da195a74c55bef988d0d48f9508bd5d849425c1770dba5d7bfc6ce9ed848954
3
- size 757076
 
 
 
 
our_visualization/assets/DejavuSansMono-5m7L.ttf DELETED
@@ -1,3 +0,0 @@
1
- version https://git-lfs.github.com/spec/v1
2
- oid sha256:dfbac4c793ca4b896c34af5c7ccbbaf37924b46cd521a5ecff9130b9c331f575
3
- size 333636
 
 
 
 
our_visualization/assets/custom.css DELETED
@@ -1,59 +0,0 @@
1
- *[class="model-loading"][data-dash-is-loading="true"]::before{
2
- content: "Loading model...";
3
- display: inline-block;
4
- color: red;
5
- visibility: visible;
6
- font-size: 16px;
7
- /*transition-delay: 1s;
8
- transition-property: font-size;
9
- transition-duration: 200ms;*/
10
- }
11
-
12
- .completely-hidden{
13
- display: none;
14
- }
15
-
16
- .hidden-but-reserve-space{
17
- visibility: hidden;
18
- }
19
-
20
- #link{
21
- text-decoration: underline;
22
- cursor: pointer;
23
- }
24
-
25
- .header-container {
26
- border-bottom: thin lightgrey solid;
27
- box-sizing: border-box;
28
- white-space: nowrap;
29
- overflow-y: auto;
30
- }
31
-
32
- .header-control-container {
33
- margin-left: 20px;
34
- margin-right: 20px;
35
- }
36
-
37
- body {
38
- margin: 0;
39
- padding: 0;
40
- font-family: 'BundledDejavuSans';
41
- font-size: 14px;
42
- -moz-user-select: none;
43
- }
44
-
45
- @font-face {
46
- font-family: 'BundledDejavuSansMono';
47
- src: url('DejavuSansMono-5m7L.ttf');
48
- }
49
-
50
- @font-face {
51
- font-family: 'BundledDejavuSans';
52
- src: url('DejaVuSans.ttf');
53
- }
54
-
55
- @font-face {
56
- font-family: 'BundledDejavuSans';
57
- src: url('DejaVuSans-Bold.ttf');
58
- font-weight: bold;
59
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/assets/favicon.ico DELETED
Binary file (15.4 kB)
 
our_visualization/board2planes.py DELETED
@@ -1,97 +0,0 @@
1
- #Author: https://github.com/Arcturai
2
-
3
- import chess
4
- import numpy as np
5
- import re
6
-
7
-
8
- WPAWN = chess.Piece(chess.PAWN, chess.WHITE)
9
- WKNIGHT = chess.Piece(chess.KNIGHT, chess.WHITE)
10
- WBISHOP = chess.Piece(chess.BISHOP, chess.WHITE)
11
- WROOK = chess.Piece(chess.ROOK, chess.WHITE)
12
- WQUEEN = chess.Piece(chess.QUEEN, chess.WHITE)
13
- WKING = chess.Piece(chess.KING, chess.WHITE)
14
- BPAWN = chess.Piece(chess.PAWN, chess.BLACK)
15
- BKNIGHT = chess.Piece(chess.KNIGHT, chess.BLACK)
16
- BBISHOP = chess.Piece(chess.BISHOP, chess.BLACK)
17
- BROOK = chess.Piece(chess.ROOK, chess.BLACK)
18
- BQUEEN = chess.Piece(chess.QUEEN, chess.BLACK)
19
- BKING = chess.Piece(chess.KING, chess.BLACK)
20
-
21
-
22
- def assign_piece2(planes, piece_step, row, col):
23
- planes[piece_step][row][col] = 1
24
-
25
-
26
- DISPATCH2 = {}
27
-
28
- DISPATCH2[str(WPAWN)] = lambda retval, row, col: assign_piece2(retval, 0, row, col)
29
- DISPATCH2[str(WKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 1, row, col)
30
- DISPATCH2[str(WBISHOP)] = lambda retval, row, col: assign_piece2(retval, 2, row, col)
31
- DISPATCH2[str(WROOK)] = lambda retval, row, col: assign_piece2(retval, 3, row, col)
32
- DISPATCH2[str(WQUEEN)] = lambda retval, row, col: assign_piece2(retval, 4, row, col)
33
- DISPATCH2[str(WKING)] = lambda retval, row, col: assign_piece2(retval, 5, row, col)
34
- DISPATCH2[str(BPAWN)] = lambda retval, row, col: assign_piece2(retval, 6, row, col)
35
- DISPATCH2[str(BKNIGHT)] = lambda retval, row, col: assign_piece2(retval, 7, row, col)
36
- DISPATCH2[str(BBISHOP)] = lambda retval, row, col: assign_piece2(retval, 8, row, col)
37
- DISPATCH2[str(BROOK)] = lambda retval, row, col: assign_piece2(retval, 9, row, col)
38
- DISPATCH2[str(BQUEEN)] = lambda retval, row, col: assign_piece2(retval, 10, row, col)
39
- DISPATCH2[str(BKING)] = lambda retval, row, col: assign_piece2(retval, 11, row, col)
40
-
41
-
42
- def append_plane(planes, ones):
43
- if ones:
44
- return np.append(planes, np.ones((1, 8, 8), dtype=np.float), axis=0)
45
- else:
46
- return np.append(planes, np.zeros((1, 8, 8), dtype=np.float), axis=0)
47
-
48
-
49
- def fill_planes(board):
50
- planes = np.zeros((12, 8, 8), dtype=np.float)
51
- for row in range(8):
52
- for col in range(8):
53
- piece = str(board.piece_at(chess.SQUARES[row * 8 + col]))
54
- if piece != "None":
55
- DISPATCH2[piece](planes, row, col)
56
- planes = append_plane(planes, board.is_repetition(2))
57
- return planes
58
-
59
-
60
- def board2planes(board_):
61
- if not board_.turn:
62
- board = board_.mirror()
63
- else:
64
- board = board_
65
-
66
- retval = fill_planes(board)
67
-
68
- s_board = board_.copy()
69
- for i in range(7):
70
- if s_board.move_stack.__len__() > 0:
71
- s_board.pop()
72
- b = s_board.mirror() if not board_.turn else s_board.copy()
73
- retval = np.append(retval, fill_planes(b), axis=0)
74
- else:
75
- retval = np.append(retval, np.zeros((13, 8, 8), dtype=np.float), axis=0)
76
-
77
- retval = append_plane(retval, bool(board.castling_rights & chess.BB_H1))
78
- retval = append_plane(retval, bool(board.castling_rights & chess.BB_A1))
79
- retval = append_plane(retval, bool(board.castling_rights & chess.BB_H8))
80
- retval = append_plane(retval, bool(board.castling_rights & chess.BB_A8))
81
- retval = append_plane(retval, not board_.turn)
82
- retval = np.append(retval, np.full((1, 8, 8), fill_value=board_.halfmove_clock/99., dtype=np.float), axis=0)
83
- retval = append_plane(retval, False)
84
- retval = append_plane(retval, True)
85
-
86
- return np.expand_dims(retval, axis=0)
87
-
88
-
89
- def bulk_board2planes(boards):
90
- planes = []
91
- for b in boards:
92
- temp = board2planes(b)
93
- planes.append(temp)
94
- pl = tuple(planes)
95
- retval = np.concatenate(pl, axis=0)
96
- return retval
97
-
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/constants.py DELETED
@@ -1,24 +0,0 @@
1
- import sys
2
- import os
3
-
4
-
5
- def root_directory():
6
- if getattr(sys, 'frozen', False):
7
- # The application is frozen
8
- root = os.path.dirname(sys.executable)
9
- else:
10
- # The application is not frozen
11
- root = os.path.dirname(__file__) # os.path.dirname(os.path.abspath(__file__))#os.path.dirname(__file__)
12
- return (root)
13
-
14
-
15
- ROOT_DIR = root_directory()
16
-
17
- LEFT_PANE_WIDTH = 90
18
- RIGHT_PANE_WIDTH = 100 - LEFT_PANE_WIDTH
19
- GRAPH_PANE_HEIGHT = 100
20
- HEADER_HEIGHT = 11
21
- CONTENT_HEIGHT = 100 - HEADER_HEIGHT
22
-
23
- EXPORT_FORMAT = 'png' #one of png, svg, jpeg, webp
24
- EXPORT_SCALE = 1.0 #When 1.0, the figure is exported as same size as currently in the browser. Use e.g. 0.5 to scale to half.
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/datasets/test_set.csv DELETED
The diff for this file is too large to render. See raw diff
 
our_visualization/global_data.py DELETED
@@ -1,539 +0,0 @@
1
- import chess.engine
2
- from constants import ROOT_DIR, CONTENT_HEIGHT, LEFT_PANE_WIDTH, EXPORT_FORMAT, EXPORT_SCALE
3
- from time import sleep
4
- # from test_array import activations_array
5
-
6
- from copy import deepcopy
7
-
8
- from board2planes import board2planes
9
-
10
- import yaml
11
-
12
- import os
13
- from os.path import isdir, join
14
- import sys
15
-
16
- import numpy as np
17
-
18
-
19
- SIMULATE_TF = False #TODO: Remove this option, deprecated
20
- # turn off tensorflow importing and generate random data to speed up development
21
- DEV_MODE = False
22
- SIMULATED_LAYERS = 6
23
- SIMULATED_HEADS = 64
24
- FIXED_ROW = None # 1 #None to disable
25
- FIXED_COL = None # 5 #None to disable
26
- if DEV_MODE:
27
- class DummyModel:
28
- def __init__(self, layers, heads):
29
- self.layers = layers
30
- self.heads = heads
31
-
32
- def __call__(self, *args, **kwargs):
33
- data = [np.random.rand(1, self.heads, 64, 64) for i in range(self.layers)]
34
- return [None, None, None, data]
35
-
36
- else:
37
- import tensorflow as tf
38
- from tensorflow.compat.v1 import ConfigProto
39
- from tensorflow.compat.v1 import InteractiveSession
40
-
41
-
42
- # class to hold data, state and configurations
43
- # Dash is stateless and in general it is very bad idea to store data in global variables on server side
44
- # However, this application is ment to be run by single user on local machine, so it is safe to store data and state
45
- # information on global object
46
- class GlobalData:
47
- def __init__(self):
48
- import os
49
- if not DEV_MODE:
50
- # import tensorflow as tf
51
- os.environ["CUDA_VISIBLE_DEVICES"] = "0"
52
-
53
- # from tensorflow.compat.v1 import ConfigProto
54
- # from tensorflow.compat.v1 import InteractiveSession
55
- # import chess
56
- # import matplotlib.patheffects as path_effects
57
-
58
- #config = ConfigProto()
59
- #config.gpu_options.allow_growth = True
60
- #session = InteractiveSession(config=config)
61
- #tf.keras.backend.clear_session()
62
-
63
- self.tmp = 0
64
- self.export_format = EXPORT_FORMAT
65
- self.export_scale = EXPORT_SCALE
66
- self.fen = 'rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1' # '2kr3r/ppp2b2/2n4p/4p3/Q2Pq1pP/2P1N3/PP3PP1/R1B1KB1R w KQ - 3 18'#'6n1/1p1k4/3p4/pNp5/P1P4p/7P/1P4KP/r7 w - - 2 121'#
67
- self.board = chess.Board(fen=self.fen)
68
- self.focused_square_ind = 0
69
- self.active_move_table_cell = None # tuple (row_ind, col_id), e.g. (12, 'White')
70
-
71
- self.activations = None # activations_array
72
- self.visualization_mode = 'ROW'
73
- self.visualization_mode_is_64x64 = False
74
- self.subplot_mode = 'big' #'fit' # big'#'fit'#, 'big'
75
- self.subplot_cols = 0
76
- self.subplot_rows = 0
77
- self.number_of_heads = 0
78
- self.selected_head = None
79
- self.show_all_heads = True
80
-
81
- self.show_colorscale = False
82
- self.colorscale_mode = 'mode1'
83
-
84
- self.figure_container_height = '100%' # '100%'
85
-
86
- self.running_counter = 0 # used to pass new values to hidden indicator elements which will trigger follow-up callback
87
- self.grid_has_changed = False
88
-
89
- # self.has_subplot_grid_changed = True
90
- # self.figure_layout_images = None #store layout and only recalculate when subplot grid has changed
91
- # self.figure_layout_annotations = None
92
- # self.need_update_axis = True
93
-
94
- self.screen_w = 0
95
- self.screen_h = 0
96
- self.figure_w = 0
97
- self.figure_h = 0
98
- self.heatmap_w = 0
99
- self.heatmap_h = 0
100
- self.heatmap_fig_w = 0
101
- self.heatmap_fig_h = 0
102
- self.heatmap_gap = 0
103
- self.colorscale_x_offset = 0
104
-
105
- self.heatmap_horizontal_gap = 0.275
106
-
107
- self.figure_cache = {}
108
-
109
- self.update_grid_shape()
110
-
111
- self.pgn_data = [] # list of boards in pgn
112
- self.move_table_boards = {} # dict of boards in pgn, key is (move_table.row_ind, move_table.column_id)
113
-
114
- if not SIMULATE_TF:
115
- self.selected_layer = None
116
- else:
117
- self.selected_layer = 0
118
-
119
- self.nr_of_layers_in_body = -1
120
- self.has_attention_policy = False
121
-
122
- self.model_paths = []
123
- self.model_names = []
124
- self.model_yamls = {} #key = model path, value = yaml of that model
125
- self.model_cache = {}
126
- self.find_models2()
127
- self.model_path = None#self.model_paths[0] # '/home/jusufe/PycharmProjects/lc0-attention-visualizer/T12_saved_model_1M'
128
- self.model = None
129
- self.tfp = None #TensorflowProcess
130
- if not SIMULATE_TF:
131
- self.load_model()
132
- self.activations_data = None
133
-
134
- if self.model is not None or SIMULATE_TF:
135
- self.update_activations_data()
136
-
137
- if self.selected_layer is not None:
138
- self.set_layer(self.selected_layer)
139
-
140
- self.move_table_active_cell = None
141
-
142
- self.force_update_graph = False
143
-
144
- def set_subplot_mode(self, fit_to_page):
145
- if fit_to_page == [True]:
146
- self.subplot_mode = 'fit'
147
- else:
148
- self.subplot_mode = 'big'
149
- self.update_grid_shape()
150
-
151
- def set_screen_size(self, w, h):
152
- self.screen_w = w
153
- self.screen_h = h
154
-
155
- self.figure_w = w*LEFT_PANE_WIDTH/100
156
- self.figure_h = h*CONTENT_HEIGHT/100
157
- print('GRAPH AREA', self.figure_w, self.figure_h)
158
-
159
- def set_heatmap_size(self, size):
160
- if size != '1':
161
- #print('-----------------------HEATMAP SIZE', size)
162
- # w, h = size
163
- # print('TYETETETETEU', global_data.screen_w)
164
- # global_data.set_screen_size(w, h)
165
- #print('>>>>>: HEATMAP WIDTH', size[0])
166
- #print('>>>>>: HEATMAP HEIGHT', size[1])
167
- #print('>>>>>: FIG WIDTH', size[2])
168
- #print('>>>>>: FIG HEIGHT', size[3])
169
- #print('>>>>>: HEATMAP GAP', size[4])
170
-
171
- self.heatmap_w = float(size[0])
172
- self.heatmap_h = float(size[1])
173
- self.heatmap_fig_w = float(size[2])
174
- self.heatmap_fig_h = float(size[3])
175
- self.heatmap_gap = round(float(size[4]), 2)
176
-
177
- self.colorscale_x_offset = float(size[5])/self.heatmap_fig_w
178
-
179
- if size[6] == 1:
180
- self.force_update_graph = True
181
- else:
182
- self.force_update_graph = False
183
-
184
-
185
- #if self.heatmap_gap < 30:
186
- # self.heatmap_horizontal_gap += 0.025
187
-
188
- # self.heatmap_horizontal_gap = min(0.25, self.heatmap_horizontal_gap)
189
- #if self.heatmap_gap < 200:
190
- # self.heatmap_horizontal_gap += -0.025
191
- # self.heatmap_horizontal_gap = max(0.1, self.heatmap_horizontal_gap)
192
-
193
- def set_colorscale_mode(self, mode, colorscale_mode, colorscale_mode_64x64, show):
194
- if mode == '64x64':
195
- self.colorscale_mode = colorscale_mode_64x64
196
- else:
197
- self.colorscale_mode = colorscale_mode
198
- #print('SHOW value', show)
199
- self.show_colorscale = show == [True]
200
-
201
- def cache_figure(self, fig):
202
- if not self.check_if_figure_is_cached() and fig != {}:
203
- key = self.get_figure_cache_key()
204
- cached_fig = deepcopy(fig)
205
- cached_fig.update_layout({'coloraxis1': None}, overwrite=True)
206
- #print('CACHING FIGURE:')
207
- self.figure_cache[key] = cached_fig
208
-
209
- def get_cached_figure(self):
210
- if self.check_if_figure_is_cached():
211
- key = self.get_figure_cache_key()
212
- fig = deepcopy(self.figure_cache[key])
213
- else:
214
- fig = None
215
- return fig
216
-
217
- def check_if_figure_is_cached(self):
218
- key = self.get_figure_cache_key()
219
- return key in self.figure_cache
220
-
221
- def get_figure_cache_key(self):
222
- return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64,
223
- self.selected_head if not self.show_all_heads else -1, self.show_colorscale, self.colorscale_mode,
224
- self.board.turn)
225
- #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.selected_head if not self.show_all_heads else -1, self.heatmap_horizontal_gap, self.heatmap_fig_h, self.heatmap_fig_w)
226
- #return (self.subplot_rows, self.subplot_cols, self.visualization_mode_is_64x64, self.show_all_heads)
227
-
228
- def get_side_to_move(self):
229
- return ['Black', 'White'][self.board.turn]
230
-
231
- def load_model(self):
232
- if self.model_path in self.model_cache:
233
- self.model, self.tfp = self.model_cache[self.model_path]
234
-
235
- elif self.model_path is not None:
236
- #net = '/home/jusufe/Projects/lc0/BT1024-3142c-swa-186000.pb.gz'
237
- #yaml_path = '/home/jusufe/Downloads/cfg.yaml'
238
- if not DEV_MODE:
239
- net = self.model_path
240
- yaml_path = self.model_yamls[self.model_path]
241
- with open(yaml_path) as f:
242
- cfg = f.read()
243
- cfg = yaml.safe_load(cfg)
244
-
245
- if 'dropout_rate' in cfg['model']:
246
- print('Setting dropout_rate to 0.0')
247
- cfg['model']['dropout_rate'] = 0.0
248
-
249
- tfp = tfprocess.TFProcess(cfg)
250
- tfp.init_net()
251
- tfp.replace_weights(net, ignore_errors=True)
252
- self.model = tfp.model
253
- self.tfp = tfp
254
- else:
255
- self.model = DummyModel(SIMULATED_LAYERS, SIMULATED_HEADS)
256
- self.tfp = None
257
-
258
- else:
259
- self.model = None
260
- self.tfp = None
261
-
262
- def find_models(self):
263
- root = ROOT_DIR
264
- models_root_folder = os.path.join(root, 'models')
265
- model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
266
- model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
267
- isdir(join(models_root_folder, f))]
268
- self.model_names = model_folders
269
- self.model_paths = model_paths
270
-
271
- #print('MODELS:')
272
- #print(self.model_names)
273
- #print(self.model_paths)
274
-
275
- def find_models2(self):
276
- import os
277
- from os.path import isdir, join
278
- root = ROOT_DIR
279
- models_root_folder = os.path.join(root, 'models')
280
- model_folders = [f for f in os.listdir(models_root_folder) if isdir(join(models_root_folder, f))]
281
- model_paths = [os.path.relpath(join(models_root_folder, f)) for f in os.listdir(models_root_folder) if
282
- isdir(join(models_root_folder, f))]
283
-
284
- models = []
285
- paths = []
286
- yamls = []
287
- for path in model_paths:
288
- yaml_files = [file for file in os.listdir(path) if file.endswith(".yaml")]
289
- if len(yaml_files) != 1:
290
- continue
291
- model_files = [file for file in os.listdir(path) if file.endswith(".pb.gz")]
292
- if len(model_files) == 0:
293
- continue
294
-
295
- models += model_files
296
- paths += [os.path.relpath(join(path, f)) for f in model_files]
297
- yaml_file = os.path.relpath(join(path, yaml_files[0]))
298
- yamls += [yaml_file]*len(model_files)
299
-
300
- self.model_yamls = {path: yaml_file for path, yaml_file in zip(paths, yamls)}
301
- self.model_names = models
302
- self.model_paths = paths#model_paths
303
-
304
-
305
- def update_activations_data(self):
306
-
307
- if self.model is not None and self.selected_layer is None:
308
- self.selected_layer = 0
309
-
310
- if not SIMULATE_TF:
311
- if self.selected_layer is not None and self.model is not None and self.selected_layer != 'Smolgen':
312
- if not DEV_MODE:
313
- inputs = board2planes(self.board)
314
- inputs = tf.reshape(tf.convert_to_tensor(inputs, dtype=tf.float32), [-1, 112, 8, 8])
315
- else:
316
- inputs = None
317
-
318
- outputs = self.model(inputs)
319
- self.activations_data = outputs[-1]
320
- for i,x in enumerate(self.activations_data):
321
- print( 'LAYERS', i, x.shape)
322
-
323
- #smolgen = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
324
- #print('Smolgen')
325
- #print(type(smolgen))
326
- #print(smolgen.shape)
327
- #print(type(smolgen[0]))
328
- #print(smolgen[0].shape)
329
- #_, _, _, self.activations_data = self.model(inputs)
330
- elif self.selected_layer == 'Smolgen' and self.tfp is not None and self.tfp.use_smolgen:
331
- weights = self.tfp.smol_weight_gen_dense.get_weights()[0]
332
- self.activations_data = weights.reshape((weights.shape[0], 64, 64))
333
- print('TYPEEEEE', type(self.activations_data))
334
-
335
- else:
336
- layers = SIMULATED_LAYERS
337
- heads = SIMULATED_HEADS
338
- self.activations_data = [np.random.rand(1, heads, 64, 64) for i in range(layers)]
339
-
340
- if self.model is not None:
341
-
342
- if self.model_path not in self.model_cache:
343
- self.model_cache[self.model_path] = [self.model, self.tfp]
344
-
345
- self.update_layers_in_body_count()
346
-
347
- #TODO: figure out better way to determine if we have policy attention weights
348
- #TODO: What happens if policy vis is selected and user switches to model without policy layer? Take care of this case.
349
- if self.activations_data is not None and self.activations_data[-2].shape == (1, 8, 24):
350
- self.has_attention_policy = True
351
- else:
352
- self.has_attention_policy = False
353
- # self.update_selected_activation_data()
354
- # self.activations = self.activations_data[self.selected_layer]
355
-
356
- def update_grid_shape(self):
357
- # TODO: add client side callback triggered by Interval component to save window or precise container dimensions to Div
358
- # TODO: Trigger server side figure update callback when dimensions are recorded and store in global_data
359
- # TODO: If needed, recalculate subplot rows and cols and container scaler based on the changed dimension
360
-
361
- def calc_cols(heads, rows):
362
- if heads % rows == 0:
363
- cols = int(heads / rows)
364
- else:
365
- cols = int(1 + heads / rows)
366
- return cols
367
-
368
- if FIXED_ROW and FIXED_COL:
369
- self.subplot_cols = FIXED_COL
370
- self.subplot_rows = FIXED_ROW
371
- return None
372
-
373
- heads = self.number_of_heads
374
- if self.subplot_mode == 'fit':
375
- max_rows_in_screen = 4
376
- if heads <= 4:
377
- rows = 1
378
- elif heads <= 8:
379
- rows = 2
380
- else:
381
- rows = heads // 8 + int(heads % 8 != 0)
382
-
383
- elif self.subplot_mode == 'big':
384
- #print(heads)
385
-
386
- max_rows_in_screen = 2
387
- rows = heads // 4 + int(heads % 4 != 0)
388
- #print(rows)
389
-
390
- if rows > max_rows_in_screen:
391
- container_height = f'{int((rows / max_rows_in_screen) * 100)}%'
392
- else:
393
- container_height = '100%'
394
-
395
- if rows != 0:
396
- cols = calc_cols(heads, rows)
397
- else:
398
- cols = 0
399
-
400
- if self.subplot_rows != rows or self.subplot_cols != cols:
401
- self.grid_has_changed = True
402
-
403
- self.subplot_cols = cols
404
- self.subplot_rows = rows
405
-
406
- if self.show_all_heads:
407
- self.figure_container_height = container_height
408
- else:
409
- self.figure_container_height = '100%'
410
-
411
- def update_selected_activation_data(self):
412
- # import numpy as np
413
- # self.activations = activations_array + np.random.rand(8, 64, 64)
414
- if self.activations_data is not None:
415
- if self.selected_layer not in ('Policy', 'Smolgen'):
416
- if not DEV_MODE:
417
- activations = tf.squeeze(self.activations_data[self.selected_layer], axis=0).numpy()
418
- #self.activations = activations[:, ::-1, :] #Flip along y-axis
419
- else:
420
- activations = np.squeeze(self.activations_data[self.selected_layer], axis=0)
421
- elif self.selected_layer == 'Policy':
422
- print('RAW POLICY SHAPE', self.activations_data[-1].shape)
423
- activations = self.activations_data[-1].numpy()
424
- #print('POLICY SHAPE', activations.shape)
425
-
426
- #print('RAW POLICY SHAPE', self.activations_data[-1].shape)
427
- #activations = np.squeeze(self.activations_data[-1].numpy(), axis=0) #shape 64,64
428
- #promo = np.squeeze(self.activations_data[-2].numpy(), axis=0) #shape 8,24
429
- #print('promo shape:', promo.shape)
430
- #if self.board.turn:
431
- # pad_shape = (48, 8)
432
- #else:
433
- # pad_shape = (8, 48)
434
- #promo_padded = np.pad(promo, (pad_shape, (0, 0)), mode='constant', constant_values=None) #shape 64,24
435
- #self.activations = np.expand_dims(np.concatenate((activations, promo_padded), axis=1), axis=0)#shape 1,64,88
436
- #print('POLICY SHAPE', self.activations.shape)
437
- elif self.selected_layer == 'Smolgen':
438
- activations = self.tfp.smol_weight_gen_dense.get_weights()[0].reshape((256, 64, 64))
439
-
440
- self.activations = activations[:, ::-1, :] # Flip along y-axis
441
-
442
- def set_visualization_mode(self, mode):
443
- self.visualization_mode = mode
444
- self.visualization_mode_is_64x64 = mode == '64x64'
445
-
446
- def set_layer(self, layer):
447
- self.selected_layer = layer
448
- self.update_selected_activation_data()
449
- if layer not in ('Policy', 'Smolgen'):
450
- self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
451
- elif layer == 'Policy':
452
- self.number_of_heads = 1
453
- elif layer == 'Smolgen':
454
- self.number_of_heads = self.activations.shape[0]
455
- self.set_head(0)
456
- self.update_grid_shape()
457
-
458
- def set_head(self, head):
459
- self.selected_head = head
460
-
461
- def set_model(self, model):
462
- if model != self.model_path:
463
- self.model_path = model
464
- self.load_model()
465
- self.update_activations_data()
466
- self.update_selected_activation_data()
467
- self.number_of_heads = self.activations_data[self.selected_layer].shape[1]
468
- if self.selected_head is None:
469
- self.selected_head = 0
470
- else:
471
- self.selected_head = min(self.selected_head, self.number_of_heads - 1)
472
- self.update_grid_shape()
473
- if SIMULATE_TF:
474
- sleep(2)
475
-
476
- def update_layers_in_body_count(self):
477
- # TODO: figure out robust way to separate attention layers in body from the rest. UPDATE: Use yaml
478
- heads = self.activations_data[0].shape[1]
479
- for ind, layer in enumerate(self.activations_data):
480
- if layer.shape[1] != heads or len(layer.shape) != 4:
481
- ind = ind - 1
482
- break
483
- self.nr_of_layers_in_body = ind + 1
484
- if self.selected_layer not in ('Policy', 'Smolgen'):
485
- self.selected_layer = min(self.selected_layer, self.nr_of_layers_in_body - 1)
486
-
487
- def get_head_data(self, head):
488
-
489
- if self.activations.shape[0] <= head:
490
- return None
491
-
492
- if self.visualization_mode == '64x64':
493
- # print('64x64 selection')
494
- data = self.activations[head, :, :]
495
-
496
- elif self.visualization_mode == 'ROW':
497
- # print('ROW selection')
498
- if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
499
- row = 63 - self.focused_square_ind
500
- data = self.activations[head, row, :].reshape((8, 8))
501
- else:
502
- #row = self.focused_square_ind
503
- multiples = self.focused_square_ind // 8
504
- remainder = self.focused_square_ind % 8
505
-
506
- a = 7 - remainder
507
- b = multiples * 8
508
- row = a + b
509
- data = self.activations[head, row, :].reshape((8, 8))[::-1, :]
510
- else:
511
- # print('COL selection')
512
- if self.board.turn or self.selected_layer == 'Smolgen': #White turn to move
513
- col = self.focused_square_ind
514
- data = self.activations[head, :, col].reshape((8, 8))[::-1, ::-1]
515
- else:
516
- focused = 63 - self.focused_square_ind
517
- multiples = focused // 8
518
- remainder = focused % 8
519
- a = 7 - remainder
520
- b = multiples * 8
521
- col = a + b
522
- #print('COL!!!!!!!!!!!!!!!!!', col, a, b, focused, self.focused_square_ind)
523
- data = self.activations[head, :, col].reshape((8, 8))[:, ::-1]
524
- return data
525
-
526
- def set_fen(self, fen):
527
- self.board.set_fen(fen)
528
- self.fen = fen
529
- self.update_activations_data()
530
- self.update_selected_activation_data()
531
-
532
- def set_board(self, board):
533
- self.board = deepcopy(board)
534
- self.update_activations_data()
535
- self.update_selected_activation_data()
536
-
537
-
538
- global_data = GlobalData()
539
- print('global data created')
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/layouts/visualization_demo.slides.json DELETED
@@ -1,4 +0,0 @@
1
- {
2
- "type": "slides",
3
- "data": {}
4
- }
 
 
 
 
 
our_visualization/leela_board.py DELETED
@@ -1,617 +0,0 @@
1
- # -*- coding: utf-8 -*-
2
-
3
- # uci_to_idx is a list of four dicts {uci -> NN policy index}
4
- # 0 = white, no-castling
5
- # 1 = white, castling
6
- # 2 = black, no-castling
7
- # 3 = black, castling
8
- # Black moves are flipped, and castling moves are mapped to the e8a8, e8h8, e1a1, e1h1 indexes
9
- # from their respective UCI names
10
-
11
- uci_to_idx = []
12
-
13
- # The index-to-uci list originates from here:
14
- # https://github.com/glinscott/leela-chess/blob/master/lc0/src/chess/bitboard.cc
15
-
16
- # White, no-castling
17
- _idx_to_move_wn = [
18
- 'a1b1', 'a1c1', 'a1d1', 'a1e1', 'a1f1', 'a1g1', 'a1h1',
19
- 'a1a2', 'a1b2', 'a1c2', 'a1a3', 'a1b3', 'a1c3', 'a1a4',
20
- 'a1d4', 'a1a5', 'a1e5', 'a1a6', 'a1f6', 'a1a7', 'a1g7',
21
- 'a1a8', 'a1h8', 'b1a1', 'b1c1', 'b1d1', 'b1e1', 'b1f1',
22
- 'b1g1', 'b1h1', 'b1a2', 'b1b2', 'b1c2', 'b1d2', 'b1a3',
23
- 'b1b3', 'b1c3', 'b1d3', 'b1b4', 'b1e4', 'b1b5', 'b1f5',
24
- 'b1b6', 'b1g6', 'b1b7', 'b1h7', 'b1b8', 'c1a1', 'c1b1',
25
- 'c1d1', 'c1e1', 'c1f1', 'c1g1', 'c1h1', 'c1a2', 'c1b2',
26
- 'c1c2', 'c1d2', 'c1e2', 'c1a3', 'c1b3', 'c1c3', 'c1d3',
27
- 'c1e3', 'c1c4', 'c1f4', 'c1c5', 'c1g5', 'c1c6', 'c1h6',
28
- 'c1c7', 'c1c8', 'd1a1', 'd1b1', 'd1c1', 'd1e1', 'd1f1',
29
- 'd1g1', 'd1h1', 'd1b2', 'd1c2', 'd1d2', 'd1e2', 'd1f2',
30
- 'd1b3', 'd1c3', 'd1d3', 'd1e3', 'd1f3', 'd1a4', 'd1d4',
31
- 'd1g4', 'd1d5', 'd1h5', 'd1d6', 'd1d7', 'd1d8', 'e1a1',
32
- 'e1b1', 'e1c1', 'e1d1', 'e1f1', 'e1g1', 'e1h1', 'e1c2',
33
- 'e1d2', 'e1e2', 'e1f2', 'e1g2', 'e1c3', 'e1d3', 'e1e3',
34
- 'e1f3', 'e1g3', 'e1b4', 'e1e4', 'e1h4', 'e1a5', 'e1e5',
35
- 'e1e6', 'e1e7', 'e1e8', 'f1a1', 'f1b1', 'f1c1', 'f1d1',
36
- 'f1e1', 'f1g1', 'f1h1', 'f1d2', 'f1e2', 'f1f2', 'f1g2',
37
- 'f1h2', 'f1d3', 'f1e3', 'f1f3', 'f1g3', 'f1h3', 'f1c4',
38
- 'f1f4', 'f1b5', 'f1f5', 'f1a6', 'f1f6', 'f1f7', 'f1f8',
39
- 'g1a1', 'g1b1', 'g1c1', 'g1d1', 'g1e1', 'g1f1', 'g1h1',
40
- 'g1e2', 'g1f2', 'g1g2', 'g1h2', 'g1e3', 'g1f3', 'g1g3',
41
- 'g1h3', 'g1d4', 'g1g4', 'g1c5', 'g1g5', 'g1b6', 'g1g6',
42
- 'g1a7', 'g1g7', 'g1g8', 'h1a1', 'h1b1', 'h1c1', 'h1d1',
43
- 'h1e1', 'h1f1', 'h1g1', 'h1f2', 'h1g2', 'h1h2', 'h1f3',
44
- 'h1g3', 'h1h3', 'h1e4', 'h1h4', 'h1d5', 'h1h5', 'h1c6',
45
- 'h1h6', 'h1b7', 'h1h7', 'h1a8', 'h1h8', 'a2a1', 'a2b1',
46
- 'a2c1', 'a2b2', 'a2c2', 'a2d2', 'a2e2', 'a2f2', 'a2g2',
47
- 'a2h2', 'a2a3', 'a2b3', 'a2c3', 'a2a4', 'a2b4', 'a2c4',
48
- 'a2a5', 'a2d5', 'a2a6', 'a2e6', 'a2a7', 'a2f7', 'a2a8',
49
- 'a2g8', 'b2a1', 'b2b1', 'b2c1', 'b2d1', 'b2a2', 'b2c2',
50
- 'b2d2', 'b2e2', 'b2f2', 'b2g2', 'b2h2', 'b2a3', 'b2b3',
51
- 'b2c3', 'b2d3', 'b2a4', 'b2b4', 'b2c4', 'b2d4', 'b2b5',
52
- 'b2e5', 'b2b6', 'b2f6', 'b2b7', 'b2g7', 'b2b8', 'b2h8',
53
- 'c2a1', 'c2b1', 'c2c1', 'c2d1', 'c2e1', 'c2a2', 'c2b2',
54
- 'c2d2', 'c2e2', 'c2f2', 'c2g2', 'c2h2', 'c2a3', 'c2b3',
55
- 'c2c3', 'c2d3', 'c2e3', 'c2a4', 'c2b4', 'c2c4', 'c2d4',
56
- 'c2e4', 'c2c5', 'c2f5', 'c2c6', 'c2g6', 'c2c7', 'c2h7',
57
- 'c2c8', 'd2b1', 'd2c1', 'd2d1', 'd2e1', 'd2f1', 'd2a2',
58
- 'd2b2', 'd2c2', 'd2e2', 'd2f2', 'd2g2', 'd2h2', 'd2b3',
59
- 'd2c3', 'd2d3', 'd2e3', 'd2f3', 'd2b4', 'd2c4', 'd2d4',
60
- 'd2e4', 'd2f4', 'd2a5', 'd2d5', 'd2g5', 'd2d6', 'd2h6',
61
- 'd2d7', 'd2d8', 'e2c1', 'e2d1', 'e2e1', 'e2f1', 'e2g1',
62
- 'e2a2', 'e2b2', 'e2c2', 'e2d2', 'e2f2', 'e2g2', 'e2h2',
63
- 'e2c3', 'e2d3', 'e2e3', 'e2f3', 'e2g3', 'e2c4', 'e2d4',
64
- 'e2e4', 'e2f4', 'e2g4', 'e2b5', 'e2e5', 'e2h5', 'e2a6',
65
- 'e2e6', 'e2e7', 'e2e8', 'f2d1', 'f2e1', 'f2f1', 'f2g1',
66
- 'f2h1', 'f2a2', 'f2b2', 'f2c2', 'f2d2', 'f2e2', 'f2g2',
67
- 'f2h2', 'f2d3', 'f2e3', 'f2f3', 'f2g3', 'f2h3', 'f2d4',
68
- 'f2e4', 'f2f4', 'f2g4', 'f2h4', 'f2c5', 'f2f5', 'f2b6',
69
- 'f2f6', 'f2a7', 'f2f7', 'f2f8', 'g2e1', 'g2f1', 'g2g1',
70
- 'g2h1', 'g2a2', 'g2b2', 'g2c2', 'g2d2', 'g2e2', 'g2f2',
71
- 'g2h2', 'g2e3', 'g2f3', 'g2g3', 'g2h3', 'g2e4', 'g2f4',
72
- 'g2g4', 'g2h4', 'g2d5', 'g2g5', 'g2c6', 'g2g6', 'g2b7',
73
- 'g2g7', 'g2a8', 'g2g8', 'h2f1', 'h2g1', 'h2h1', 'h2a2',
74
- 'h2b2', 'h2c2', 'h2d2', 'h2e2', 'h2f2', 'h2g2', 'h2f3',
75
- 'h2g3', 'h2h3', 'h2f4', 'h2g4', 'h2h4', 'h2e5', 'h2h5',
76
- 'h2d6', 'h2h6', 'h2c7', 'h2h7', 'h2b8', 'h2h8', 'a3a1',
77
- 'a3b1', 'a3c1', 'a3a2', 'a3b2', 'a3c2', 'a3b3', 'a3c3',
78
- 'a3d3', 'a3e3', 'a3f3', 'a3g3', 'a3h3', 'a3a4', 'a3b4',
79
- 'a3c4', 'a3a5', 'a3b5', 'a3c5', 'a3a6', 'a3d6', 'a3a7',
80
- 'a3e7', 'a3a8', 'a3f8', 'b3a1', 'b3b1', 'b3c1', 'b3d1',
81
- 'b3a2', 'b3b2', 'b3c2', 'b3d2', 'b3a3', 'b3c3', 'b3d3',
82
- 'b3e3', 'b3f3', 'b3g3', 'b3h3', 'b3a4', 'b3b4', 'b3c4',
83
- 'b3d4', 'b3a5', 'b3b5', 'b3c5', 'b3d5', 'b3b6', 'b3e6',
84
- 'b3b7', 'b3f7', 'b3b8', 'b3g8', 'c3a1', 'c3b1', 'c3c1',
85
- 'c3d1', 'c3e1', 'c3a2', 'c3b2', 'c3c2', 'c3d2', 'c3e2',
86
- 'c3a3', 'c3b3', 'c3d3', 'c3e3', 'c3f3', 'c3g3', 'c3h3',
87
- 'c3a4', 'c3b4', 'c3c4', 'c3d4', 'c3e4', 'c3a5', 'c3b5',
88
- 'c3c5', 'c3d5', 'c3e5', 'c3c6', 'c3f6', 'c3c7', 'c3g7',
89
- 'c3c8', 'c3h8', 'd3b1', 'd3c1', 'd3d1', 'd3e1', 'd3f1',
90
- 'd3b2', 'd3c2', 'd3d2', 'd3e2', 'd3f2', 'd3a3', 'd3b3',
91
- 'd3c3', 'd3e3', 'd3f3', 'd3g3', 'd3h3', 'd3b4', 'd3c4',
92
- 'd3d4', 'd3e4', 'd3f4', 'd3b5', 'd3c5', 'd3d5', 'd3e5',
93
- 'd3f5', 'd3a6', 'd3d6', 'd3g6', 'd3d7', 'd3h7', 'd3d8',
94
- 'e3c1', 'e3d1', 'e3e1', 'e3f1', 'e3g1', 'e3c2', 'e3d2',
95
- 'e3e2', 'e3f2', 'e3g2', 'e3a3', 'e3b3', 'e3c3', 'e3d3',
96
- 'e3f3', 'e3g3', 'e3h3', 'e3c4', 'e3d4', 'e3e4', 'e3f4',
97
- 'e3g4', 'e3c5', 'e3d5', 'e3e5', 'e3f5', 'e3g5', 'e3b6',
98
- 'e3e6', 'e3h6', 'e3a7', 'e3e7', 'e3e8', 'f3d1', 'f3e1',
99
- 'f3f1', 'f3g1', 'f3h1', 'f3d2', 'f3e2', 'f3f2', 'f3g2',
100
- 'f3h2', 'f3a3', 'f3b3', 'f3c3', 'f3d3', 'f3e3', 'f3g3',
101
- 'f3h3', 'f3d4', 'f3e4', 'f3f4', 'f3g4', 'f3h4', 'f3d5',
102
- 'f3e5', 'f3f5', 'f3g5', 'f3h5', 'f3c6', 'f3f6', 'f3b7',
103
- 'f3f7', 'f3a8', 'f3f8', 'g3e1', 'g3f1', 'g3g1', 'g3h1',
104
- 'g3e2', 'g3f2', 'g3g2', 'g3h2', 'g3a3', 'g3b3', 'g3c3',
105
- 'g3d3', 'g3e3', 'g3f3', 'g3h3', 'g3e4', 'g3f4', 'g3g4',
106
- 'g3h4', 'g3e5', 'g3f5', 'g3g5', 'g3h5', 'g3d6', 'g3g6',
107
- 'g3c7', 'g3g7', 'g3b8', 'g3g8', 'h3f1', 'h3g1', 'h3h1',
108
- 'h3f2', 'h3g2', 'h3h2', 'h3a3', 'h3b3', 'h3c3', 'h3d3',
109
- 'h3e3', 'h3f3', 'h3g3', 'h3f4', 'h3g4', 'h3h4', 'h3f5',
110
- 'h3g5', 'h3h5', 'h3e6', 'h3h6', 'h3d7', 'h3h7', 'h3c8',
111
- 'h3h8', 'a4a1', 'a4d1', 'a4a2', 'a4b2', 'a4c2', 'a4a3',
112
- 'a4b3', 'a4c3', 'a4b4', 'a4c4', 'a4d4', 'a4e4', 'a4f4',
113
- 'a4g4', 'a4h4', 'a4a5', 'a4b5', 'a4c5', 'a4a6', 'a4b6',
114
- 'a4c6', 'a4a7', 'a4d7', 'a4a8', 'a4e8', 'b4b1', 'b4e1',
115
- 'b4a2', 'b4b2', 'b4c2', 'b4d2', 'b4a3', 'b4b3', 'b4c3',
116
- 'b4d3', 'b4a4', 'b4c4', 'b4d4', 'b4e4', 'b4f4', 'b4g4',
117
- 'b4h4', 'b4a5', 'b4b5', 'b4c5', 'b4d5', 'b4a6', 'b4b6',
118
- 'b4c6', 'b4d6', 'b4b7', 'b4e7', 'b4b8', 'b4f8', 'c4c1',
119
- 'c4f1', 'c4a2', 'c4b2', 'c4c2', 'c4d2', 'c4e2', 'c4a3',
120
- 'c4b3', 'c4c3', 'c4d3', 'c4e3', 'c4a4', 'c4b4', 'c4d4',
121
- 'c4e4', 'c4f4', 'c4g4', 'c4h4', 'c4a5', 'c4b5', 'c4c5',
122
- 'c4d5', 'c4e5', 'c4a6', 'c4b6', 'c4c6', 'c4d6', 'c4e6',
123
- 'c4c7', 'c4f7', 'c4c8', 'c4g8', 'd4a1', 'd4d1', 'd4g1',
124
- 'd4b2', 'd4c2', 'd4d2', 'd4e2', 'd4f2', 'd4b3', 'd4c3',
125
- 'd4d3', 'd4e3', 'd4f3', 'd4a4', 'd4b4', 'd4c4', 'd4e4',
126
- 'd4f4', 'd4g4', 'd4h4', 'd4b5', 'd4c5', 'd4d5', 'd4e5',
127
- 'd4f5', 'd4b6', 'd4c6', 'd4d6', 'd4e6', 'd4f6', 'd4a7',
128
- 'd4d7', 'd4g7', 'd4d8', 'd4h8', 'e4b1', 'e4e1', 'e4h1',
129
- 'e4c2', 'e4d2', 'e4e2', 'e4f2', 'e4g2', 'e4c3', 'e4d3',
130
- 'e4e3', 'e4f3', 'e4g3', 'e4a4', 'e4b4', 'e4c4', 'e4d4',
131
- 'e4f4', 'e4g4', 'e4h4', 'e4c5', 'e4d5', 'e4e5', 'e4f5',
132
- 'e4g5', 'e4c6', 'e4d6', 'e4e6', 'e4f6', 'e4g6', 'e4b7',
133
- 'e4e7', 'e4h7', 'e4a8', 'e4e8', 'f4c1', 'f4f1', 'f4d2',
134
- 'f4e2', 'f4f2', 'f4g2', 'f4h2', 'f4d3', 'f4e3', 'f4f3',
135
- 'f4g3', 'f4h3', 'f4a4', 'f4b4', 'f4c4', 'f4d4', 'f4e4',
136
- 'f4g4', 'f4h4', 'f4d5', 'f4e5', 'f4f5', 'f4g5', 'f4h5',
137
- 'f4d6', 'f4e6', 'f4f6', 'f4g6', 'f4h6', 'f4c7', 'f4f7',
138
- 'f4b8', 'f4f8', 'g4d1', 'g4g1', 'g4e2', 'g4f2', 'g4g2',
139
- 'g4h2', 'g4e3', 'g4f3', 'g4g3', 'g4h3', 'g4a4', 'g4b4',
140
- 'g4c4', 'g4d4', 'g4e4', 'g4f4', 'g4h4', 'g4e5', 'g4f5',
141
- 'g4g5', 'g4h5', 'g4e6', 'g4f6', 'g4g6', 'g4h6', 'g4d7',
142
- 'g4g7', 'g4c8', 'g4g8', 'h4e1', 'h4h1', 'h4f2', 'h4g2',
143
- 'h4h2', 'h4f3', 'h4g3', 'h4h3', 'h4a4', 'h4b4', 'h4c4',
144
- 'h4d4', 'h4e4', 'h4f4', 'h4g4', 'h4f5', 'h4g5', 'h4h5',
145
- 'h4f6', 'h4g6', 'h4h6', 'h4e7', 'h4h7', 'h4d8', 'h4h8',
146
- 'a5a1', 'a5e1', 'a5a2', 'a5d2', 'a5a3', 'a5b3', 'a5c3',
147
- 'a5a4', 'a5b4', 'a5c4', 'a5b5', 'a5c5', 'a5d5', 'a5e5',
148
- 'a5f5', 'a5g5', 'a5h5', 'a5a6', 'a5b6', 'a5c6', 'a5a7',
149
- 'a5b7', 'a5c7', 'a5a8', 'a5d8', 'b5b1', 'b5f1', 'b5b2',
150
- 'b5e2', 'b5a3', 'b5b3', 'b5c3', 'b5d3', 'b5a4', 'b5b4',
151
- 'b5c4', 'b5d4', 'b5a5', 'b5c5', 'b5d5', 'b5e5', 'b5f5',
152
- 'b5g5', 'b5h5', 'b5a6', 'b5b6', 'b5c6', 'b5d6', 'b5a7',
153
- 'b5b7', 'b5c7', 'b5d7', 'b5b8', 'b5e8', 'c5c1', 'c5g1',
154
- 'c5c2', 'c5f2', 'c5a3', 'c5b3', 'c5c3', 'c5d3', 'c5e3',
155
- 'c5a4', 'c5b4', 'c5c4', 'c5d4', 'c5e4', 'c5a5', 'c5b5',
156
- 'c5d5', 'c5e5', 'c5f5', 'c5g5', 'c5h5', 'c5a6', 'c5b6',
157
- 'c5c6', 'c5d6', 'c5e6', 'c5a7', 'c5b7', 'c5c7', 'c5d7',
158
- 'c5e7', 'c5c8', 'c5f8', 'd5d1', 'd5h1', 'd5a2', 'd5d2',
159
- 'd5g2', 'd5b3', 'd5c3', 'd5d3', 'd5e3', 'd5f3', 'd5b4',
160
- 'd5c4', 'd5d4', 'd5e4', 'd5f4', 'd5a5', 'd5b5', 'd5c5',
161
- 'd5e5', 'd5f5', 'd5g5', 'd5h5', 'd5b6', 'd5c6', 'd5d6',
162
- 'd5e6', 'd5f6', 'd5b7', 'd5c7', 'd5d7', 'd5e7', 'd5f7',
163
- 'd5a8', 'd5d8', 'd5g8', 'e5a1', 'e5e1', 'e5b2', 'e5e2',
164
- 'e5h2', 'e5c3', 'e5d3', 'e5e3', 'e5f3', 'e5g3', 'e5c4',
165
- 'e5d4', 'e5e4', 'e5f4', 'e5g4', 'e5a5', 'e5b5', 'e5c5',
166
- 'e5d5', 'e5f5', 'e5g5', 'e5h5', 'e5c6', 'e5d6', 'e5e6',
167
- 'e5f6', 'e5g6', 'e5c7', 'e5d7', 'e5e7', 'e5f7', 'e5g7',
168
- 'e5b8', 'e5e8', 'e5h8', 'f5b1', 'f5f1', 'f5c2', 'f5f2',
169
- 'f5d3', 'f5e3', 'f5f3', 'f5g3', 'f5h3', 'f5d4', 'f5e4',
170
- 'f5f4', 'f5g4', 'f5h4', 'f5a5', 'f5b5', 'f5c5', 'f5d5',
171
- 'f5e5', 'f5g5', 'f5h5', 'f5d6', 'f5e6', 'f5f6', 'f5g6',
172
- 'f5h6', 'f5d7', 'f5e7', 'f5f7', 'f5g7', 'f5h7', 'f5c8',
173
- 'f5f8', 'g5c1', 'g5g1', 'g5d2', 'g5g2', 'g5e3', 'g5f3',
174
- 'g5g3', 'g5h3', 'g5e4', 'g5f4', 'g5g4', 'g5h4', 'g5a5',
175
- 'g5b5', 'g5c5', 'g5d5', 'g5e5', 'g5f5', 'g5h5', 'g5e6',
176
- 'g5f6', 'g5g6', 'g5h6', 'g5e7', 'g5f7', 'g5g7', 'g5h7',
177
- 'g5d8', 'g5g8', 'h5d1', 'h5h1', 'h5e2', 'h5h2', 'h5f3',
178
- 'h5g3', 'h5h3', 'h5f4', 'h5g4', 'h5h4', 'h5a5', 'h5b5',
179
- 'h5c5', 'h5d5', 'h5e5', 'h5f5', 'h5g5', 'h5f6', 'h5g6',
180
- 'h5h6', 'h5f7', 'h5g7', 'h5h7', 'h5e8', 'h5h8', 'a6a1',
181
- 'a6f1', 'a6a2', 'a6e2', 'a6a3', 'a6d3', 'a6a4', 'a6b4',
182
- 'a6c4', 'a6a5', 'a6b5', 'a6c5', 'a6b6', 'a6c6', 'a6d6',
183
- 'a6e6', 'a6f6', 'a6g6', 'a6h6', 'a6a7', 'a6b7', 'a6c7',
184
- 'a6a8', 'a6b8', 'a6c8', 'b6b1', 'b6g1', 'b6b2', 'b6f2',
185
- 'b6b3', 'b6e3', 'b6a4', 'b6b4', 'b6c4', 'b6d4', 'b6a5',
186
- 'b6b5', 'b6c5', 'b6d5', 'b6a6', 'b6c6', 'b6d6', 'b6e6',
187
- 'b6f6', 'b6g6', 'b6h6', 'b6a7', 'b6b7', 'b6c7', 'b6d7',
188
- 'b6a8', 'b6b8', 'b6c8', 'b6d8', 'c6c1', 'c6h1', 'c6c2',
189
- 'c6g2', 'c6c3', 'c6f3', 'c6a4', 'c6b4', 'c6c4', 'c6d4',
190
- 'c6e4', 'c6a5', 'c6b5', 'c6c5', 'c6d5', 'c6e5', 'c6a6',
191
- 'c6b6', 'c6d6', 'c6e6', 'c6f6', 'c6g6', 'c6h6', 'c6a7',
192
- 'c6b7', 'c6c7', 'c6d7', 'c6e7', 'c6a8', 'c6b8', 'c6c8',
193
- 'c6d8', 'c6e8', 'd6d1', 'd6d2', 'd6h2', 'd6a3', 'd6d3',
194
- 'd6g3', 'd6b4', 'd6c4', 'd6d4', 'd6e4', 'd6f4', 'd6b5',
195
- 'd6c5', 'd6d5', 'd6e5', 'd6f5', 'd6a6', 'd6b6', 'd6c6',
196
- 'd6e6', 'd6f6', 'd6g6', 'd6h6', 'd6b7', 'd6c7', 'd6d7',
197
- 'd6e7', 'd6f7', 'd6b8', 'd6c8', 'd6d8', 'd6e8', 'd6f8',
198
- 'e6e1', 'e6a2', 'e6e2', 'e6b3', 'e6e3', 'e6h3', 'e6c4',
199
- 'e6d4', 'e6e4', 'e6f4', 'e6g4', 'e6c5', 'e6d5', 'e6e5',
200
- 'e6f5', 'e6g5', 'e6a6', 'e6b6', 'e6c6', 'e6d6', 'e6f6',
201
- 'e6g6', 'e6h6', 'e6c7', 'e6d7', 'e6e7', 'e6f7', 'e6g7',
202
- 'e6c8', 'e6d8', 'e6e8', 'e6f8', 'e6g8', 'f6a1', 'f6f1',
203
- 'f6b2', 'f6f2', 'f6c3', 'f6f3', 'f6d4', 'f6e4', 'f6f4',
204
- 'f6g4', 'f6h4', 'f6d5', 'f6e5', 'f6f5', 'f6g5', 'f6h5',
205
- 'f6a6', 'f6b6', 'f6c6', 'f6d6', 'f6e6', 'f6g6', 'f6h6',
206
- 'f6d7', 'f6e7', 'f6f7', 'f6g7', 'f6h7', 'f6d8', 'f6e8',
207
- 'f6f8', 'f6g8', 'f6h8', 'g6b1', 'g6g1', 'g6c2', 'g6g2',
208
- 'g6d3', 'g6g3', 'g6e4', 'g6f4', 'g6g4', 'g6h4', 'g6e5',
209
- 'g6f5', 'g6g5', 'g6h5', 'g6a6', 'g6b6', 'g6c6', 'g6d6',
210
- 'g6e6', 'g6f6', 'g6h6', 'g6e7', 'g6f7', 'g6g7', 'g6h7',
211
- 'g6e8', 'g6f8', 'g6g8', 'g6h8', 'h6c1', 'h6h1', 'h6d2',
212
- 'h6h2', 'h6e3', 'h6h3', 'h6f4', 'h6g4', 'h6h4', 'h6f5',
213
- 'h6g5', 'h6h5', 'h6a6', 'h6b6', 'h6c6', 'h6d6', 'h6e6',
214
- 'h6f6', 'h6g6', 'h6f7', 'h6g7', 'h6h7', 'h6f8', 'h6g8',
215
- 'h6h8', 'a7a1', 'a7g1', 'a7a2', 'a7f2', 'a7a3', 'a7e3',
216
- 'a7a4', 'a7d4', 'a7a5', 'a7b5', 'a7c5', 'a7a6', 'a7b6',
217
- 'a7c6', 'a7b7', 'a7c7', 'a7d7', 'a7e7', 'a7f7', 'a7g7',
218
- 'a7h7', 'a7a8', 'a7b8', 'a7c8', 'b7b1', 'b7h1', 'b7b2',
219
- 'b7g2', 'b7b3', 'b7f3', 'b7b4', 'b7e4', 'b7a5', 'b7b5',
220
- 'b7c5', 'b7d5', 'b7a6', 'b7b6', 'b7c6', 'b7d6', 'b7a7',
221
- 'b7c7', 'b7d7', 'b7e7', 'b7f7', 'b7g7', 'b7h7', 'b7a8',
222
- 'b7b8', 'b7c8', 'b7d8', 'c7c1', 'c7c2', 'c7h2', 'c7c3',
223
- 'c7g3', 'c7c4', 'c7f4', 'c7a5', 'c7b5', 'c7c5', 'c7d5',
224
- 'c7e5', 'c7a6', 'c7b6', 'c7c6', 'c7d6', 'c7e6', 'c7a7',
225
- 'c7b7', 'c7d7', 'c7e7', 'c7f7', 'c7g7', 'c7h7', 'c7a8',
226
- 'c7b8', 'c7c8', 'c7d8', 'c7e8', 'd7d1', 'd7d2', 'd7d3',
227
- 'd7h3', 'd7a4', 'd7d4', 'd7g4', 'd7b5', 'd7c5', 'd7d5',
228
- 'd7e5', 'd7f5', 'd7b6', 'd7c6', 'd7d6', 'd7e6', 'd7f6',
229
- 'd7a7', 'd7b7', 'd7c7', 'd7e7', 'd7f7', 'd7g7', 'd7h7',
230
- 'd7b8', 'd7c8', 'd7d8', 'd7e8', 'd7f8', 'e7e1', 'e7e2',
231
- 'e7a3', 'e7e3', 'e7b4', 'e7e4', 'e7h4', 'e7c5', 'e7d5',
232
- 'e7e5', 'e7f5', 'e7g5', 'e7c6', 'e7d6', 'e7e6', 'e7f6',
233
- 'e7g6', 'e7a7', 'e7b7', 'e7c7', 'e7d7', 'e7f7', 'e7g7',
234
- 'e7h7', 'e7c8', 'e7d8', 'e7e8', 'e7f8', 'e7g8', 'f7f1',
235
- 'f7a2', 'f7f2', 'f7b3', 'f7f3', 'f7c4', 'f7f4', 'f7d5',
236
- 'f7e5', 'f7f5', 'f7g5', 'f7h5', 'f7d6', 'f7e6', 'f7f6',
237
- 'f7g6', 'f7h6', 'f7a7', 'f7b7', 'f7c7', 'f7d7', 'f7e7',
238
- 'f7g7', 'f7h7', 'f7d8', 'f7e8', 'f7f8', 'f7g8', 'f7h8',
239
- 'g7a1', 'g7g1', 'g7b2', 'g7g2', 'g7c3', 'g7g3', 'g7d4',
240
- 'g7g4', 'g7e5', 'g7f5', 'g7g5', 'g7h5', 'g7e6', 'g7f6',
241
- 'g7g6', 'g7h6', 'g7a7', 'g7b7', 'g7c7', 'g7d7', 'g7e7',
242
- 'g7f7', 'g7h7', 'g7e8', 'g7f8', 'g7g8', 'g7h8', 'h7b1',
243
- 'h7h1', 'h7c2', 'h7h2', 'h7d3', 'h7h3', 'h7e4', 'h7h4',
244
- 'h7f5', 'h7g5', 'h7h5', 'h7f6', 'h7g6', 'h7h6', 'h7a7',
245
- 'h7b7', 'h7c7', 'h7d7', 'h7e7', 'h7f7', 'h7g7', 'h7f8',
246
- 'h7g8', 'h7h8', 'a8a1', 'a8h1', 'a8a2', 'a8g2', 'a8a3',
247
- 'a8f3', 'a8a4', 'a8e4', 'a8a5', 'a8d5', 'a8a6', 'a8b6',
248
- 'a8c6', 'a8a7', 'a8b7', 'a8c7', 'a8b8', 'a8c8', 'a8d8',
249
- 'a8e8', 'a8f8', 'a8g8', 'a8h8', 'b8b1', 'b8b2', 'b8h2',
250
- 'b8b3', 'b8g3', 'b8b4', 'b8f4', 'b8b5', 'b8e5', 'b8a6',
251
- 'b8b6', 'b8c6', 'b8d6', 'b8a7', 'b8b7', 'b8c7', 'b8d7',
252
- 'b8a8', 'b8c8', 'b8d8', 'b8e8', 'b8f8', 'b8g8', 'b8h8',
253
- 'c8c1', 'c8c2', 'c8c3', 'c8h3', 'c8c4', 'c8g4', 'c8c5',
254
- 'c8f5', 'c8a6', 'c8b6', 'c8c6', 'c8d6', 'c8e6', 'c8a7',
255
- 'c8b7', 'c8c7', 'c8d7', 'c8e7', 'c8a8', 'c8b8', 'c8d8',
256
- 'c8e8', 'c8f8', 'c8g8', 'c8h8', 'd8d1', 'd8d2', 'd8d3',
257
- 'd8d4', 'd8h4', 'd8a5', 'd8d5', 'd8g5', 'd8b6', 'd8c6',
258
- 'd8d6', 'd8e6', 'd8f6', 'd8b7', 'd8c7', 'd8d7', 'd8e7',
259
- 'd8f7', 'd8a8', 'd8b8', 'd8c8', 'd8e8', 'd8f8', 'd8g8',
260
- 'd8h8', 'e8e1', 'e8e2', 'e8e3', 'e8a4', 'e8e4', 'e8b5',
261
- 'e8e5', 'e8h5', 'e8c6', 'e8d6', 'e8e6', 'e8f6', 'e8g6',
262
- 'e8c7', 'e8d7', 'e8e7', 'e8f7', 'e8g7', 'e8a8', 'e8b8',
263
- 'e8c8', 'e8d8', 'e8f8', 'e8g8', 'e8h8', 'f8f1', 'f8f2',
264
- 'f8a3', 'f8f3', 'f8b4', 'f8f4', 'f8c5', 'f8f5', 'f8d6',
265
- 'f8e6', 'f8f6', 'f8g6', 'f8h6', 'f8d7', 'f8e7', 'f8f7',
266
- 'f8g7', 'f8h7', 'f8a8', 'f8b8', 'f8c8', 'f8d8', 'f8e8',
267
- 'f8g8', 'f8h8', 'g8g1', 'g8a2', 'g8g2', 'g8b3', 'g8g3',
268
- 'g8c4', 'g8g4', 'g8d5', 'g8g5', 'g8e6', 'g8f6', 'g8g6',
269
- 'g8h6', 'g8e7', 'g8f7', 'g8g7', 'g8h7', 'g8a8', 'g8b8',
270
- 'g8c8', 'g8d8', 'g8e8', 'g8f8', 'g8h8', 'h8a1', 'h8h1',
271
- 'h8b2', 'h8h2', 'h8c3', 'h8h3', 'h8d4', 'h8h4', 'h8e5',
272
- 'h8h5', 'h8f6', 'h8g6', 'h8h6', 'h8f7', 'h8g7', 'h8h7',
273
- 'h8a8', 'h8b8', 'h8c8', 'h8d8', 'h8e8', 'h8f8', 'h8g8',
274
- 'a7a8q', 'a7a8r', 'a7a8b', 'a7b8q', 'a7b8r', 'a7b8b', 'b7a8q',
275
- 'b7a8r', 'b7a8b', 'b7b8q', 'b7b8r', 'b7b8b', 'b7c8q', 'b7c8r',
276
- 'b7c8b', 'c7b8q', 'c7b8r', 'c7b8b', 'c7c8q', 'c7c8r', 'c7c8b',
277
- 'c7d8q', 'c7d8r', 'c7d8b', 'd7c8q', 'd7c8r', 'd7c8b', 'd7d8q',
278
- 'd7d8r', 'd7d8b', 'd7e8q', 'd7e8r', 'd7e8b', 'e7d8q', 'e7d8r',
279
- 'e7d8b', 'e7e8q', 'e7e8r', 'e7e8b', 'e7f8q', 'e7f8r', 'e7f8b',
280
- 'f7e8q', 'f7e8r', 'f7e8b', 'f7f8q', 'f7f8r', 'f7f8b', 'f7g8q',
281
- 'f7g8r', 'f7g8b', 'g7f8q', 'g7f8r', 'g7f8b', 'g7g8q', 'g7g8r',
282
- 'g7g8b', 'g7h8q', 'g7h8r', 'g7h8b', 'h7g8q', 'h7g8r', 'h7g8b',
283
- 'h7h8q', 'h7h8r', 'h7h8b'
284
- ]
285
-
286
- # White, no castling
287
- _uci_to_idx_wn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
288
-
289
- # White, castling
290
- _uci_to_idx_wc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_wn))
291
- _uci_to_idx_wc['e1g1'], _uci_to_idx_wc['e1h1'] = _uci_to_idx_wc['e1h1'], _uci_to_idx_wc['e1g1']
292
- _uci_to_idx_wc['e1c1'], _uci_to_idx_wc['e1a1'] = _uci_to_idx_wc['e1a1'], _uci_to_idx_wc['e1c1']
293
-
294
-
295
- # Black, no castling
296
- _idx_to_move_bn = []
297
- for move in _idx_to_move_wn:
298
- c0,r0,c1,r1,p = move[0],int(move[1]),move[2],int(move[3]),move[4:]
299
- r0 = 9 - r0
300
- r1 = 9 - r1
301
- _idx_to_move_bn.append('{}{}{}{}{}'.format(c0,r0,c1,r1,p))
302
- _uci_to_idx_bn = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
303
-
304
- # Black, castling
305
- _uci_to_idx_bc = dict((uci, idx) for idx, uci in enumerate(_idx_to_move_bn))
306
- _uci_to_idx_bc['e8g8'], _uci_to_idx_bc['e8h8'] = _uci_to_idx_bc['e8h8'], _uci_to_idx_bc['e8g8']
307
- _uci_to_idx_bc['e8c8'], _uci_to_idx_bc['e8a8'] = _uci_to_idx_bc['e8a8'], _uci_to_idx_bc['e8c8']
308
-
309
- uci_to_idx = [_uci_to_idx_wn, _uci_to_idx_wc, _uci_to_idx_bn, _uci_to_idx_bc]
310
-
311
-
312
- import collections
313
- import struct
314
- import zlib
315
-
316
- import chess
317
- import numpy as np
318
- from chess import Move
319
-
320
- flat_planes = []
321
- for i in range(256):
322
- flat_planes.append(np.ones((8,8), dtype=np.uint8)*i)
323
-
324
- LeelaBoardData = collections.namedtuple('LeelaBoardData',
325
- 'plane_bytes repetition '
326
- 'transposition_key us_ooo us_oo them_ooo them_oo '
327
- 'side_to_move rule50_count')
328
-
329
- def pc_board_property(propertyname):
330
- '''Create a property based on self.pc_board'''
331
- def prop(self):
332
- return getattr(self.pc_board, propertyname)
333
- return property(prop)
334
-
335
- class LeelaBoard:
336
- turn = pc_board_property('turn')
337
- move_stack = pc_board_property('move_stack')
338
- _plane_bytes_struct = struct.Struct('>Q')
339
-
340
- def __init__(self, leela_board = None, *args, **kwargs):
341
- '''If leela_board is passed as an argument, return a copy'''
342
- self.pc_board = chess.Board(*args, **kwargs)
343
- self.lcz_stack = []
344
- self._lcz_transposition_counter = collections.Counter()
345
- self._lcz_push()
346
- self.is_game_over = self.pc_method('is_game_over')
347
- self.can_claim_draw = self.pc_method('can_claim_draw')
348
- self.generate_legal_moves = self.pc_method('generate_legal_moves')
349
-
350
- def copy(self, history=7):
351
- """Note! Currently the copy constructor uses pc_board.copy(stack=False), which makes pops impossible"""
352
- cls = type(self)
353
- copied = cls.__new__(cls)
354
- copied.pc_board = self.pc_board.copy(stack=False)
355
- copied.pc_board.stack[:] = self.pc_board.stack[-history:]
356
- copied.pc_board.move_stack[:] = self.pc_board.move_stack[-history:]
357
- copied.lcz_stack = self.lcz_stack[-history:]
358
- copied._lcz_transposition_counter = self._lcz_transposition_counter.copy()
359
- copied.is_game_over = copied.pc_method('is_game_over')
360
- copied.can_claim_draw = copied.pc_method('can_claim_draw')
361
- copied.generate_legal_moves = copied.pc_method('generate_legal_moves')
362
- return copied
363
-
364
- def pc_method(self, methodname):
365
- '''Return attribute of self.pc_board, useful for copying method bindings'''
366
- return getattr(self.pc_board, methodname)
367
-
368
- def is_threefold(self):
369
- transposition_key = self.pc_board._transposition_key()
370
- return self._lcz_transposition_counter[transposition_key] >= 3
371
-
372
- def is_fifty_moves(self):
373
- return self.pc_board.halfmove_clock >= 100
374
-
375
- def is_draw(self):
376
- return self.is_threefold() or self.is_fifty_moves()
377
-
378
- def push(self, move):
379
- self.pc_board.push(move)
380
- self._lcz_push()
381
-
382
- def push_uci(self, uci):
383
- # don't check for legality - it takes much longer to run...
384
- # self.pc_board.push_uci(uci)
385
- self.pc_board.push(Move.from_uci(uci))
386
- self._lcz_push()
387
-
388
- def push_san(self, san):
389
- self.pc_board.push_san(san)
390
- self._lcz_push()
391
-
392
- def pop(self):
393
- result = self.pc_board.pop()
394
- _lcz_data = self.lcz_stack.pop()
395
- self._lcz_transposition_counter.subtract((_lcz_data.transposition_key,))
396
- return result
397
-
398
- def _plane_bytes_iter(self):
399
- """Get plane bytes... used for _lcz_push"""
400
- pack = self._plane_bytes_struct.pack
401
- pieces_mask = self.pc_board.pieces_mask
402
- for color in (True, False):
403
- for piece_type in range(1,7):
404
- byts = pack(pieces_mask(piece_type, color))
405
- yield byts
406
-
407
- def _lcz_push(self):
408
- """Push data onto the lcz data stack after pushing board moves"""
409
- transposition_key = self.pc_board._transposition_key()
410
- self._lcz_transposition_counter.update((transposition_key,))
411
- repetitions = self._lcz_transposition_counter[transposition_key] - 1
412
- # side_to_move = 0 if we're white, 1 if we're black
413
- side_to_move = 0 if self.pc_board.turn else 1
414
- rule50_count = self.pc_board.halfmove_clock
415
- # Figure out castling rights
416
- if not side_to_move:
417
- # we're white
418
- _c = self.pc_board.castling_rights
419
- us_ooo, us_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
420
- them_ooo, them_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
421
- else:
422
- # We're black
423
- _c = self.pc_board.castling_rights
424
- us_ooo, us_oo = (_c>>chess.A8) & 1, (_c>>chess.H8) & 1
425
- them_ooo, them_oo = (_c>>chess.A1) & 1, (_c>>chess.H1) & 1
426
- # Create 13 planes... 6 us, 6 them, repetitions>=1
427
- plane_bytes = b''.join(self._plane_bytes_iter())
428
- repetition = (repetitions>=1)
429
- lcz_data = LeelaBoardData(
430
- plane_bytes, repetition=repetition,
431
- transposition_key=transposition_key,
432
- us_ooo=us_ooo, us_oo=us_oo, them_ooo=them_ooo, them_oo=them_oo,
433
- side_to_move=side_to_move, rule50_count=rule50_count
434
- )
435
- self.lcz_stack.append(lcz_data)
436
-
437
- def serialize_features(self):
438
- '''Get compacted bytes representation of input planes'''
439
- planes = []
440
- curdata = self.lcz_stack[-1]
441
- bytes_false_true = bytes([False]), bytes([True])
442
- bytes_per_history = 97
443
- total_plane_bytes = bytes_per_history * 8
444
- def bytes_iter():
445
- plane_bytes_yielded = 0
446
- for data in self.lcz_stack[-1:-9:-1]:
447
- yield data.plane_bytes
448
- yield bytes_false_true[data.repetition]
449
- plane_bytes_yielded += bytes_per_history
450
- # 104 total piece planes... fill in missing with 0s
451
- yield bytes(total_plane_bytes - plane_bytes_yielded)
452
- # Yield the rest of the constant planes
453
- yield np.packbits((curdata.us_ooo,
454
- curdata.us_oo,
455
- curdata.them_ooo,
456
- curdata.them_oo,
457
- curdata.side_to_move)).tobytes()
458
- yield chr(curdata.rule50_count).encode()
459
- return b''.join(bytes_iter())
460
-
461
- @classmethod
462
- def deserialize_features(cls, serialized):
463
- planes_stack = []
464
- rule50_count = serialized[-1] # last byte is rule 50
465
- board_attrs = np.unpackbits(memoryview(serialized[-2:-1])) # second to last byte
466
- us_ooo, us_oo, them_ooo, them_oo, side_to_move = board_attrs[:5]
467
- bytes_per_history = 97
468
- for history_idx in range(0, bytes_per_history*8, bytes_per_history):
469
- plane_bytes = serialized[history_idx:history_idx+96]
470
- repetition = serialized[history_idx+96]
471
- if not side_to_move:
472
- # we're white
473
- planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
474
- .reshape(12, 8, 8)[::-1])
475
- else:
476
- # We're black
477
- planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
478
- .reshape(12, 8, 8)[::-1]
479
- .reshape(2,6,8,8)[::-1,:,::-1]
480
- .reshape(12, 8,8))
481
- planes_stack.append(planes)
482
- planes_stack.append([flat_planes[repetition]])
483
- planes_stack.append([flat_planes[us_ooo],
484
- flat_planes[us_oo],
485
- flat_planes[them_ooo],
486
- flat_planes[them_oo],
487
- flat_planes[side_to_move],
488
- flat_planes[rule50_count],
489
- flat_planes[0],
490
- flat_planes[1]])
491
- planes = np.concatenate(planes_stack)
492
- return planes
493
-
494
- def lcz_features(self):
495
- '''Get neural network input planes as uint8'''
496
- # print(list(self._planes_iter()))
497
- planes_stack = []
498
- curdata = self.lcz_stack[-1]
499
- planes_yielded = 0
500
- for data in self.lcz_stack[-1:-9:-1]:
501
- plane_bytes = data.plane_bytes
502
- if not curdata.side_to_move:
503
- # we're white
504
- planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
505
- .reshape(12, 8, 8)[::-1])
506
- else:
507
- # We're black
508
- planes = (np.unpackbits(memoryview(plane_bytes))[::-1]
509
- .reshape(12, 8, 8)[::-1]
510
- .reshape(2,6,8,8)[::-1,:,::-1]
511
- .reshape(12, 8,8))
512
- planes_stack.append(planes)
513
- planes_stack.append([flat_planes[data.repetition]])
514
- planes_yielded += 13
515
- empty_planes = [flat_planes[0] for _ in range(104-planes_yielded)]
516
- if empty_planes:
517
- planes_stack.append(empty_planes)
518
- # Yield the rest of the constant planes
519
- planes_stack.append([flat_planes[curdata.us_ooo],
520
- flat_planes[curdata.us_oo],
521
- flat_planes[curdata.them_ooo],
522
- flat_planes[curdata.them_oo],
523
- flat_planes[curdata.side_to_move],
524
- flat_planes[curdata.rule50_count],
525
- flat_planes[0],
526
- flat_planes[1]])
527
- planes = np.concatenate(planes_stack)
528
- return planes
529
-
530
- def lcz_uci_to_idx(self, uci_list):
531
- # Return list of NN policy output indexes for this board position, given uci_list
532
-
533
- # TODO: Perhaps it's possible to just add the uci knight promotion move to the index dict
534
- # currently knight promotions are not in the dict
535
- uci_list = [uci.rstrip('n') for uci in uci_list]
536
-
537
- data = self.lcz_stack[-1]
538
- # uci_to_idx_index =
539
- # White, no-castling => 0
540
- # White, castling => 1
541
- # Black, no-castling => 2
542
- # Black, castling => 3
543
- uci_to_idx_index = (data.us_ooo | data.us_oo) + 2*data.side_to_move
544
- uci_idx_dct = uci_to_idx[uci_to_idx_index]
545
- return [uci_idx_dct[m] for m in uci_list]
546
-
547
- @classmethod
548
- def compress_features(cls, features):
549
- """Compress a features array as returned from lcz_features method"""
550
- features_8 = features.astype(np.uint8)
551
- # Simple compression would do this...
552
- # return zlib.compress(features_8)
553
- piece_plane_bytes = np.packbits(features_8[:-8]).tobytes()
554
- scalar_bytes = features_8[-8:][:,0,0].tobytes()
555
- compressed = zlib.compress(piece_plane_bytes + scalar_bytes)
556
- return compressed
557
-
558
- @classmethod
559
- def decompress_features(cls, compressed_features):
560
- """Decompress a compressed features array from compress_features"""
561
- decompressed = zlib.decompress(compressed_features)
562
- # Simple decompression would do this
563
- # return np.frombuffer(decompressed, dtype=np.uint8).astype(np.float32).reshape(-1,8,8)
564
- piece_plane_bytes = decompressed[:-8]
565
- scalar_bytes = decompressed[-8:]
566
- piece_plane_arr = np.unpackbits(memoryview(piece_plane_bytes))
567
- scalar_arr = np.frombuffer(scalar_bytes, dtype=np.uint8).repeat(64)
568
- result = np.concatenate((piece_plane_arr, scalar_arr)).astype(np.float32).reshape(-1,8,8)
569
- return result
570
-
571
- def unicode(self):
572
- if self.pc_board.is_game_over() or self.is_draw():
573
- result = self.pc_board.result(claim_draw=True)
574
- turnstring = 'Result: {}'.format(result)
575
- else:
576
- turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
577
- boardstr = self.pc_board.unicode() + "\n" + turnstring
578
- return boardstr
579
-
580
- def __repr__(self):
581
- return "LeelaBoard('{}')".format(self.pc_board.fen())
582
-
583
- def _repr_svg_(self):
584
- return self.pc_board._repr_svg_()
585
-
586
- def __str__(self):
587
- if self.pc_board.is_game_over() or self.is_draw():
588
- result = self.pc_board.result(claim_draw=True)
589
- turnstring = 'Result: {}'.format(result)
590
- else:
591
- turnstring = 'Turn: {}'.format('White' if self.pc_board.turn else 'Black')
592
- boardstr = self.pc_board.__str__() + "\n" + turnstring
593
- return boardstr
594
-
595
- def __eq__(self, other):
596
- return self.get_hash_key() == other.get_hash_key()
597
-
598
- def __hash__(self):
599
- return hash(self.get_hash_key())
600
-
601
- def get_hash_key(self):
602
- transposition_key = self.pc_board._transposition_key()
603
- return (transposition_key +
604
- (self._lcz_transposition_counter[transposition_key], self.pc_board.halfmove_clock) +
605
- tuple(self.pc_board.move_stack[-7:])
606
- )
607
-
608
- # lb = LeelaBoard()
609
- # lb.push_uci('c2c4')
610
- #lb.push_uci('c7c5')
611
- #lb.push_uci('d2d3')
612
- #lb.push_uci('c2c4')
613
- #lb.push_uci('b8c6')
614
- # saved_planes = planes
615
- # planes = lb.features()
616
- # output = leela_net(torch.from_numpy(planes).unsqueeze(0))
617
- # output
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/leela_utils.py DELETED
The diff for this file is too large to render. See raw diff
 
our_visualization/models/.DS_Store DELETED
Binary file (6.15 kB)
 
our_visualization/python_chess_customized_svg.py DELETED
@@ -1,414 +0,0 @@
1
- # This file has been copied and slightly modified from python-chess library,
2
- # Copyright (C) 2016-2020 Niklas Fiekas <niklas.fiekas@backscattering.de>.
3
-
4
- # This program is free software: you can redistribute it and/or modify
5
- # it under the terms of the GNU General Public License as published by
6
- # the Free Software Foundation, either version 3 of the License, or
7
- # (at your option) any later version.
8
- #
9
- # This program is distributed in the hope that it will be useful,
10
- # but WITHOUT ANY WARRANTY; without even the implied warranty of
11
- # MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
12
- # GNU General Public License for more details.
13
- #
14
- # You should have received a copy of the GNU General Public License
15
- # along with this program. If not, see <http://www.gnu.org/licenses/>.
16
-
17
- # Piece vector graphics are copyright (C) Colin M.L. Burnett
18
- # <https://en.wikipedia.org/wiki/User:Cburnett> and also licensed under the
19
- # GNU General Public License.
20
-
21
- import chess
22
- import math
23
-
24
- import xml.etree.ElementTree as ET
25
-
26
- from typing import Iterable, Optional, Tuple, Union
27
-
28
- SQUARE_SIZE = 45
29
- MARGIN = 20
30
-
31
- PIECES = {
32
- "b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
33
- # noqa: E501
34
- "k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
35
- # noqa: E501
36
- "n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
37
- # noqa: E501
38
- "p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
39
- # noqa: E501
40
- "q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
41
- # noqa: E501
42
- "r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
43
- # noqa: E501
44
- "B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
45
- # noqa: E501
46
- "K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
47
- # noqa: E501
48
- "N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
49
- # noqa: E501
50
- "P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
51
- # noqa: E501
52
- "Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
53
- # noqa: E501
54
- "R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
55
- # noqa: E501
56
- }
57
-
58
- PIECES = {
59
- "b": """<g id="black-bishop" class="black bishop" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zm6-4c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z" fill="#000" stroke-linecap="butt"/><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke="#fff" stroke-linejoin="miter"/></g>""",
60
- # noqa: E501
61
- "k": """<g id="black-king" class="black king" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#000" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#000"/><path d="M20 8h5" stroke-linejoin="miter"/><path d="M32 29.5s8.5-4 6.03-9.65C34.15 14 25 18 22.5 24.5l.01 2.1-.01-2.1C20 18 9.906 14 6.997 19.85c-2.497 5.65 4.853 9 4.853 9M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0" stroke="#fff"/></g>""",
62
- # noqa: E501
63
- "n": """<g id="black-knight" class="black knight" fill="none" fill-rule="evenodd" stroke="#fff" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#000000; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#000000; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#ececec; stroke:#ececec;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#ececec; stroke:#ececec;"/><path d="M 24.55,10.4 L 24.1,11.85 L 24.6,12 C 27.75,13 30.25,14.49 32.5,18.75 C 34.75,23.01 35.75,29.06 35.25,39 L 35.2,39.5 L 37.45,39.5 L 37.5,39 C 38,28.94 36.62,22.15 34.25,17.66 C 31.88,13.17 28.46,11.02 25.06,10.5 L 24.55,10.4 z " style="fill:#ececec; stroke:none;"/></g>""",
64
- # noqa: E501
65
- "p": """<g id="black-pawn" class="black pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" stroke="#fff" stroke-width="1.0" stroke-linecap="round"/></g>""",
66
- # noqa: E501
67
- "q": """<g id="black-queen" class="black queen" fill="#000" fill-rule="evenodd" stroke="#000" stroke-width="1.0" stroke-linecap="round" stroke-linejoin="round"><g fill="#000" stroke="none"><circle cx="6" cy="12" r="2.75"/><circle cx="14" cy="9" r="2.75"/><circle cx="22.5" cy="8" r="2.75"/><circle cx="31" cy="9" r="2.75"/><circle cx="39" cy="12" r="2.75"/></g><path d="M9 26c8.5-1.5 21-1.5 27 0l2.5-12.5L31 25l-.3-14.1-5.2 13.6-3-14.5-3 14.5-5.2-13.6L14 25 6.5 13.5 9 26zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11 38.5a35 35 1 0 0 23 0" fill="none" stroke-linecap="butt"/><path d="M11 29a35 35 1 0 1 23 0M12.5 31.5h20M11.5 34.5a35 35 1 0 0 22 0M10.5 37.5a35 35 1 0 0 24 0" fill="none" stroke="#fff"/></g>""",
68
- # noqa: E501
69
- "r": """<g id="black-rook" class="black rook" fill="#000" fill-rule="evenodd" stroke="#fff" stroke-width="0.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12.5 32l1.5-2.5h17l1.5 2.5h-20zM12 36v-4h21v4H12z" stroke-linecap="butt"/><path d="M14 29.5v-13h17v13H14z" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M14 16.5L11 14h23l-3 2.5H14zM11 14V9h4v2h5V9h5v2h5V9h4v5H11z" stroke-linecap="butt"/><path d="M12 35.5h21M13 31.5h19M14 29.5h17M14 16.5h17M11 14h23" fill="none" stroke="#fff" stroke-width="1" stroke-linejoin="miter"/></g>""",
70
- # noqa: E501
71
- "B": """<g id="white-bishop" class="white bishop" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><g fill="#fff" stroke-linecap="butt"><path d="M9 36c3.39-.97 10.11.43 13.5-2 3.39 2.43 10.11 1.03 13.5 2 0 0 1.65.54 3 2-.68.97-1.65.99-3 .5-3.39-.97-10.11.46-13.5-1-3.39 1.46-10.11.03-13.5 1-1.354.49-2.323.47-3-.5 1.354-1.94 3-2 3-2zM15 32c2.5 2.5 12.5 2.5 15 0 .5-1.5 0-2 0-2 0-2.5-2.5-4-2.5-4 5.5-1.5 6-11.5-5-15.5-11 4-10.5 14-5 15.5 0 0-2.5 1.5-2.5 4 0 0-.5.5 0 2zM25 8a2.5 2.5 0 1 1-5 0 2.5 2.5 0 1 1 5 0z"/></g><path d="M17.5 26h10M15 30h15m-7.5-14.5v5M20 18h5" stroke-linejoin="miter"/></g>""",
72
- # noqa: E501
73
- "K": """<g id="white-king" class="white king" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M22.5 11.63V6M20 8h5" stroke-linejoin="miter"/><path d="M22.5 25s4.5-7.5 3-10.5c0 0-1-2.5-3-2.5s-3 2.5-3 2.5c-1.5 3 3 10.5 3 10.5" fill="#fff" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M11.5 37c5.5 3.5 15.5 3.5 21 0v-7s9-4.5 6-10.5c-4-6.5-13.5-3.5-16 4V27v-3.5c-3.5-7.5-13-10.5-16-4-3 6 5 10 5 10V37z" fill="#fff"/><path d="M11.5 30c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0m-21 3.5c5.5-3 15.5-3 21 0"/></g>""",
74
- # noqa: E501
75
- "N": """<g id="white-knight" class="white knight" fill="none" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M 22,10 C 32.5,11 38.5,18 38,39 L 15,39 C 15,30 25,32.5 23,18" style="fill:#ffffff; stroke:#000000;"/><path d="M 24,18 C 24.38,20.91 18.45,25.37 16,27 C 13,29 13.18,31.34 11,31 C 9.958,30.06 12.41,27.96 11,28 C 10,28 11.19,29.23 10,30 C 9,30 5.997,31 6,26 C 6,24 12,14 12,14 C 12,14 13.89,12.1 14,10.5 C 13.27,9.506 13.5,8.5 13.5,7.5 C 14.5,6.5 16.5,10 16.5,10 L 18.5,10 C 18.5,10 19.28,8.008 21,7 C 22,7 22,10 22,10" style="fill:#ffffff; stroke:#000000;"/><path d="M 9.5 25.5 A 0.5 0.5 0 1 1 8.5,25.5 A 0.5 0.5 0 1 1 9.5 25.5 z" style="fill:#000000; stroke:#000000;"/><path d="M 15 15.5 A 0.5 1.5 0 1 1 14,15.5 A 0.5 1.5 0 1 1 15 15.5 z" transform="matrix(0.866,0.5,-0.5,0.866,9.693,-5.173)" style="fill:#000000; stroke:#000000;"/></g>""",
76
- # noqa: E501
77
- "P": """<g id="white-pawn" class="white pawn"><path d="M22 9c-2.21 0-4 1.79-4 4 0 .89.29 1.71.78 2.38-1.95 1.12-3.28 3.21-3.28 5.62 0 2.03.94 3.84 2.41 5.03-3 1.06-7.41 5.55-7.41 13.47h23c0-7.92-4.41-12.41-7.41-13.47 1.47-1.19 2.41-3 2.41-5.03 0-2.41-1.33-4.5-3.28-5.62.49-.67.78-1.49.78-2.38 0-2.21-1.79-4-4-4z" fill="#fff" stroke="#000" stroke-width="1.5" stroke-linecap="round"/></g>""",
78
- # noqa: E501
79
- "Q": """<g id="white-queen" class="white queen" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M8 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM24.5 7.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM41 12a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM16 8.5a2 2 0 1 1-4 0 2 2 0 1 1 4 0zM33 9a2 2 0 1 1-4 0 2 2 0 1 1 4 0z"/><path d="M9 26c8.5-1.5 21-1.5 27 0l2-12-7 11V11l-5.5 13.5-3-15-3 15-5.5-14V25L7 14l2 12zM9 26c0 2 1.5 2 2.5 4 1 1.5 1 1 .5 3.5-1.5 1-1.5 2.5-1.5 2.5-1.5 1.5.5 2.5.5 2.5 6.5 1 16.5 1 23 0 0 0 1.5-1 0-2.5 0 0 .5-1.5-1-2.5-.5-2.5-.5-2 .5-3.5 1-2 2.5-2 2.5-4-8.5-1.5-18.5-1.5-27 0z" stroke-linecap="butt"/><path d="M11.5 30c3.5-1 18.5-1 22 0M12 33.5c6-1 15-1 21 0" fill="none"/></g>""",
80
- # noqa: E501
81
- "R": """<g id="white-rook" class="white rook" fill="#fff" fill-rule="evenodd" stroke="#000" stroke-width="1.5" stroke-linecap="round" stroke-linejoin="round"><path d="M9 39h27v-3H9v3zM12 36v-4h21v4H12zM11 14V9h4v2h5V9h5v2h5V9h4v5" stroke-linecap="butt"/><path d="M34 14l-3 3H14l-3-3"/><path d="M31 17v12.5H14V17" stroke-linecap="butt" stroke-linejoin="miter"/><path d="M31 29.5l1.5 2.5h-20l1.5-2.5"/><path d="M11 14h23" fill="none" stroke-linejoin="miter"/></g>""",
82
- # noqa: E501
83
- }
84
-
85
- XX = """<g id="xx"><path d="M35.865 9.135a1.89 1.89 0 0 1 0 2.673L25.173 22.5l10.692 10.692a1.89 1.89 0 0 1 0 2.673 1.89 1.89 0 0 1-2.673 0L22.5 25.173 11.808 35.865a1.89 1.89 0 0 1-2.673 0 1.89 1.89 0 0 1 0-2.673L19.827 22.5 9.135 11.808a1.89 1.89 0 0 1 0-2.673 1.89 1.89 0 0 1 2.673 0L22.5 19.827 33.192 9.135a1.89 1.89 0 0 1 2.673 0z" fill="#000" stroke="#fff" stroke-width="1.688"/></g>""" # noqa: E501
86
-
87
- CHECK_GRADIENT = """<radialGradient id="check_gradient"><stop offset="0%" stop-color="#ff0000" stop-opacity="1.0" /><stop offset="50%" stop-color="#e70000" stop-opacity="1.0" /><stop offset="100%" stop-color="#9e0000" stop-opacity="0.0" /></radialGradient>""" # noqa: E501
88
-
89
- DEFAULT_COLORS = {
90
- "square light": "#ffce9e",
91
- "square dark": "#d18b47",
92
- "square dark lastmove": "#aaa23b",
93
- "square light lastmove": "#cdd16a",
94
- }
95
-
96
- class Arrow:
97
- """Details of an arrow to be drawn."""
98
-
99
- def __init__(self, tail: chess.Square, head: chess.Square, *, color: str = "#888", annotation: str = '') -> None:
100
- self.tail = tail
101
- self.head = head
102
- self.color = color
103
- self.annotation = annotation
104
-
105
-
106
- class SvgWrapper(str):
107
- def _repr_svg_(self) -> "SvgWrapper":
108
- return self
109
-
110
-
111
- def _svg(viewbox: int, size: Optional[int]) -> ET.Element:
112
- svg = ET.Element("svg", {
113
- "xmlns": "http://www.w3.org/2000/svg",
114
- "version": "1.1",
115
- "xmlns:xlink": "http://www.w3.org/1999/xlink",
116
- "viewBox": f"0 0 {viewbox:d} {viewbox:d}",
117
- })
118
-
119
- if size is not None:
120
- svg.set("width", str(size))
121
- svg.set("height", str(size))
122
-
123
- return svg
124
-
125
-
126
- def _text(content: str, x: int, y: int, width: int, height: int) -> ET.Element:
127
- t = ET.Element("text", {
128
- "x": str(x + width // 2),
129
- "y": str(y + height // 2),
130
- "font-size": str(max(1, int(min(width, height) * 0.7))),
131
- "text-anchor": "middle",
132
- "alignment-baseline": "middle",
133
- })
134
- t.text = content
135
- return t
136
-
137
-
138
- def piece(piece: chess.Piece, size: Optional[int] = None) -> str:
139
- """
140
- Renders the given :class:`chess.Piece` as an SVG image.
141
- >>> import chess
142
- >>> import chess.svg
143
- >>>
144
- >>> chess.svg.piece(chess.Piece.from_symbol("R")) # doctest: +SKIP
145
- .. image:: ../docs/wR.svg
146
- """
147
- svg = _svg(SQUARE_SIZE, size)
148
- svg.append(ET.fromstring(PIECES[piece.symbol()]))
149
- return SvgWrapper(ET.tostring(svg).decode("utf-8"))
150
-
151
-
152
- def board(board: Optional[chess.BaseBoard] = None, *,
153
- squares: Optional[chess.IntoSquareSet] = None,
154
- flipped: bool = False,
155
- coordinates: bool = True,
156
- lastmove: Optional[chess.Move] = None,
157
- check: Optional[chess.Square] = None,
158
- arrows: Iterable[Union[Arrow, Tuple[chess.Square, chess.Square]]] = (),
159
- size: Optional[int] = None,
160
- style: Optional[str] = None,
161
- square_colors: Iterable[str] = (), #TODO: remove as it is not needed anymore
162
- only_pieces: bool = False) -> str:
163
- """
164
- Renders a board with pieces and/or selected squares as an SVG image.
165
- :param board: A :class:`chess.BaseBoard` for a chessboard with pieces or
166
- ``None`` (the default) for a chessboard without pieces.
167
- :param squares: A :class:`chess.SquareSet` with selected squares.
168
- :param flipped: Pass ``True`` to flip the board.
169
- :param coordinates: Pass ``False`` to disable coordinates in the margin.
170
- :param lastmove: A :class:`chess.Move` to be highlighted.
171
- :param check: A square to be marked as check.
172
- :param arrows: A list of :class:`~chess.svg.Arrow` objects like
173
- ``[chess.svg.Arrow(chess.E2, chess.E4)]`` or a list of tuples like
174
- ``[(chess.E2, chess.E4)]``. An arrow from a square pointing to the same
175
- square is drawn as a circle, like ``[(chess.E2, chess.E2)]``.
176
- :param size: The size of the image in pixels (e.g., ``400`` for a 400 by
177
- 400 board) or ``None`` (the default) for no size limit.
178
- :param style: A CSS stylesheet to include in the SVG image.
179
- >>> import chess
180
- >>> import chess.svg
181
- >>>
182
- >>> board = chess.Board("8/8/8/8/4N3/8/8/8 w - - 0 1")
183
- >>> squares = board.attacks(chess.E4)
184
- >>> chess.svg.board(board=board, squares=squares) # doctest: +SKIP
185
- .. image:: ../docs/Ne4.svg
186
- """
187
- margin = MARGIN if coordinates else 0
188
- svg = _svg(8 * SQUARE_SIZE + 2 * margin, size)
189
-
190
- if style:
191
- ET.SubElement(svg, "style").text = style
192
-
193
- defs = ET.SubElement(svg, "defs")
194
- if board:
195
- for piece_color in chess.COLORS:
196
- for piece_type in chess.PIECE_TYPES:
197
- if board.pieces_mask(piece_type, piece_color):
198
- defs.append(ET.fromstring(PIECES[chess.Piece(piece_type, piece_color).symbol()]))
199
-
200
- squares = chess.SquareSet(squares) if squares else chess.SquareSet()
201
- if squares:
202
- defs.append(ET.fromstring(XX))
203
-
204
- if check is not None and not only_pieces:
205
- defs.append(ET.fromstring(CHECK_GRADIENT))
206
-
207
- for square, bb in enumerate(chess.BB_SQUARES):
208
- file_index = chess.square_file(square)
209
- rank_index = chess.square_rank(square)
210
-
211
- x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
212
- y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
213
-
214
- cls = ["square", "light" if chess.BB_LIGHT_SQUARES & bb else "dark"]
215
- if lastmove and square in [lastmove.from_square, lastmove.to_square]:
216
- cls.append("lastmove")
217
- if square_colors == ():
218
- fill_color = DEFAULT_COLORS[" ".join(cls)]
219
- else:
220
- fill_color = square_colors[square]
221
-
222
- cls.append(chess.SQUARE_NAMES[square])
223
- if not only_pieces:
224
- ET.SubElement(svg, "rect", {
225
- "x": str(x),
226
- "y": str(y),
227
- "width": str(SQUARE_SIZE),
228
- "height": str(SQUARE_SIZE),
229
- "class": " ".join(cls),
230
- "stroke": "none",
231
- "fill": fill_color,
232
- })
233
-
234
- if square == check:
235
- ET.SubElement(svg, "rect", {
236
- "x": str(x),
237
- "y": str(y),
238
- "width": str(SQUARE_SIZE),
239
- "height": str(SQUARE_SIZE),
240
- "class": "check",
241
- "fill": "url(#check_gradient)",
242
- })
243
-
244
- # Render pieces.
245
- if board is not None:
246
- piece = board.piece_at(square)
247
- if piece:
248
- ET.SubElement(svg, "use", {
249
- "xlink:href": f"#{chess.COLOR_NAMES[piece.color]}-{chess.PIECE_NAMES[piece.piece_type]}",
250
- "transform": f"translate({x:d}, {y:d})",
251
- })
252
-
253
- # Render selected squares.
254
- if squares is not None and square in squares:
255
- #ET.SubElement(svg, "use", {
256
- # "xlink:href": "#xx",
257
- # "x": str(x),
258
- # "y": str(y),
259
- #})
260
- ET.SubElement(svg, "rect", {
261
- "x": str(x),
262
- "y": str(y),
263
- "width": str(SQUARE_SIZE),
264
- "height": str(SQUARE_SIZE),
265
- "class": "check",
266
- "fill": "none",
267
- "stroke": "#FF0000",
268
- "stroke-width": "5.0",
269
- "rx": "2.5",
270
- "opacity": "0.60"
271
- })
272
-
273
- if coordinates:
274
- for file_index, file_name in enumerate(chess.FILE_NAMES):
275
- x = (file_index if not flipped else 7 - file_index) * SQUARE_SIZE + margin
276
- svg.append(_text(file_name, x, 0, SQUARE_SIZE, margin))
277
- svg.append(_text(file_name, x, margin + 8 * SQUARE_SIZE, SQUARE_SIZE, margin))
278
- for rank_index, rank_name in enumerate(chess.RANK_NAMES):
279
- y = (7 - rank_index if not flipped else rank_index) * SQUARE_SIZE + margin
280
- svg.append(_text(rank_name, 0, y, margin, SQUARE_SIZE))
281
- svg.append(_text(rank_name, margin + 8 * SQUARE_SIZE, y, margin, SQUARE_SIZE))
282
-
283
- for arrow in arrows:
284
- try:
285
- tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
286
- except AttributeError:
287
- tail, head = arrow # type: ignore
288
- color = "#888"
289
- annotation = ''
290
-
291
- tail_file = chess.square_file(tail)
292
- tail_rank = chess.square_rank(tail)
293
- head_file = chess.square_file(head)
294
- head_rank = chess.square_rank(head)
295
-
296
- xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
297
- ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
298
- xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
299
- yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
300
-
301
- if (head_file, head_rank) == (tail_file, tail_rank):
302
- ET.SubElement(svg, "circle", {
303
- "cx": str(xhead),
304
- "cy": str(yhead),
305
- "r": str(SQUARE_SIZE * 0.9 / 2),
306
- "stroke-width": str(SQUARE_SIZE * 0.1),
307
- "stroke": color,
308
- "fill": "none",
309
- "opacity": "0.5",
310
- "class": "circle",
311
- })
312
- else:
313
- # marker_size = 0.75 * SQUARE_SIZE
314
- # marker_margin = 0.1 * SQUARE_SIZE
315
- marker_size = 0.5 * SQUARE_SIZE
316
- marker_margin = 0.05 * SQUARE_SIZE
317
-
318
- dx, dy = xhead - xtail, yhead - ytail
319
- hypot = math.hypot(dx, dy)
320
-
321
- shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
322
- shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
323
-
324
- xtip = xhead - dx * marker_margin / hypot
325
- ytip = yhead - dy * marker_margin / hypot
326
-
327
- x_annot = xtail + (shaft_x - xtail) / 2
328
- y_annot = ytail + (shaft_y - ytail) / 2
329
-
330
- x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot # - (xtip - xtail)*(SQUARE_SIZE/2)
331
- y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot # - (ytip - ytail)*(SQUARE_SIZE/2)
332
-
333
- ET.SubElement(svg, "line", {
334
- "x1": str(xtail),
335
- "y1": str(ytail),
336
- "x2": str(shaft_x),
337
- "y2": str(shaft_y),
338
- "stroke": color,
339
- "stroke-width": str(SQUARE_SIZE * 0.15),
340
- "opacity": "0.5",
341
- "stroke-linecap": "butt",
342
- "class": "arrow",
343
- })
344
-
345
- marker = [(xtip, ytip),
346
- (shaft_x + dy * 0.5 * marker_size / hypot,
347
- shaft_y - dx * 0.5 * marker_size / hypot),
348
- (shaft_x - dy * 0.5 * marker_size / hypot,
349
- shaft_y + dx * 0.5 * marker_size / hypot)]
350
-
351
- ET.SubElement(svg, "polygon", {
352
- "points": " ".join(str(x) + "," + str(y) for x, y in marker),
353
- "fill": color,
354
- "opacity": "0.5",
355
- "class": "arrow",
356
- })
357
-
358
- for arrow in arrows:
359
- try:
360
- tail, head, color, annotation = arrow.tail, arrow.head, arrow.color, arrow.annotation # type: ignore
361
- except AttributeError:
362
- tail, head = arrow # type: ignore
363
- color = "#888"
364
- annotation = ''
365
-
366
- tail_file = chess.square_file(tail)
367
- tail_rank = chess.square_rank(tail)
368
- head_file = chess.square_file(head)
369
- head_rank = chess.square_rank(head)
370
-
371
- xtail = margin + (tail_file + 0.5 if not flipped else 7.5 - tail_file) * SQUARE_SIZE
372
- ytail = margin + (7.5 - tail_rank if not flipped else tail_rank + 0.5) * SQUARE_SIZE
373
- xhead = margin + (head_file + 0.5 if not flipped else 7.5 - head_file) * SQUARE_SIZE
374
- yhead = margin + (7.5 - head_rank if not flipped else head_rank + 0.5) * SQUARE_SIZE
375
-
376
- marker_size = 0.5 * SQUARE_SIZE
377
- marker_margin = 0.05 * SQUARE_SIZE
378
-
379
- dx, dy = xhead - xtail, yhead - ytail
380
- hypot = math.hypot(dx, dy)
381
-
382
- shaft_x = xhead - dx * (marker_size + marker_margin) / hypot
383
- shaft_y = yhead - dy * (marker_size + marker_margin) / hypot
384
-
385
- xtip = xhead - dx * marker_margin / hypot
386
- ytip = yhead - dy * marker_margin / hypot
387
-
388
- x_annot = xhead - dx * 0.74 * SQUARE_SIZE / hypot
389
- y_annot = yhead - dy * 0.74 * SQUARE_SIZE / hypot
390
-
391
- if annotation != '':
392
- ET.SubElement(svg, "circle", {
393
- "cx": str(x_annot),
394
- "cy": str(y_annot),
395
- "r": str(SQUARE_SIZE * 0.175),
396
- #"r": str(SQUARE_SIZE * 0.2),
397
- "stroke-width": str(SQUARE_SIZE * 0.01),
398
- "stroke": '#000000',
399
- "fill": color,
400
- "opacity": "1.0",
401
- "class": "circle",
402
- })
403
- #style = get_style("'BundledDejavuSans'", str(SQUARE_SIZE * 0.1))
404
- annot = ET.SubElement(svg, "text", {
405
- "x": str(x_annot),
406
- "y": str(y_annot),
407
- "font-size": str(SQUARE_SIZE * 0.2), # max(1, int(min(SQUARE_SIZE, SQUARE_SIZE) * 0.3))),
408
- "text-anchor": "middle",
409
- "dominant-baseline": "middle"
410
- #"alignment-baseline": "middle"
411
- })
412
- annot.text = annotation
413
-
414
- return SvgWrapper(ET.tostring(svg).decode("utf-8"))
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/svg_pieces.py DELETED
@@ -1,31 +0,0 @@
1
- from python_chess_customized_svg import piece
2
- import python_chess_customized_svg as svg
3
- import chess
4
- import base64
5
-
6
- #def get_svg_piece(symbol):
7
- # img = piece(chess.Piece.from_symbol(symbol))
8
- # svg_str = str(img)
9
- # svg_byte = svg_str.encode()
10
- # encoded = base64.b64encode(svg_byte)
11
- # svg_piece = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
12
- # return svg_piece
13
-
14
- def get_svg_board(board, focused_square_ind, only_pieces):
15
- if focused_square_ind is not None:
16
- squares = [focused_square_ind]
17
- else:
18
- squares = []
19
- if board.move_stack:
20
- print('board stack YES')
21
- lastmove = board.peek()
22
- else:
23
- print('board stack NO')
24
- lastmove = None
25
- svg_str = str(svg.board(board, squares=squares, arrows=[], lastmove=lastmove, coordinates=False, only_pieces=only_pieces))
26
- svg_byte = svg_str.encode()
27
- encoded = base64.b64encode(svg_byte)
28
- svg_board = 'data:image/svg+xml;base64,{}'.format(encoded.decode())
29
- return svg_board
30
-
31
- #SVG_PIECES = {piece: get_svg_piece(piece) for piece in ('b', 'k', 'n', 'p', 'q', 'r', 'B', 'K', 'N', 'P', 'Q', 'R')}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/utils.py DELETED
@@ -1,32 +0,0 @@
1
- from leela_board import LeelaBoard
2
- import chess
3
- import torch
4
-
5
-
6
- def flip_move(move):
7
- from_square = chess.square_mirror(chess.parse_square(move[:2]))
8
- to_square = chess.square_mirror(chess.parse_square(move[2:4]))
9
- promotion = move[4:] if len(move) > 4 else ""
10
- return chess.square_name(from_square) + chess.square_name(to_square) + promotion
11
-
12
-
13
- def flip_board(fen, moves):
14
- temp_board = chess.Board(fen=fen)
15
- return temp_board.mirror().fen(), [flip_move(move) for move in moves]
16
-
17
-
18
- # Helper functions
19
- class ChessBoard:
20
- def __init__(self, fen): # Create new board from fen
21
- self.board = LeelaBoard(fen=fen)
22
- self.t = self.__t()
23
-
24
- def move(self, move): # Move piece on board ("e2e3")
25
- self.board.push_uci(move)
26
- self.t = self.__t()
27
-
28
- def __t(self): # Set board tensor (private method)
29
- return torch.from_numpy(self.board.lcz_features()).float()
30
-
31
- def __str__(self): # Prints board state
32
- return str(self.board)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
our_visualization/visualization_demo.py DELETED
@@ -1,230 +0,0 @@
1
- import marimo
2
-
3
- __generated_with = "0.8.22"
4
- app = marimo.App(width="medium")
5
-
6
-
7
- @app.cell
8
- def __():
9
- import marimo as mo
10
- return (mo,)
11
-
12
-
13
- @app.cell
14
- def __():
15
- import pandas as pd
16
- df = pd.read_csv("our_visualization/datasets/test_set.csv")
17
- df.head()
18
- return df, pd
19
-
20
-
21
- @app.cell
22
- def __():
23
- import pickle
24
- from utils import ChessBoard
25
- import onnxruntime as ort
26
- from leela_board import _idx_to_move_bn, _idx_to_move_wn
27
- import numpy as np
28
- from onnx2torch import convert
29
- import onnx
30
- import torch
31
- import os
32
-
33
- def get_models(root="/our_visualization/models"):
34
- paths = os.listdir(root)
35
- model_paths = []
36
- for path in paths:
37
- if ".onnx" in path: model_paths.append(os.path.join(root, path))
38
- return model_paths
39
-
40
- def get_activations_from_model(model_path, pattern, fen):
41
- # Write hooks for selected model path
42
- def register_hooks_for_capture(model, pattern):
43
- activations = {}
44
- def get_activation(name):
45
- def hook(module, input, output):
46
- activations[name] = output.detach().numpy()
47
- return hook
48
-
49
- handles = []
50
- for n, m in model.named_modules():
51
- if pattern in n:
52
- handle = m.register_forward_hook(get_activation(n))
53
- handles.append(handle)
54
- return activations, handles
55
-
56
- # Load model and register hooks for it
57
- model = convert(onnx.load(model_path))
58
- act, handles = register_hooks_for_capture(model, pattern)
59
-
60
- # Get fen and pass it through model to generate activations
61
- board = ChessBoard(fen)
62
- inputs = board.t
63
- _, _, _ = model(inputs.unsqueeze(dim=0))
64
-
65
- # Remove handles
66
- [h.remove() for h in handles]
67
- return act
68
- return (
69
- ChessBoard,
70
- convert,
71
- get_activations_from_model,
72
- get_models,
73
- np,
74
- onnx,
75
- ort,
76
- os,
77
- pickle,
78
- torch,
79
- )
80
-
81
-
82
- @app.cell
83
- def __(df, mo):
84
- min_elo, max_elo = df["Rating"].min() // 100 * 100, df["Rating"].max() // 100 * 100
85
- elo_list = [f"{elo}" for elo in range(min_elo, max_elo + 100, 100)]
86
- dropdown_elo = mo.ui.dropdown(value = "1000", options=elo_list, label=f"Select rating in range of {min_elo} - {max_elo}")
87
- dropdown_elo
88
- return dropdown_elo, elo_list, max_elo, min_elo
89
-
90
-
91
- @app.cell
92
- def __(df, dropdown_elo, mo):
93
- unique_themes = set()
94
- df_rated = df[(df["Rating"] >= int(dropdown_elo.value)) & (df["Rating"] <= int(dropdown_elo.value) + 100)]
95
- for i in range(len(df_rated)):
96
- themes = df_rated.iloc[i]["Themes"].split(" ")
97
- for theme in themes: unique_themes.add(theme)
98
- unique_themes_list = list(unique_themes)
99
- unique_themes_list.sort()
100
-
101
- dropdown_themes = mo.ui.dropdown(value=unique_themes_list[0], options=unique_themes_list, label=f"Select puzzle theme")
102
- dropdown_themes
103
- return (
104
- df_rated,
105
- dropdown_themes,
106
- i,
107
- theme,
108
- themes,
109
- unique_themes,
110
- unique_themes_list,
111
- )
112
-
113
-
114
- @app.cell
115
- def __(df_rated, dropdown_themes):
116
- themes_mask = []
117
- def _(themes_mask):
118
- for i in range(len(df_rated)):
119
- themes_new = df_rated.iloc[i]["Themes"].split(" ")
120
- if dropdown_themes.value in themes_new: themes_mask.append(i)
121
- _(themes_mask)
122
- fens = list(df_rated.iloc[themes_mask]["FEN"])
123
- df_rated.iloc[themes_mask][["FEN", "Moves", "Themes", "Rating"]]
124
- return fens, themes_mask
125
-
126
-
127
- @app.cell
128
- def __(fens, mo):
129
- dropdown_fen = mo.ui.dropdown(value = fens[0], options=fens, label="Select FEN")
130
- dropdown_fen
131
- return (dropdown_fen,)
132
-
133
-
134
- @app.cell
135
- def __(df_rated, dropdown_fen, mo):
136
- moves = df_rated[df_rated["FEN"] == dropdown_fen.value]["Moves"].iloc[0].split(" ")
137
- player_moves = moves[1::2]
138
- board_moves = []
139
- def _(board_moves):
140
- for i in range(len(player_moves)):
141
- board_moves.append(moves[:2 * i + 1])
142
- _(board_moves)
143
- moves_dict = {pm: om for pm, om in zip(player_moves, board_moves)}
144
- dropdown_moves = mo.ui.dropdown(options=moves_dict, value=player_moves[0], label="Select which player move to look at")
145
- # print(moves)
146
- dropdown_moves
147
- return board_moves, dropdown_moves, moves, moves_dict, player_moves
148
-
149
-
150
- @app.cell
151
- def __(dropdown_moves, mo):
152
- dropdown_layer = mo.ui.dropdown(value="0", options=[f"{i}" for i in range(15)], label="Select layer (smaller - closer to input)")
153
- focus_square = mo.ui.text_area(value=dropdown_moves.selected_key[:2], placeholder="Input square to look at (e.g. a1, b8, ...")
154
- mo.vstack([dropdown_layer, focus_square])
155
- return dropdown_layer, focus_square
156
-
157
-
158
- @app.cell
159
- def __(ChessBoard, dropdown_fen, dropdown_moves):
160
- def _():
161
- board = ChessBoard(dropdown_fen.value)
162
- for move in dropdown_moves.value:
163
- print(move)
164
- # board.move(move)
165
- return board.board.pc_board.fen()
166
- FEN = _()
167
- return (FEN,)
168
-
169
-
170
- @app.cell
171
- def __(focus_square):
172
- import chess
173
- from global_data import global_data
174
-
175
- focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
176
-
177
- def set_plotting_parameters(act, layer_number, fen):
178
- layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
179
- print(act.keys())
180
- global_data.model = 'test'
181
- global_data.activations = act[layer_key][0, :, ::-1 , :]
182
- print(global_data.activations.shape)
183
- global_data.subplot_rows = 8
184
- global_data.subplot_cols = 4
185
- global_data.board = chess.Board(fen)
186
- global_data.show_all_heads = True
187
- # global_data.selected_head = 1
188
- global_data.visualization_mode = 'ROW'
189
- global_data.focused_square_ind = focus_square_ind
190
- # global_data.heatmap_horizontal_gap = 0.001
191
-
192
- global_data.visualization_mode_is_64x64 = False
193
- global_data.colorscale_mode = "mode1"
194
- global_data.show_colorscale = False
195
- return chess, focus_square_ind, global_data, set_plotting_parameters
196
-
197
-
198
- @app.cell
199
- def __(
200
- FEN,
201
- dropdown_layer,
202
- get_activations_from_model,
203
- get_models,
204
- set_plotting_parameters,
205
- ):
206
- # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
207
- # board = ChessBoard("r1b2rk1/pp2pp1p/6p1/3Qb2q/1P4n1/2P1BN2/P2N1PPP/R4RK1 w - - 0 14")
208
- # board.move("f3e5")
209
- # FEN = board.board.pc_board.fen()
210
- PATTERN = "mha/QK/softmax"
211
- # PATTERN = "smolgen_weights"
212
- MODEL = get_models()[-1]
213
- ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
214
- set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN)
215
- from activation_heatmap import heatmap_figure
216
- fig = heatmap_figure()
217
- fig.update_layout(height=1500, width=1200)
218
- fig
219
- return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
220
-
221
-
222
- @app.cell
223
- def __():
224
- # Add fens after opponents moves
225
- # Default squares of interest
226
- return
227
-
228
-
229
- if __name__ == "__main__":
230
- app.run()