docker-neural-memory / src /memory /ttt_layer.py
macayaven's picture
Upload folder using huggingface_hub
dd41762 verified
Raw
History Blame Contribute Delete
3.17 kB
"""
Test-Time Training (TTT) Layer.
The hidden state is a machine learning model.
The update rule is a step of self-supervised learning.
Based on: https://arxiv.org/abs/2407.04620
"""
import copy
import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch import Tensor
class TTTLayer(nn.Module):
"""
Test-Time Training layer.
The hidden state is itself a learnable model that updates
via gradient descent during the forward pass.
"""
def __init__(self, dim: int, variant: str = "linear"):
"""
Initialize TTT layer.
Args:
dim: Input/output dimension
variant: "linear" for TTT-Linear, "mlp" for TTT-MLP
"""
super().__init__()
self.dim = dim
self.variant = variant
self.hidden_model: nn.Module
if variant == "linear":
# TTT-Linear: Hidden state is a linear model
self.hidden_model = nn.Linear(dim, dim, bias=False)
elif variant == "mlp":
# TTT-MLP: Hidden state is a two-layer MLP
self.hidden_model = nn.Sequential(
nn.Linear(dim, dim * 4),
nn.ReLU(),
nn.Linear(dim * 4, dim),
)
else:
raise ValueError(f"Unknown variant: {variant}. Use 'linear' or 'mlp'.")
# Project input to key/value for self-supervised learning
self.to_kv = nn.Linear(dim, dim * 2)
# Learnable learning rate
self.eta = nn.Parameter(torch.tensor(0.1))
def forward(self, x: Tensor) -> Tensor:
"""
Process sequence with test-time training.
Args:
x: Input tensor [batch, seq_len, dim]
Returns:
Output tensor [batch, seq_len, dim]
"""
_batch, seq_len, _dim = x.shape
# Clone hidden model for this sequence (mini-batch gradient descent)
hidden_state = copy.deepcopy(self.hidden_model)
outputs = []
for t in range(seq_len):
# Current token
x_t = x[:, t : t + 1, :]
# Self-supervised target: reconstruct from key-value
kv = self.to_kv(x_t)
_k, v = kv.chunk(2, dim=-1)
# Forward through hidden state
y_t = hidden_state(x_t)
# Compute loss and update hidden state
loss = functional.mse_loss(y_t, v)
# Compute gradients
grads = torch.autograd.grad(loss, list(hidden_state.parameters()), create_graph=False)
# Update hidden state weights
with torch.no_grad():
for param, grad in zip(hidden_state.parameters(), grads):
param -= self.eta * grad
outputs.append(y_t.detach())
return torch.cat(outputs, dim=1)
class TTTLinear(TTTLayer):
"""TTT-Linear: Hidden state is a linear model (faster)."""
def __init__(self, dim: int):
super().__init__(dim, variant="linear")
class TTTMLP(TTTLayer):
"""TTT-MLP: Hidden state is a two-layer MLP (more expressive)."""
def __init__(self, dim: int):
super().__init__(dim, variant="mlp")