File size: 2,870 Bytes
6ab54d2
 
 
 
 
 
 
 
 
ea63f8b
6ab54d2
 
 
 
 
 
ea63f8b
6ab54d2
 
ea63f8b
6ab54d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea63f8b
6ab54d2
ea63f8b
6ab54d2
ea63f8b
6ab54d2
 
 
 
ea63f8b
6ab54d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ea63f8b
6ab54d2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
import torch
from torch import nn, Tensor



class VecDyT(nn.Module):
    def __init__(self, input_shape):
    
        super().__init__()
                   
        self.alpha = nn.Parameter(torch.randn(input_shape))
       
    def forward(self, x):
        x = torch.tanh(self.alpha * x)
        return x
        
        
class GatingUnit(nn.Module):
    def __init__(self,dim):
        
        super().__init__()

        self.proj_1 =  nn.Linear(dim,dim,bias=False)
        self.proj_2 =  nn.Linear(dim,dim,bias=False)
            
        self.gelu = nn.GELU()
       
             	   
    def forward(self, x):

        u, v = x, x 
        u = self.proj_1(u)
        u = self.gelu(u)
        v = self.proj_2(v)
        g = u * v
        
        return g

class TTT(nn.Module):   
    def __init__(self, dim: int):
        
        super(TTT, self).__init__()
            
       
        self.mapping = nn.Linear(dim,dim,bias=False)
        self.State =  nn.Linear(dim,dim,bias=False)
        self.Probe =  nn.Linear(dim,dim,bias=False)
               
       
    def forward(self, in_seq: Tensor) -> Tensor:

       
        outs = []
        
        for seq in range(in_seq.size(1)):
            
            state = self.State(in_seq[:,seq,:])
            train_view = state + torch.randn_like(state)
            label_view = state
            loss = nn.functional.mse_loss(self.mapping(train_view), label_view)
            grads = torch.autograd.grad(
                loss, self.mapping.parameters(),create_graph=True)
            with torch.no_grad():
                for param, grad in zip(self.mapping.parameters(), grads):
              
                    param -= 0.01 * grad
            readout = self.mapping(self.Probe(in_seq[:,seq,:])).detach()
            outs.append(readout)
        out = torch.stack(outs, dim=1)
        
        return out 
            


class TensorMapperBlock(nn.Module):
    def __init__(self, dim, num_patch):
        
        super().__init__()
        
        self.norm_1 =  VecDyT(dim) 
        self.norm_2 =  VecDyT(dim)    
        self.memory = TTT(dim)            
        self.feedforward = GatingUnit(dim)
        

    def forward(self, x):
        
        
    
        residual = x
    
        x = self.norm_1(x)    
    
        x = self.memory(x)
        
        x = x + residual
        
        residual = x
        
        x = self.norm_2(x)
                          
        x = self.feedforward(x)
                          
        x = x + residual

        return x


class TensorMapper(nn.Module):
    def __init__(self, d_model,num_patch, num_layers):
        super().__init__()
        
        self.model = nn.Sequential(
            *[TensorMapperBlock(d_model,num_patch) for _ in range(num_layers)]
        )

    def forward(self, x):
       
        return self.model(x)