Spaces:
Sleeping
Sleeping
| """ | |
| Test-Time Training (TTT) Layer. | |
| The hidden state is a machine learning model. | |
| The update rule is a step of self-supervised learning. | |
| Based on: https://arxiv.org/abs/2407.04620 | |
| """ | |
| import copy | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as functional | |
| from torch import Tensor | |
| class TTTLayer(nn.Module): | |
| """ | |
| Test-Time Training layer. | |
| The hidden state is itself a learnable model that updates | |
| via gradient descent during the forward pass. | |
| """ | |
| def __init__(self, dim: int, variant: str = "linear"): | |
| """ | |
| Initialize TTT layer. | |
| Args: | |
| dim: Input/output dimension | |
| variant: "linear" for TTT-Linear, "mlp" for TTT-MLP | |
| """ | |
| super().__init__() | |
| self.dim = dim | |
| self.variant = variant | |
| self.hidden_model: nn.Module | |
| if variant == "linear": | |
| # TTT-Linear: Hidden state is a linear model | |
| self.hidden_model = nn.Linear(dim, dim, bias=False) | |
| elif variant == "mlp": | |
| # TTT-MLP: Hidden state is a two-layer MLP | |
| self.hidden_model = nn.Sequential( | |
| nn.Linear(dim, dim * 4), | |
| nn.ReLU(), | |
| nn.Linear(dim * 4, dim), | |
| ) | |
| else: | |
| raise ValueError(f"Unknown variant: {variant}. Use 'linear' or 'mlp'.") | |
| # Project input to key/value for self-supervised learning | |
| self.to_kv = nn.Linear(dim, dim * 2) | |
| # Learnable learning rate | |
| self.eta = nn.Parameter(torch.tensor(0.1)) | |
| def forward(self, x: Tensor) -> Tensor: | |
| """ | |
| Process sequence with test-time training. | |
| Args: | |
| x: Input tensor [batch, seq_len, dim] | |
| Returns: | |
| Output tensor [batch, seq_len, dim] | |
| """ | |
| _batch, seq_len, _dim = x.shape | |
| # Clone hidden model for this sequence (mini-batch gradient descent) | |
| hidden_state = copy.deepcopy(self.hidden_model) | |
| outputs = [] | |
| for t in range(seq_len): | |
| # Current token | |
| x_t = x[:, t : t + 1, :] | |
| # Self-supervised target: reconstruct from key-value | |
| kv = self.to_kv(x_t) | |
| _k, v = kv.chunk(2, dim=-1) | |
| # Forward through hidden state | |
| y_t = hidden_state(x_t) | |
| # Compute loss and update hidden state | |
| loss = functional.mse_loss(y_t, v) | |
| # Compute gradients | |
| grads = torch.autograd.grad(loss, list(hidden_state.parameters()), create_graph=False) | |
| # Update hidden state weights | |
| with torch.no_grad(): | |
| for param, grad in zip(hidden_state.parameters(), grads): | |
| param -= self.eta * grad | |
| outputs.append(y_t.detach()) | |
| return torch.cat(outputs, dim=1) | |
| class TTTLinear(TTTLayer): | |
| """TTT-Linear: Hidden state is a linear model (faster).""" | |
| def __init__(self, dim: int): | |
| super().__init__(dim, variant="linear") | |
| class TTTMLP(TTTLayer): | |
| """TTT-MLP: Hidden state is a two-layer MLP (more expressive).""" | |
| def __init__(self, dim: int): | |
| super().__init__(dim, variant="mlp") | |