Update ttt_mlp_vecDyT.py
Browse files- ttt_mlp_vecDyT.py +2 -2
ttt_mlp_vecDyT.py
CHANGED
|
@@ -93,8 +93,8 @@ class TTT(nn.Module):
|
|
| 93 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 94 |
|
| 95 |
param -= 0.01 * grad
|
| 96 |
-
|
| 97 |
-
outs.append(
|
| 98 |
out = torch.stack(outs, dim=1)
|
| 99 |
|
| 100 |
return out
|
|
|
|
| 93 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 94 |
|
| 95 |
param -= 0.01 * grad
|
| 96 |
+
readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
|
| 97 |
+
outs.append(readout)
|
| 98 |
out = torch.stack(outs, dim=1)
|
| 99 |
|
| 100 |
return out
|