Abdullah-Nazhat commited on
Commit
bb1c7a0
·
verified ·
1 Parent(s): 9774165

Update ttt_mlp_vecDyT.py

Browse files
Files changed (1) hide show
  1. ttt_mlp_vecDyT.py +2 -2
ttt_mlp_vecDyT.py CHANGED
@@ -93,8 +93,8 @@ class TTT(nn.Module):
93
  for param, grad in zip(self.mapping.parameters(), grads):
94
 
95
  param -= 0.01 * grad
96
- probe = self.probe(self.mapping(in_seq[:,seq,:]).detach())
97
- outs.append(probe)
98
  out = torch.stack(outs, dim=1)
99
 
100
  return out
 
93
  for param, grad in zip(self.mapping.parameters(), grads):
94
 
95
  param -= 0.01 * grad
96
+ readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
97
+ outs.append(readout)
98
  out = torch.stack(outs, dim=1)
99
 
100
  return out