File size: 2,870 Bytes
6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 ea63f8b 6ab54d2 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 |
import torch
from torch import nn, Tensor
class VecDyT(nn.Module):
def __init__(self, input_shape):
super().__init__()
self.alpha = nn.Parameter(torch.randn(input_shape))
def forward(self, x):
x = torch.tanh(self.alpha * x)
return x
class GatingUnit(nn.Module):
def __init__(self,dim):
super().__init__()
self.proj_1 = nn.Linear(dim,dim,bias=False)
self.proj_2 = nn.Linear(dim,dim,bias=False)
self.gelu = nn.GELU()
def forward(self, x):
u, v = x, x
u = self.proj_1(u)
u = self.gelu(u)
v = self.proj_2(v)
g = u * v
return g
class TTT(nn.Module):
def __init__(self, dim: int):
super(TTT, self).__init__()
self.mapping = nn.Linear(dim,dim,bias=False)
self.State = nn.Linear(dim,dim,bias=False)
self.Probe = nn.Linear(dim,dim,bias=False)
def forward(self, in_seq: Tensor) -> Tensor:
outs = []
for seq in range(in_seq.size(1)):
state = self.State(in_seq[:,seq,:])
train_view = state + torch.randn_like(state)
label_view = state
loss = nn.functional.mse_loss(self.mapping(train_view), label_view)
grads = torch.autograd.grad(
loss, self.mapping.parameters(),create_graph=True)
with torch.no_grad():
for param, grad in zip(self.mapping.parameters(), grads):
param -= 0.01 * grad
readout = self.mapping(self.Probe(in_seq[:,seq,:])).detach()
outs.append(readout)
out = torch.stack(outs, dim=1)
return out
class TensorMapperBlock(nn.Module):
def __init__(self, dim, num_patch):
super().__init__()
self.norm_1 = VecDyT(dim)
self.norm_2 = VecDyT(dim)
self.memory = TTT(dim)
self.feedforward = GatingUnit(dim)
def forward(self, x):
residual = x
x = self.norm_1(x)
x = self.memory(x)
x = x + residual
residual = x
x = self.norm_2(x)
x = self.feedforward(x)
x = x + residual
return x
class TensorMapper(nn.Module):
def __init__(self, d_model,num_patch, num_layers):
super().__init__()
self.model = nn.Sequential(
*[TensorMapperBlock(d_model,num_patch) for _ in range(num_layers)]
)
def forward(self, x):
return self.model(x)
|