Update ttt_glu_LayerNorm.py
Browse files- ttt_glu_LayerNorm.py +2 -2
ttt_glu_LayerNorm.py
CHANGED
|
@@ -87,8 +87,8 @@ class TTT(nn.Module):
|
|
| 87 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 88 |
|
| 89 |
param -= 0.01 * grad
|
| 90 |
-
|
| 91 |
-
outs.append(
|
| 92 |
out = torch.stack(outs, dim=1)
|
| 93 |
|
| 94 |
return out
|
|
|
|
| 87 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 88 |
|
| 89 |
param -= 0.01 * grad
|
| 90 |
+
readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
|
| 91 |
+
outs.append(readout)
|
| 92 |
out = torch.stack(outs, dim=1)
|
| 93 |
|
| 94 |
return out
|