Update ttt_glu_vecDyT.py
Browse files- ttt_glu_vecDyT.py +2 -2
ttt_glu_vecDyT.py
CHANGED
|
@@ -99,8 +99,8 @@ class TTT(nn.Module):
|
|
| 99 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 100 |
|
| 101 |
param -= 0.01 * grad
|
| 102 |
-
|
| 103 |
-
outs.append(
|
| 104 |
out = torch.stack(outs, dim=1)
|
| 105 |
|
| 106 |
return out
|
|
|
|
| 99 |
for param, grad in zip(self.mapping.parameters(), grads):
|
| 100 |
|
| 101 |
param -= 0.01 * grad
|
| 102 |
+
readout = self.mapping(self.probe(in_seq[:,seq,:])).detach()
|
| 103 |
+
outs.append(readout)
|
| 104 |
out = torch.stack(outs, dim=1)
|
| 105 |
|
| 106 |
return out
|