seredapj commited on
Commit
1f8b419
·
verified ·
1 Parent(s): c853eef

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -3
app.py CHANGED
@@ -218,13 +218,13 @@ def __(dropdown_method, focus_square, rollout, switch, torch):
218
  a = act[layer_key][0, :, ::-1 , :]
219
  elif dropdown_method.value == "Attention rollout (MIN)":
220
  a = rollout(act, skip_connection=switch.value, parse="min")
221
- a = torch.stack([a for _ in range(32)])
222
  elif dropdown_method.value == "Attention rollout (MAX)":
223
  a = rollout(act, skip_connection=switch.value, parse="max")
224
- a = torch.stack([a for _ in range(32)])
225
  elif dropdown_method.value == "Attention rollout (MEAN)":
226
  a = rollout(act, skip_connection=switch.value, parse="mean")
227
- a = torch.stack([a for _ in range(32)])
228
  return a
229
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
230
 
 
218
  a = act[layer_key][0, :, ::-1 , :]
219
  elif dropdown_method.value == "Attention rollout (MIN)":
220
  a = rollout(act, skip_connection=switch.value, parse="min")
221
+ a = torch.stack([a for _ in range(32)]).numpy()
222
  elif dropdown_method.value == "Attention rollout (MAX)":
223
  a = rollout(act, skip_connection=switch.value, parse="max")
224
+ a = torch.stack([a for _ in range(32)]).numpy()
225
  elif dropdown_method.value == "Attention rollout (MEAN)":
226
  a = rollout(act, skip_connection=switch.value, parse="mean")
227
+ a = torch.stack([a for _ in range(32)]).numpy()
228
  return a
229
  focus_square_ind = 8 * (int(focus_square.value[1]) - 1) + ord(focus_square.value[0]) - ord("a")
230