seredapj commited on
Commit
c853eef
·
verified ·
1 Parent(s): 1fa4f92

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +69 -8
app.py CHANGED
@@ -167,17 +167,70 @@ def __(ChessBoard, dropdown_fen, dropdown_moves):
167
 
168
 
169
  @app.cell
170
- def __(focus_square):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
171
  import chess
172
  from global_data import global_data
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
175
 
176
- def set_plotting_parameters(act, layer_number, fen):
177
- layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
178
- print(act.keys())
179
  global_data.model = 'test'
180
- global_data.activations = act[layer_key][0, :, ::-1 , :]
181
  print(global_data.activations.shape)
182
  global_data.subplot_rows = 8
183
  global_data.subplot_cols = 4
@@ -191,7 +244,13 @@ def __(focus_square):
191
  global_data.visualization_mode_is_64x64 = False
192
  global_data.colorscale_mode = "mode1"
193
  global_data.show_colorscale = False
194
- return chess, focus_square_ind, global_data, set_plotting_parameters
 
 
 
 
 
 
195
 
196
 
197
  @app.cell
@@ -200,6 +259,7 @@ def __(
200
  dropdown_layer,
201
  get_activations_from_model,
202
  get_models,
 
203
  set_plotting_parameters,
204
  ):
205
  # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
@@ -210,12 +270,13 @@ def __(
210
  # PATTERN = "smolgen_weights"
211
  MODEL = get_models()[-1]
212
  ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
213
- set_plotting_parameters(ACTIVATIONS, int(dropdown_layer.value), FEN)
 
214
  from activation_heatmap import heatmap_figure
215
  fig = heatmap_figure()
216
  fig.update_layout(height=1500, width=1200)
217
  fig
218
- return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
219
 
220
 
221
  @app.cell
 
167
 
168
 
169
  @app.cell
170
+ def __(torch):
171
+ def rollout(x, skip_connection=True, parse="min"):
172
+ attns = []
173
+ for k, v in x.items():
174
+ v = v[0, :, ::-1, :]
175
+ v = torch.tensor(v.copy())
176
+ if parse == "min":
177
+ item = torch.min(v, dim=0).values
178
+ elif parse == "max":
179
+ item = torch.max(v, dim=0).values
180
+ elif parse == "mean":
181
+ item = torch.mean(v, dim=0)
182
+ attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
183
+ roll = torch.prod(torch.stack(attns), dim=0)
184
+ return roll.numpy()
185
+
186
+ return (rollout,)
187
+
188
+
189
+ @app.cell
190
+ def __(mo):
191
+ METHODS = ["Attention visualization", "Attention rollout (MIN)", "Attention rollout (MEAN)", "Attention rollout (MAX)"]
192
+ dropdown_method = mo.ui.dropdown(options=METHODS, value=METHODS[0], label="Select XAI method")
193
+ dropdown_method
194
+ return METHODS, dropdown_method
195
+
196
+
197
+ @app.cell
198
+ def __(mo):
199
+ switch = mo.ui.switch(value=False, label="To use skip connection in rollout")
200
+ switch
201
+ return (switch,)
202
+
203
+
204
+ @app.cell
205
+ def __(switch):
206
+ switch.value
207
+ return
208
+
209
+
210
+ @app.cell
211
+ def __(dropdown_method, focus_square, rollout, switch, torch):
212
  import chess
213
  from global_data import global_data
214
 
215
+ def parse_activations(act, layer_number=None):
216
+ if dropdown_method.value == "Attention visualization":
217
+ layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
218
+ a = act[layer_key][0, :, ::-1 , :]
219
+ elif dropdown_method.value == "Attention rollout (MIN)":
220
+ a = rollout(act, skip_connection=switch.value, parse="min")
221
+ a = torch.stack([a for _ in range(32)])
222
+ elif dropdown_method.value == "Attention rollout (MAX)":
223
+ a = rollout(act, skip_connection=switch.value, parse="max")
224
+ a = torch.stack([a for _ in range(32)])
225
+ elif dropdown_method.value == "Attention rollout (MEAN)":
226
+ a = rollout(act, skip_connection=switch.value, parse="mean")
227
+ a = torch.stack([a for _ in range(32)])
228
+ return a
229
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
230
 
231
+ def set_plotting_parameters(act, fen):
 
 
232
  global_data.model = 'test'
233
+ global_data.activations = act
234
  print(global_data.activations.shape)
235
  global_data.subplot_rows = 8
236
  global_data.subplot_cols = 4
 
244
  global_data.visualization_mode_is_64x64 = False
245
  global_data.colorscale_mode = "mode1"
246
  global_data.show_colorscale = False
247
+ return (
248
+ chess,
249
+ focus_square_ind,
250
+ global_data,
251
+ parse_activations,
252
+ set_plotting_parameters,
253
+ )
254
 
255
 
256
  @app.cell
 
259
  dropdown_layer,
260
  get_activations_from_model,
261
  get_models,
262
+ parse_activations,
263
  set_plotting_parameters,
264
  ):
265
  # FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
 
270
  # PATTERN = "smolgen_weights"
271
  MODEL = get_models()[-1]
272
  ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
273
+ a = parse_activations(ACTIVATIONS, layer_number=int(dropdown_layer.value))
274
+ set_plotting_parameters(a, FEN)
275
  from activation_heatmap import heatmap_figure
276
  fig = heatmap_figure()
277
  fig.update_layout(height=1500, width=1200)
278
  fig
279
+ return ACTIVATIONS, MODEL, PATTERN, a, fig, heatmap_figure
280
 
281
 
282
  @app.cell