Update app.py
Browse files
app.py
CHANGED
|
@@ -167,17 +167,70 @@ def __(ChessBoard, dropdown_fen, dropdown_moves):
|
|
| 167 |
|
| 168 |
|
| 169 |
@app.cell
|
| 170 |
-
def __(
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 171 |
import chess
|
| 172 |
from global_data import global_data
|
| 173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 175 |
|
| 176 |
-
def set_plotting_parameters(act,
|
| 177 |
-
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
|
| 178 |
-
print(act.keys())
|
| 179 |
global_data.model = 'test'
|
| 180 |
-
global_data.activations = act
|
| 181 |
print(global_data.activations.shape)
|
| 182 |
global_data.subplot_rows = 8
|
| 183 |
global_data.subplot_cols = 4
|
|
@@ -191,7 +244,13 @@ def __(focus_square):
|
|
| 191 |
global_data.visualization_mode_is_64x64 = False
|
| 192 |
global_data.colorscale_mode = "mode1"
|
| 193 |
global_data.show_colorscale = False
|
| 194 |
-
return
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 195 |
|
| 196 |
|
| 197 |
@app.cell
|
|
@@ -200,6 +259,7 @@ def __(
|
|
| 200 |
dropdown_layer,
|
| 201 |
get_activations_from_model,
|
| 202 |
get_models,
|
|
|
|
| 203 |
set_plotting_parameters,
|
| 204 |
):
|
| 205 |
# FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
@@ -210,12 +270,13 @@ def __(
|
|
| 210 |
# PATTERN = "smolgen_weights"
|
| 211 |
MODEL = get_models()[-1]
|
| 212 |
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
|
| 213 |
-
|
|
|
|
| 214 |
from activation_heatmap import heatmap_figure
|
| 215 |
fig = heatmap_figure()
|
| 216 |
fig.update_layout(height=1500, width=1200)
|
| 217 |
fig
|
| 218 |
-
return ACTIVATIONS, MODEL, PATTERN, fig, heatmap_figure
|
| 219 |
|
| 220 |
|
| 221 |
@app.cell
|
|
|
|
| 167 |
|
| 168 |
|
| 169 |
@app.cell
|
| 170 |
+
def __(torch):
|
| 171 |
+
def rollout(x, skip_connection=True, parse="min"):
|
| 172 |
+
attns = []
|
| 173 |
+
for k, v in x.items():
|
| 174 |
+
v = v[0, :, ::-1, :]
|
| 175 |
+
v = torch.tensor(v.copy())
|
| 176 |
+
if parse == "min":
|
| 177 |
+
item = torch.min(v, dim=0).values
|
| 178 |
+
elif parse == "max":
|
| 179 |
+
item = torch.max(v, dim=0).values
|
| 180 |
+
elif parse == "mean":
|
| 181 |
+
item = torch.mean(v, dim=0)
|
| 182 |
+
attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
|
| 183 |
+
roll = torch.prod(torch.stack(attns), dim=0)
|
| 184 |
+
return roll.numpy()
|
| 185 |
+
|
| 186 |
+
return (rollout,)
|
| 187 |
+
|
| 188 |
+
|
| 189 |
+
@app.cell
|
| 190 |
+
def __(mo):
|
| 191 |
+
METHODS = ["Attention visualization", "Attention rollout (MIN)", "Attention rollout (MEAN)", "Attention rollout (MAX)"]
|
| 192 |
+
dropdown_method = mo.ui.dropdown(options=METHODS, value=METHODS[0], label="Select XAI method")
|
| 193 |
+
dropdown_method
|
| 194 |
+
return METHODS, dropdown_method
|
| 195 |
+
|
| 196 |
+
|
| 197 |
+
@app.cell
|
| 198 |
+
def __(mo):
|
| 199 |
+
switch = mo.ui.switch(value=False, label="To use skip connection in rollout")
|
| 200 |
+
switch
|
| 201 |
+
return (switch,)
|
| 202 |
+
|
| 203 |
+
|
| 204 |
+
@app.cell
|
| 205 |
+
def __(switch):
|
| 206 |
+
switch.value
|
| 207 |
+
return
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
@app.cell
|
| 211 |
+
def __(dropdown_method, focus_square, rollout, switch, torch):
|
| 212 |
import chess
|
| 213 |
from global_data import global_data
|
| 214 |
|
| 215 |
+
def parse_activations(act, layer_number=None):
|
| 216 |
+
if dropdown_method.value == "Attention visualization":
|
| 217 |
+
layer_key = [k for k in act.keys() if "0" in k][0].replace("0", f"{layer_number}")
|
| 218 |
+
a = act[layer_key][0, :, ::-1 , :]
|
| 219 |
+
elif dropdown_method.value == "Attention rollout (MIN)":
|
| 220 |
+
a = rollout(act, skip_connection=switch.value, parse="min")
|
| 221 |
+
a = torch.stack([a for _ in range(32)])
|
| 222 |
+
elif dropdown_method.value == "Attention rollout (MAX)":
|
| 223 |
+
a = rollout(act, skip_connection=switch.value, parse="max")
|
| 224 |
+
a = torch.stack([a for _ in range(32)])
|
| 225 |
+
elif dropdown_method.value == "Attention rollout (MEAN)":
|
| 226 |
+
a = rollout(act, skip_connection=switch.value, parse="mean")
|
| 227 |
+
a = torch.stack([a for _ in range(32)])
|
| 228 |
+
return a
|
| 229 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 230 |
|
| 231 |
+
def set_plotting_parameters(act, fen):
|
|
|
|
|
|
|
| 232 |
global_data.model = 'test'
|
| 233 |
+
global_data.activations = act
|
| 234 |
print(global_data.activations.shape)
|
| 235 |
global_data.subplot_rows = 8
|
| 236 |
global_data.subplot_cols = 4
|
|
|
|
| 244 |
global_data.visualization_mode_is_64x64 = False
|
| 245 |
global_data.colorscale_mode = "mode1"
|
| 246 |
global_data.show_colorscale = False
|
| 247 |
+
return (
|
| 248 |
+
chess,
|
| 249 |
+
focus_square_ind,
|
| 250 |
+
global_data,
|
| 251 |
+
parse_activations,
|
| 252 |
+
set_plotting_parameters,
|
| 253 |
+
)
|
| 254 |
|
| 255 |
|
| 256 |
@app.cell
|
|
|
|
| 259 |
dropdown_layer,
|
| 260 |
get_activations_from_model,
|
| 261 |
get_models,
|
| 262 |
+
parse_activations,
|
| 263 |
set_plotting_parameters,
|
| 264 |
):
|
| 265 |
# FEN = "rnbqkbnr/pppppppp/8/8/8/8/PPPPPPPP/RNBQKBNR w KQkq - 0 1"
|
|
|
|
| 270 |
# PATTERN = "smolgen_weights"
|
| 271 |
MODEL = get_models()[-1]
|
| 272 |
ACTIVATIONS = get_activations_from_model(MODEL, PATTERN, FEN)
|
| 273 |
+
a = parse_activations(ACTIVATIONS, layer_number=int(dropdown_layer.value))
|
| 274 |
+
set_plotting_parameters(a, FEN)
|
| 275 |
from activation_heatmap import heatmap_figure
|
| 276 |
fig = heatmap_figure()
|
| 277 |
fig.update_layout(height=1500, width=1200)
|
| 278 |
fig
|
| 279 |
+
return ACTIVATIONS, MODEL, PATTERN, a, fig, heatmap_figure
|
| 280 |
|
| 281 |
|
| 282 |
@app.cell
|