Abdullah-Nazhat commited on
Commit
8a01495
·
verified ·
1 Parent(s): bb1c7a0

Update ttt_mlp_LayerNorm.py

Browse files
Files changed (1) hide show
  1. ttt_mlp_LayerNorm.py +2 -2
ttt_mlp_LayerNorm.py CHANGED
@@ -77,8 +77,8 @@ class TTT(nn.Module):
77
  for param, grad in zip(self.mapping.parameters(), grads):
78
 
79
  param -= 0.01 * grad
80
- probe = self.probe(self.mapping(in_seq[:,seq,:]).detach())
81
- outs.append(probe)
82
  out = torch.stack(outs, dim=1)
83
 
84
  return out
 
77
  for param, grad in zip(self.mapping.parameters(), grads):
78
 
79
  param -= 0.01 * grad
80
+ readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
81
+ outs.append(readout)
82
  out = torch.stack(outs, dim=1)
83
 
84
  return out