File size: 7,540 Bytes
402ca49 d3d0ded bd374cf dbae999 bd374cf dbae999 bd374cf dbae999 bd374cf d3d0ded 2604e09 d3d0ded d61992c d3d0ded 769e80b d3d0ded 769e80b d3d0ded 769e80b d3d0ded 769e80b d3d0ded 769e80b d3d0ded bd374cf d3d0ded 2604e09 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | ---
license: mit
tags:
- image-classification
- pytorch
- tiny-imagenet
---
# LookThem: Ratio-based Attention
# Results
## MNIST
- 11 Epoch training, train accuracy: 99.02%
- Test accuracy: 98.66%
## CIFAR-10
- 10 Epoch training, train accuracy: 67.89%
- Test accuracy (10 epoch): 73.42%
- 40 Epoch training, train accuracy: 76.63%
- Test accuracy: 79.79%
## Tiny-ImageNet
- 15 Epoch training, train accuracy: 32.07%
- Test accuracy (15 epoch): ? (the result are lying)
- 30 Epoch training, train accuracy: 37.20%
- Test accuracy: ? (the results are lying)
- File size: 5.72MB
## Tiny-ImageNet 2 (V5)
- 20 Epoch training, train accuracy (with dropout): 36.98%
- Test accuracy: 34.2%
- 40 Epoch training, Train Top-1 Accuracy (no dropout): 59.76%
- Train Top-5 Accuracy: 83.41%
- Test accuracy: 35.46%
- Test Top-5 Accuracy: 61.87%
- Code in second and third notebook
# Explanation
I was courious, what if a token look at other tokens, but without QKV? Instead, it's like make transformation for two tokens (current token and another token), then divide them. Let's say current token is token A and another token is token B. It's divide like "transformA(A) / transformB(B)" which the transform is a linear NN. With tanh for normalizing (to make it don't explode). And, the reverse ("transformA(B) / transformB(A)"). Then, the result of "transformA(A) / transformB(B)" multiply with A, and the reverse multiply with B. Then add them, then divide by 2. That's the new number for that interaction. Add to temp variable. Loop again for another token interaction (but for the code it's vectorized). Then, that variable averaged. That's the new A.
Then I try it for MNIST (with LookThem arch with a layer), in just few epoch, I get astonishing results. Because of that, I get deeper and try for CIFAR-10, with similar architecture. And the results is good too. Because of that, I get deeper to Tiny-ImageNet. The result is.. I don't know, the notebook's results is not for evaluation result (the AI changed the code). Maybe not 100% accuracy, but at least is can learn, with even less memory in disk (just ~5MB). That's the results for you all.
There's many space to experimenting like deeper architecture, another activation function, etc. But without big train parameter tunes, it's reach impressive performance (for it's size).. correct me if I wrong about that. So, for everyone who have bigger resources, you all can experimenting with this architecture. I train it on Google Colab's T4, and code generated by Gemini 3 Flash (except for original code).
# Code
## Base Code
```
import torch
import torch.nn as nn
class LookThem(nn.Module):
def __init__(self):
super(LookThem, self).__init__()
# Menggunakan ModuleList seperti arsitektur awalmu
self.mod1 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
self.mod2 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
self.transform = nn.ModuleList([nn.Linear(1, 1) for _ in range(5)])
self.mlp = nn.Sequential(
nn.Linear(5, 10),
nn.ReLU(),
nn.Linear(10, 5)
)
def forward(self, x):
# x shape: [batch, 5, 1]
batch_size = x.size(0)
new_x = []
for i in range(5):
# Inisialisasi tensor output dengan shape [batch, 1] agar konsisten
out_i = torch.zeros((batch_size, 1), device=x.device)
count = 0
for j in range(5):
if i == j:
continue # Lewati jika indeks sama, interaksi hanya untuk token berbeda
# Tambahkan epsilon (1e-7) untuk mencegah pembagian dengan nol (ZeroDivisionError)
out_mod2_j = self.mod2[j](x[:, j]) + 1e-7
out_mod2_i = self.mod2[i](x[:, i]) + 1e-7
compare = torch.tanh(self.mod1[i](x[:, i]) / out_mod2_j)
compare2 = torch.tanh(self.mod1[j](x[:, j]) / out_mod2_i)
# Transformasi hasil interaksi
interaksi = (self.transform[j](compare) * x[:, i] + self.transform[j](compare2) * x[:, j]) / 2
out_i += interaksi
count += 1
# Cari rata-rata interaksi (dibagi jumlah token yang berinteraksi, yaitu 4)
out_i = out_i / count
# Kombinasikan nilai asli (self-identity) dengan nilai interaksi
# Misal dengan cara dijumlahkan
#total_fitur_i = x[:, i] + out_i
new_x.append(out_i)
# Menggabungkan kembali menjadi [batch, 5]
x_new = torch.cat(new_x, dim=1)
return self.mlp(x_new)
```
## Vectorized
```
class LookThemLayer(nn.Module):
def __init__(self, num_tokens, in_features, hidden_dim):
super(LookThemLayer, self).__init__()
self.num_tokens = num_tokens
self.in_features = in_features
# Batched Parameters (Vectorized)
self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
self._init_weights()
def _init_weights(self):
for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]:
nn.init.kaiming_uniform_(w, a=math.sqrt(5))
def forward(self, x):
N = self.num_tokens
# 1. Einstein Summation Projections
h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2
h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2
# 2. Rasio Kontras + Tanh
out_m2_safe = out_m2 + 1e-7
compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
# 3. Spatial J Transformations
bias_reshaped = self.trans_b.view(1, 1, N, 1)
trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped
# 4. Contextual Interaction
interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
# 5. Masking Self-Bias (i == j)
mask = 1.0 - torch.eye(N, device=x.device)
interaksi_masked = interaksi * mask.view(1, N, N, 1)
return interaksi_masked.sum(dim=2) / (N - 1.0)
```
## Colab notebook in this repo
More detail in notebook
# Reaction
I don't believe this simple architecture can achieve this performance, and the realization that this LookThem architecture born from.. "just try that".. So for you all, thanks for reading this Spontaneous Paper (A raw paper for those who are too lazy to polish it). |