File size: 3,171 Bytes
dd41762
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
"""
Test-Time Training (TTT) Layer.

The hidden state is a machine learning model.
The update rule is a step of self-supervised learning.

Based on: https://arxiv.org/abs/2407.04620
"""

import copy

import torch
import torch.nn as nn
import torch.nn.functional as functional
from torch import Tensor


class TTTLayer(nn.Module):
    """
    Test-Time Training layer.

    The hidden state is itself a learnable model that updates
    via gradient descent during the forward pass.
    """

    def __init__(self, dim: int, variant: str = "linear"):
        """
        Initialize TTT layer.

        Args:
            dim: Input/output dimension
            variant: "linear" for TTT-Linear, "mlp" for TTT-MLP
        """
        super().__init__()
        self.dim = dim
        self.variant = variant
        self.hidden_model: nn.Module

        if variant == "linear":
            # TTT-Linear: Hidden state is a linear model
            self.hidden_model = nn.Linear(dim, dim, bias=False)
        elif variant == "mlp":
            # TTT-MLP: Hidden state is a two-layer MLP
            self.hidden_model = nn.Sequential(
                nn.Linear(dim, dim * 4),
                nn.ReLU(),
                nn.Linear(dim * 4, dim),
            )
        else:
            raise ValueError(f"Unknown variant: {variant}. Use 'linear' or 'mlp'.")

        # Project input to key/value for self-supervised learning
        self.to_kv = nn.Linear(dim, dim * 2)

        # Learnable learning rate
        self.eta = nn.Parameter(torch.tensor(0.1))

    def forward(self, x: Tensor) -> Tensor:
        """
        Process sequence with test-time training.

        Args:
            x: Input tensor [batch, seq_len, dim]

        Returns:
            Output tensor [batch, seq_len, dim]
        """
        _batch, seq_len, _dim = x.shape

        # Clone hidden model for this sequence (mini-batch gradient descent)
        hidden_state = copy.deepcopy(self.hidden_model)

        outputs = []
        for t in range(seq_len):
            # Current token
            x_t = x[:, t : t + 1, :]

            # Self-supervised target: reconstruct from key-value
            kv = self.to_kv(x_t)
            _k, v = kv.chunk(2, dim=-1)

            # Forward through hidden state
            y_t = hidden_state(x_t)

            # Compute loss and update hidden state
            loss = functional.mse_loss(y_t, v)

            # Compute gradients
            grads = torch.autograd.grad(loss, list(hidden_state.parameters()), create_graph=False)

            # Update hidden state weights
            with torch.no_grad():
                for param, grad in zip(hidden_state.parameters(), grads):
                    param -= self.eta * grad

            outputs.append(y_t.detach())

        return torch.cat(outputs, dim=1)


class TTTLinear(TTTLayer):
    """TTT-Linear: Hidden state is a linear model (faster)."""

    def __init__(self, dim: int):
        super().__init__(dim, variant="linear")


class TTTMLP(TTTLayer):
    """TTT-MLP: Hidden state is a two-layer MLP (more expressive)."""

    def __init__(self, dim: int):
        super().__init__(dim, variant="mlp")