seredapj commited on
Commit
5bb75ee
·
verified ·
1 Parent(s): 35bba36

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +52 -14
app.py CHANGED
@@ -168,7 +168,7 @@ def __(ChessBoard, dropdown_fen, dropdown_moves):
168
 
169
  @app.cell
170
  def __(torch):
171
- def rollout(x, skip_connection=True, parse="min"):
172
  attns = []
173
  for k, v in x.items():
174
  v = v[0, :, ::-1, :]
@@ -179,10 +179,9 @@ def __(torch):
179
  item = torch.max(v, dim=0).values
180
  elif parse == "mean":
181
  item = torch.mean(v, dim=0)
182
- attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
183
- roll = torch.prod(torch.stack(attns), dim=0)
184
  return roll
185
-
186
  return (rollout,)
187
 
188
 
@@ -202,30 +201,68 @@ def __(mo):
202
 
203
 
204
  @app.cell
205
- def __(switch):
206
- switch.value
207
- return
 
 
 
 
 
 
 
 
208
 
209
 
210
  @app.cell
211
- def __(dropdown_method, focus_square, rollout, switch, torch):
 
 
 
 
 
 
 
 
 
212
  import chess
213
  from global_data import global_data
214
 
 
 
 
 
 
 
215
  def parse_activations(act, layer_number=None):
216
  if dropdown_method.value == "Attention visualization":
217
  layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
218
  a = act[layer_key][0, :, ::-1 , :]
219
  elif dropdown_method.value == "Attention rollout (MIN)":
220
- a = rollout(act, skip_connection=switch.value, parse="min")
221
- a = torch.stack([a for _ in range(32)]).numpy()
 
 
 
222
  elif dropdown_method.value == "Attention rollout (MAX)":
223
- a = rollout(act, skip_connection=switch.value, parse="max")
224
- a = torch.stack([a for _ in range(32)]).numpy()
 
 
 
225
  elif dropdown_method.value == "Attention rollout (MEAN)":
226
- a = rollout(act, skip_connection=switch.value, parse="mean")
227
- a = torch.stack([a for _ in range(32)]).numpy()
 
 
 
 
 
 
 
 
228
  return a
 
229
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
230
 
231
  def set_plotting_parameters(act, fen):
@@ -246,6 +283,7 @@ def __(dropdown_method, focus_square, rollout, switch, torch):
246
  global_data.show_colorscale = False
247
  return (
248
  chess,
 
249
  focus_square_ind,
250
  global_data,
251
  parse_activations,
 
168
 
169
  @app.cell
170
  def __(torch):
171
+ def rollout(x, skip_last_layers=0, skip_connection=True, parse="min"):
172
  attns = []
173
  for k, v in x.items():
174
  v = v[0, :, ::-1, :]
 
179
  item = torch.max(v, dim=0).values
180
  elif parse == "mean":
181
  item = torch.mean(v, dim=0)
182
+ attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
183
+ roll = torch.prod(torch.stack(attns)[:skip_last_layers], dim=0)
184
  return roll
 
185
  return (rollout,)
186
 
187
 
 
201
 
202
 
203
  @app.cell
204
+ def __(mo):
205
+ max_value_switch = mo.ui.switch(value=False, label="use rollout layers with max value")
206
+ max_value_switch
207
+ return (max_value_switch,)
208
+
209
+
210
+ @app.cell
211
+ def __(mo):
212
+ highlight_squares_switch = mo.ui.switch(value=False, label="Use rollout to highlight squares")
213
+ highlight_squares_switch
214
+ return (highlight_squares_switch,)
215
 
216
 
217
  @app.cell
218
+ def __(
219
+ dropdown_layer,
220
+ dropdown_method,
221
+ focus_square,
222
+ highlight_squares_switch,
223
+ max_value_switch,
224
+ rollout,
225
+ switch,
226
+ torch,
227
+ ):
228
  import chess
229
  from global_data import global_data
230
 
231
+ def find_max(a):
232
+ ar = a.reshape(a.shape[0], -1)
233
+ i = torch.max(ar, dim=1).values
234
+ im = torch.argmax(i[1:])
235
+ return a[im + 1]
236
+
237
  def parse_activations(act, layer_number=None):
238
  if dropdown_method.value == "Attention visualization":
239
  layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
240
  a = act[layer_key][0, :, ::-1 , :]
241
  elif dropdown_method.value == "Attention rollout (MIN)":
242
+ if not max_value_switch.value:
243
+ a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="min")
244
+ else:
245
+ a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="min") for _i in range(0, 15)], dim=0)
246
+ a = find_max(a)
247
  elif dropdown_method.value == "Attention rollout (MAX)":
248
+ if not max_value_switch.value:
249
+ a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="max")
250
+ else:
251
+ a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="max") for _i in range(0, 15)], dim=0)
252
+ a = find_max(a)
253
  elif dropdown_method.value == "Attention rollout (MEAN)":
254
+ if not max_value_switch.value:
255
+ a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="mean")
256
+ else:
257
+ a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="mean") for _i in range(0, 15)], dim=0)
258
+ a = find_max(a)
259
+ if dropdown_method.value != "Attention visualization":
260
+ if highlight_squares_switch.value:
261
+ a = a.max(dim=0).values
262
+ a = torch.stack([a for _ in range(64)], dim=0)
263
+ a = torch.stack([a for _ in range(32)]).numpy()
264
  return a
265
+
266
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
267
 
268
  def set_plotting_parameters(act, fen):
 
283
  global_data.show_colorscale = False
284
  return (
285
  chess,
286
+ find_max,
287
  focus_square_ind,
288
  global_data,
289
  parse_activations,