Upload README.md
Browse files
README.md
ADDED
|
@@ -0,0 +1,236 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# LookThem: Ratio-based Attention
|
| 2 |
+
|
| 3 |
+
# Explanation
|
| 4 |
+
I was courious, what if a token look at other tokens, but without QKV? Instead, it's like make transformation for two tokens (current token and another token), then divide them. Let's say current token is token A and another token is token B. It's divide like "transformA(A) / transformB(B)" which the transform is a linear NN. With tanh for normalizing (to make it don't explode). And, the reverse ("transformA(B) / transformB(A)"). Then, the result of "transformA(A) / transformB(B)" multiply with A, and the reverse multiply with B. Then add them, then divide by 2. That's the new number for that interaction. Add to temp variable. Loop again for another token interaction (but for the code it's vectorized). Then, that variable averaged. That's the new A.
|
| 5 |
+
|
| 6 |
+
Then I try it for MNIST (with LookThem arch with a layer), in just few epoch, I get astonishing results. Because of that, I get deeper and try for CIFAR-10, with similar architecture. And the results is good too. Because of that, I get deeper to Tiny-ImageNet. The result is.. around 50%. That's not 100% accuracy, but at least is can compete with old Tiny-ImageNet architecture, with even less memory in disk (just ~5MB). That's the results for you all.
|
| 7 |
+
|
| 8 |
+
There's many space to experimenting like deeper architecture, another activation function, etc. But without big train parameter tunes, it's reach SOTA (from scratch category).. correct me if I wrong about SOTA. So, for everyone who have bigger resources, you all can experimenting with this architecture. I train it on Google Colab's T4, and code generated by Gemini 3 Flash (except for original code).
|
| 9 |
+
|
| 10 |
+
# Code
|
| 11 |
+
|
| 12 |
+
## Original Code
|
| 13 |
+
import torch
|
| 14 |
+
import torch.nn as nn
|
| 15 |
+
|
| 16 |
+
class LookThem(nn.Module):
|
| 17 |
+
def __init__(self):
|
| 18 |
+
super(LookThem, self).__init__()
|
| 19 |
+
|
| 20 |
+
# Menggunakan ModuleList seperti arsitektur awalmu
|
| 21 |
+
self.mod1 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
|
| 22 |
+
self.mod2 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
|
| 23 |
+
self.transform = nn.ModuleList([nn.Linear(1, 1) for _ in range(5)])
|
| 24 |
+
|
| 25 |
+
self.mlp = nn.Sequential(
|
| 26 |
+
nn.Linear(5, 10),
|
| 27 |
+
nn.ReLU(),
|
| 28 |
+
nn.Linear(10, 5)
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
def forward(self, x):
|
| 32 |
+
# x shape: [batch, 5, 1]
|
| 33 |
+
batch_size = x.size(0)
|
| 34 |
+
new_x = []
|
| 35 |
+
|
| 36 |
+
for i in range(5):
|
| 37 |
+
# Inisialisasi tensor output dengan shape [batch, 1] agar konsisten
|
| 38 |
+
out_i = torch.zeros((batch_size, 1), device=x.device)
|
| 39 |
+
count = 0
|
| 40 |
+
|
| 41 |
+
for j in range(5):
|
| 42 |
+
if i == j:
|
| 43 |
+
continue # Lewati jika indeks sama, interaksi hanya untuk token berbeda
|
| 44 |
+
|
| 45 |
+
# Tambahkan epsilon (1e-7) untuk mencegah pembagian dengan nol (ZeroDivisionError)
|
| 46 |
+
out_mod2_j = self.mod2[j](x[:, j]) + 1e-7
|
| 47 |
+
out_mod2_i = self.mod2[i](x[:, i]) + 1e-7
|
| 48 |
+
|
| 49 |
+
compare = self.mod1[i](x[:, i]) / out_mod2_j
|
| 50 |
+
compare2 = self.mod1[j](x[:, j]) / out_mod2_i
|
| 51 |
+
|
| 52 |
+
# Transformasi hasil interaksi
|
| 53 |
+
interaksi = (self.transform[j](compare) * x[:, i] + self.transform[j](compare2) * x[:, j]) / 2
|
| 54 |
+
out_i += interaksi
|
| 55 |
+
count += 1
|
| 56 |
+
|
| 57 |
+
# Cari rata-rata interaksi (dibagi jumlah token yang berinteraksi, yaitu 4)
|
| 58 |
+
out_i = out_i / count
|
| 59 |
+
|
| 60 |
+
# Kombinasikan nilai asli (self-identity) dengan nilai interaksi
|
| 61 |
+
# Misal dengan cara dijumlahkan
|
| 62 |
+
#total_fitur_i = x[:, i] + out_i
|
| 63 |
+
new_x.append(out_i)
|
| 64 |
+
|
| 65 |
+
# Menggabungkan kembali menjadi [batch, 5]
|
| 66 |
+
x_new = torch.cat(new_x, dim=1)
|
| 67 |
+
|
| 68 |
+
return self.mlp(x_new)
|
| 69 |
+
|
| 70 |
+
## Vectorized
|
| 71 |
+
import torch
|
| 72 |
+
import torch.nn as nn
|
| 73 |
+
import math
|
| 74 |
+
|
| 75 |
+
class LookThemVectorized(nn.Module):
|
| 76 |
+
def __init__(self, num_tokens=5, in_features=1, hidden_dim=5):
|
| 77 |
+
super(LookThemVectorized, self).__init__()
|
| 78 |
+
|
| 79 |
+
self.num_tokens = num_tokens
|
| 80 |
+
self.in_features = in_features
|
| 81 |
+
self.hidden_dim = hidden_dim
|
| 82 |
+
|
| 83 |
+
# 1. Batched Parameters untuk Mod1
|
| 84 |
+
# Shape: [num_tokens, in_features, hidden_dim]
|
| 85 |
+
self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
|
| 86 |
+
self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
|
| 87 |
+
# Shape: [num_tokens, hidden_dim, 1]
|
| 88 |
+
self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
|
| 89 |
+
self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 90 |
+
|
| 91 |
+
# 2. Batched Parameters untuk Mod2
|
| 92 |
+
self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
|
| 93 |
+
self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
|
| 94 |
+
self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
|
| 95 |
+
self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 96 |
+
|
| 97 |
+
# 3. Batched Parameters untuk Transformasi Linear j
|
| 98 |
+
self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
|
| 99 |
+
self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 100 |
+
|
| 101 |
+
# 4. MLP Final disesuaikan dengan jumlah token yang dinamis
|
| 102 |
+
self.mlp = nn.Sequential(
|
| 103 |
+
nn.Linear(num_tokens, num_tokens * 2),
|
| 104 |
+
nn.ReLU(),
|
| 105 |
+
nn.Linear(num_tokens * 2, num_tokens)
|
| 106 |
+
)
|
| 107 |
+
|
| 108 |
+
self._init_weights()
|
| 109 |
+
|
| 110 |
+
def _init_weights(self):
|
| 111 |
+
# Inisialisasi Kaiming Uniform agar training stabil
|
| 112 |
+
for w in [self.mod1_w1, self.mod2_w1]:
|
| 113 |
+
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
|
| 114 |
+
for w in [self.mod1_w2, self.mod2_w2, self.trans_w]:
|
| 115 |
+
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
|
| 116 |
+
|
| 117 |
+
def forward(self, x):
|
| 118 |
+
# x shape sekarang: [Batch, num_tokens, in_features]
|
| 119 |
+
batch_size = x.size(0)
|
| 120 |
+
N = self.num_tokens
|
| 121 |
+
|
| 122 |
+
# 1. Jalankan Mod1 dan Mod2 secara paralel untuk semua token
|
| 123 |
+
h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
|
| 124 |
+
out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2 # [Batch, N, 1]
|
| 125 |
+
|
| 126 |
+
h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
|
| 127 |
+
out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2 # [Batch, N, 1]
|
| 128 |
+
|
| 129 |
+
# 2. Hitung Rasio Kombinasi i dan j via Broadcasting
|
| 130 |
+
out_m2_safe = out_m2 + 1e-7
|
| 131 |
+
compare = out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) # [Batch, N, N, 1]
|
| 132 |
+
compare2 = out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) # [Batch, N, N, 1]
|
| 133 |
+
|
| 134 |
+
# 3. Transformasikan hasil berdasar indeks j
|
| 135 |
+
# View khusus untuk bias agar nge-broadcast pas di koordinat j
|
| 136 |
+
bias_reshaped = self.trans_b.view(1, 1, N, 1)
|
| 137 |
+
trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
|
| 138 |
+
trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped
|
| 139 |
+
|
| 140 |
+
# 4. Hitung Interaksi Berbobot Fitur menggunakan fitur asli dari x
|
| 141 |
+
# x.unsqueeze(2) -> fitur token i, x.unsqueeze(1) -> fitur token j
|
| 142 |
+
interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 # [Batch, N, N, in_features]
|
| 143 |
+
|
| 144 |
+
# 5. Buat Masking untuk mengabaikan Diri Sendiri (i == j)
|
| 145 |
+
mask = 1.0 - torch.eye(N, device=x.device)
|
| 146 |
+
interaksi_masked = interaksi * mask.view(1, N, N, 1) # Sesuai ukuran matriks interaksi
|
| 147 |
+
|
| 148 |
+
# 6. Rata-ratakan interaksi (dibagi N - 1 karena diri sendiri di-skip)
|
| 149 |
+
# Kita lakukan sum pada dimensi j (dim=2), lalu dirata-rata ke dimensi fitur terdalam
|
| 150 |
+
out_i = interaksi_masked.sum(dim=2) / (N - 1.0) # [Batch, N, in_features]
|
| 151 |
+
|
| 152 |
+
# 7. Siapkan tensor untuk masuk ke MLP final
|
| 153 |
+
# Kita rata-ratakan dimensi in_features agar menjadi [Batch, N] sebelum masuk MLP
|
| 154 |
+
x_new = out_i.mean(dim=-1)
|
| 155 |
+
|
| 156 |
+
return self.mlp(x_new)
|
| 157 |
+
|
| 158 |
+
## Enhanced code (used in Tiny-ImageNet training)
|
| 159 |
+
|
| 160 |
+
class LookThemLayer(nn.Module):
|
| 161 |
+
def __init__(self, num_tokens, in_features, hidden_dim):
|
| 162 |
+
super(LookThemLayer, self).__init__()
|
| 163 |
+
self.num_tokens = num_tokens
|
| 164 |
+
self.in_features = in_features
|
| 165 |
+
|
| 166 |
+
# Batched Parameters (Vectorized)
|
| 167 |
+
self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
|
| 168 |
+
self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
|
| 169 |
+
self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
|
| 170 |
+
self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 171 |
+
|
| 172 |
+
self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
|
| 173 |
+
self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
|
| 174 |
+
self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
|
| 175 |
+
self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 176 |
+
|
| 177 |
+
self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
|
| 178 |
+
self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
|
| 179 |
+
self._init_weights()
|
| 180 |
+
|
| 181 |
+
def _init_weights(self):
|
| 182 |
+
for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]:
|
| 183 |
+
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
|
| 184 |
+
|
| 185 |
+
def forward(self, x):
|
| 186 |
+
N = self.num_tokens
|
| 187 |
+
|
| 188 |
+
# 1. Einstein Summation Projections
|
| 189 |
+
h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
|
| 190 |
+
out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2
|
| 191 |
+
|
| 192 |
+
h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
|
| 193 |
+
out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2
|
| 194 |
+
|
| 195 |
+
# 2. Rasio Kontras + Tanh
|
| 196 |
+
out_m2_safe = out_m2 + 1e-7
|
| 197 |
+
compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
|
| 198 |
+
compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
|
| 199 |
+
|
| 200 |
+
# 3. Spatial J Transformations
|
| 201 |
+
bias_reshaped = self.trans_b.view(1, 1, N, 1)
|
| 202 |
+
trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
|
| 203 |
+
trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped
|
| 204 |
+
|
| 205 |
+
# 4. Contextual Interaction
|
| 206 |
+
interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
|
| 207 |
+
|
| 208 |
+
# 5. Masking Self-Bias (i == j)
|
| 209 |
+
mask = 1.0 - torch.eye(N, device=x.device)
|
| 210 |
+
interaksi_masked = interaksi * mask.view(1, N, N, 1)
|
| 211 |
+
|
| 212 |
+
return interaksi_masked.sum(dim=2) / (N - 1.0)
|
| 213 |
+
|
| 214 |
+
|
| 215 |
+
## Colab notebook in this repo
|
| 216 |
+
|
| 217 |
+
# Results
|
| 218 |
+
## MNIST
|
| 219 |
+
- 11 Epoch training, train accuracy: 99.02%
|
| 220 |
+
- Test accuracy: 98.66%
|
| 221 |
+
## CIFAR-10
|
| 222 |
+
- 10 Epoch training, train accuracy: 67.89%
|
| 223 |
+
- Test accuracy (10 epoch): 73.42%
|
| 224 |
+
- 40 Epoch training, train accuracy: 76.63%
|
| 225 |
+
- Test accuracy: 79.79%
|
| 226 |
+
## Tiny-ImageNet
|
| 227 |
+
- 15 Epoch training, train accuracy: 32.07%
|
| 228 |
+
- Test accuracy (15 epoch): 42.01%
|
| 229 |
+
- 30 Epoch training, train accuracy: 37.20%
|
| 230 |
+
- Test accuracy: 50.53%
|
| 231 |
+
- File size: 5.72MB
|
| 232 |
+
|
| 233 |
+
More detail in notebook
|
| 234 |
+
|
| 235 |
+
# Reaction
|
| 236 |
+
I don't believe this simple architecture can achieve ResNet-34 performance, and the realization that this LookThem architecture born from.. "just try that".. So for you all, thanks for reading this Spontaneous Paper (A raw paper for those who are too lazy to polish it).
|