Update ttt_mlp_LayerNorm.py
Browse files- ttt_mlp_LayerNorm.py +2 -2
ttt_mlp_LayerNorm.py
CHANGED
|
@@ -77,8 +77,8 @@ class TTT(nn.Module):
|
|
| 77 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 78 |
|
| 79 |
param -= 0.01 * grad
|
| 80 |
-
|
| 81 |
-
outs.append(
|
| 82 |
out = torch.stack(outs, dim=1)
|
| 83 |
|
| 84 |
return out
|
|
|
|
| 77 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 78 |
|
| 79 |
param -= 0.01 * grad
|
| 80 |
+
readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
|
| 81 |
+
outs.append(readout)
|
| 82 |
out = torch.stack(outs, dim=1)
|
| 83 |
|
| 84 |
return out
|