seredapj commited on
Commit
35bba36
·
verified ·
1 Parent(s): 1f8b419

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +1 -1
app.py CHANGED
@@ -181,7 +181,7 @@ def __(torch):
181
  item = torch.mean(v, dim=0)
182
  attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
183
  roll = torch.prod(torch.stack(attns), dim=0)
184
- return roll.numpy()
185
 
186
  return (rollout,)
187
 
 
181
  item = torch.mean(v, dim=0)
182
  attns.append(item + int(skip_connection) * torch.eye(item.shape[0]))
183
  roll = torch.prod(torch.stack(attns), dim=0)
184
+ return roll
185
 
186
  return (rollout,)
187