Update app.py
Browse files
app.py
CHANGED
|
@@ -168,7 +168,7 @@ def __(ChessBoard, dropdown_fen, dropdown_moves):
|
|
| 168 |
|
| 169 |
@app.cell
|
| 170 |
def __(torch):
|
| 171 |
-
def rollout(x, skip_connection=True, parse="min"):
|
| 172 |
attns = []
|
| 173 |
for k, v in x.items():
|
| 174 |
v = v[0, :, ::-1, :]
|
|
@@ -179,10 +179,9 @@ def __(torch):
|
|
| 179 |
item = torch.max(v, dim=0).values
|
| 180 |
elif parse == "mean":
|
| 181 |
item = torch.mean(v, dim=0)
|
| 182 |
-
|
| 183 |
-
roll = torch.prod(torch.stack(attns), dim=0)
|
| 184 |
return roll
|
| 185 |
-
|
| 186 |
return (rollout,)
|
| 187 |
|
| 188 |
|
|
@@ -202,30 +201,68 @@ def __(mo):
|
|
| 202 |
|
| 203 |
|
| 204 |
@app.cell
|
| 205 |
-
def __(
|
| 206 |
-
|
| 207 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 208 |
|
| 209 |
|
| 210 |
@app.cell
|
| 211 |
-
def __(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 212 |
import chess
|
| 213 |
from global_data import global_data
|
| 214 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 215 |
def parse_activations(act, layer_number=None):
|
| 216 |
if dropdown_method.value == "Attention visualization":
|
| 217 |
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
|
| 218 |
a = act[layer_key][0, :, ::-1 , :]
|
| 219 |
elif dropdown_method.value == "Attention rollout (MIN)":
|
| 220 |
-
|
| 221 |
-
|
|
|
|
|
|
|
|
|
|
| 222 |
elif dropdown_method.value == "Attention rollout (MAX)":
|
| 223 |
-
|
| 224 |
-
|
|
|
|
|
|
|
|
|
|
| 225 |
elif dropdown_method.value == "Attention rollout (MEAN)":
|
| 226 |
-
|
| 227 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 228 |
return a
|
|
|
|
| 229 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 230 |
|
| 231 |
def set_plotting_parameters(act, fen):
|
|
@@ -246,6 +283,7 @@ def __(dropdown_method, focus_square, rollout, switch, torch):
|
|
| 246 |
global_data.show_colorscale = False
|
| 247 |
return (
|
| 248 |
chess,
|
|
|
|
| 249 |
focus_square_ind,
|
| 250 |
global_data,
|
| 251 |
parse_activations,
|
|
|
|
| 168 |
|
| 169 |
@app.cell
|
| 170 |
def __(torch):
|
| 171 |
+
def rollout(x, skip_last_layers=0, skip_connection=True, parse="min"):
|
| 172 |
attns = []
|
| 173 |
for k, v in x.items():
|
| 174 |
v = v[0, :, ::-1, :]
|
|
|
|
| 179 |
item = torch.max(v, dim=0).values
|
| 180 |
elif parse == "mean":
|
| 181 |
item = torch.mean(v, dim=0)
|
| 182 |
+
attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
|
| 183 |
+
roll = torch.prod(torch.stack(attns)[:skip_last_layers], dim=0)
|
| 184 |
return roll
|
|
|
|
| 185 |
return (rollout,)
|
| 186 |
|
| 187 |
|
|
|
|
| 201 |
|
| 202 |
|
| 203 |
@app.cell
|
| 204 |
+
def __(mo):
|
| 205 |
+
max_value_switch = mo.ui.switch(value=False, label="use rollout layers with max value")
|
| 206 |
+
max_value_switch
|
| 207 |
+
return (max_value_switch,)
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@app.cell
|
| 211 |
+
def __(mo):
|
| 212 |
+
highlight_squares_switch = mo.ui.switch(value=False, label="Use rollout to highlight squares")
|
| 213 |
+
highlight_squares_switch
|
| 214 |
+
return (highlight_squares_switch,)
|
| 215 |
|
| 216 |
|
| 217 |
@app.cell
|
| 218 |
+
def __(
|
| 219 |
+
dropdown_layer,
|
| 220 |
+
dropdown_method,
|
| 221 |
+
focus_square,
|
| 222 |
+
highlight_squares_switch,
|
| 223 |
+
max_value_switch,
|
| 224 |
+
rollout,
|
| 225 |
+
switch,
|
| 226 |
+
torch,
|
| 227 |
+
):
|
| 228 |
import chess
|
| 229 |
from global_data import global_data
|
| 230 |
|
| 231 |
+
def find_max(a):
|
| 232 |
+
ar = a.reshape(a.shape[0], -1)
|
| 233 |
+
i = torch.max(ar, dim=1).values
|
| 234 |
+
im = torch.argmax(i[1:])
|
| 235 |
+
return a[im + 1]
|
| 236 |
+
|
| 237 |
def parse_activations(act, layer_number=None):
|
| 238 |
if dropdown_method.value == "Attention visualization":
|
| 239 |
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
|
| 240 |
a = act[layer_key][0, :, ::-1 , :]
|
| 241 |
elif dropdown_method.value == "Attention rollout (MIN)":
|
| 242 |
+
if not max_value_switch.value:
|
| 243 |
+
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="min")
|
| 244 |
+
else:
|
| 245 |
+
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="min") for _i in range(0, 15)], dim=0)
|
| 246 |
+
a = find_max(a)
|
| 247 |
elif dropdown_method.value == "Attention rollout (MAX)":
|
| 248 |
+
if not max_value_switch.value:
|
| 249 |
+
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="max")
|
| 250 |
+
else:
|
| 251 |
+
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="max") for _i in range(0, 15)], dim=0)
|
| 252 |
+
a = find_max(a)
|
| 253 |
elif dropdown_method.value == "Attention rollout (MEAN)":
|
| 254 |
+
if not max_value_switch.value:
|
| 255 |
+
a = rollout(act, skip_last_layers=int(dropdown_layer.value), skip_connection=switch.value, parse="mean")
|
| 256 |
+
else:
|
| 257 |
+
a = torch.stack([rollout(act, skip_last_layers=_i, skip_connection=switch.value, parse="mean") for _i in range(0, 15)], dim=0)
|
| 258 |
+
a = find_max(a)
|
| 259 |
+
if dropdown_method.value != "Attention visualization":
|
| 260 |
+
if highlight_squares_switch.value:
|
| 261 |
+
a = a.max(dim=0).values
|
| 262 |
+
a = torch.stack([a for _ in range(64)], dim=0)
|
| 263 |
+
a = torch.stack([a for _ in range(32)]).numpy()
|
| 264 |
return a
|
| 265 |
+
|
| 266 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 267 |
|
| 268 |
def set_plotting_parameters(act, fen):
|
|
|
|
| 283 |
global_data.show_colorscale = False
|
| 284 |
return (
|
| 285 |
chess,
|
| 286 |
+
find_max,
|
| 287 |
focus_square_ind,
|
| 288 |
global_data,
|
| 289 |
parse_activations,
|