Abdullah-Nazhat commited on
Commit
bf679b1
·
verified ·
1 Parent(s): 4ed4315

Update ttt_glu_LayerNorm.py

Browse files
Files changed (1) hide show
  1. ttt_glu_LayerNorm.py +2 -2
ttt_glu_LayerNorm.py CHANGED
@@ -87,8 +87,8 @@ class TTT(nn.Module):
87
  for param, grad in zip(self.mapping.parameters(), grads):
88
 
89
  param -= 0.01 * grad
90
- probe = self.probe(self.mapping(in_seq[:,seq,:]).detach())
91
- outs.append(probe)
92
  out = torch.stack(outs, dim=1)
93
 
94
  return out
 
87
  for param, grad in zip(self.mapping.parameters(), grads):
88
 
89
  param -= 0.01 * grad
90
+ readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
91
+ outs.append(readout)
92
  out = torch.stack(outs, dim=1)
93
 
94
  return out