Update app.py
Browse files
app.py
CHANGED
|
@@ -218,13 +218,13 @@ def __(dropdown_method, focus_square, rollout, switch, torch):
|
|
| 218 |
a = act[layer_key][0, :, ::-1 , :]
|
| 219 |
elif dropdown_method.value == "Attention rollout (MIN)":
|
| 220 |
a = rollout(act, skip_connection=switch.value, parse="min")
|
| 221 |
-
a = torch.stack([a for _ in range(32)])
|
| 222 |
elif dropdown_method.value == "Attention rollout (MAX)":
|
| 223 |
a = rollout(act, skip_connection=switch.value, parse="max")
|
| 224 |
-
a = torch.stack([a for _ in range(32)])
|
| 225 |
elif dropdown_method.value == "Attention rollout (MEAN)":
|
| 226 |
a = rollout(act, skip_connection=switch.value, parse="mean")
|
| 227 |
-
a = torch.stack([a for _ in range(32)])
|
| 228 |
return a
|
| 229 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 230 |
|
|
|
|
| 218 |
a = act[layer_key][0, :, ::-1 , :]
|
| 219 |
elif dropdown_method.value == "Attention rollout (MIN)":
|
| 220 |
a = rollout(act, skip_connection=switch.value, parse="min")
|
| 221 |
+
a = torch.stack([a for _ in range(32)]).numpy()
|
| 222 |
elif dropdown_method.value == "Attention rollout (MAX)":
|
| 223 |
a = rollout(act, skip_connection=switch.value, parse="max")
|
| 224 |
+
a = torch.stack([a for _ in range(32)]).numpy()
|
| 225 |
elif dropdown_method.value == "Attention rollout (MEAN)":
|
| 226 |
a = rollout(act, skip_connection=switch.value, parse="mean")
|
| 227 |
+
a = torch.stack([a for _ in range(32)]).numpy()
|
| 228 |
return a
|
| 229 |
focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
|
| 230 |
|