Abdullah-Nazhat commited on
Commit
4ed4315
·
verified ·
1 Parent(s): 8a01495

Update ttt_glu_vecDyT.py

Browse files
Files changed (1) hide show
  1. ttt_glu_vecDyT.py +2 -2
ttt_glu_vecDyT.py CHANGED
@@ -99,8 +99,8 @@ class TTT(nn.Module):
99
  for param, grad in zip(self.mapping.parameters(), grads):
100
 
101
  param -= 0.01 * grad
102
- probe = self.probe(self.mapping(in_seq[:,seq,:]).detach())
103
- outs.append(probe)
104
  out = torch.stack(outs, dim=1)
105
 
106
  return out
 
99
  for param, grad in zip(self.mapping.parameters(), grads):
100
 
101
  param -= 0.01 * grad
102
+ readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
103
+ outs.append(readout)
104
  out = torch.stack(outs, dim=1)
105
 
106
  return out