| --- |
| license: mit |
| tags: |
| - image-classification |
| - pytorch |
| - tiny-imagenet |
| --- |
| |
| # LookThem: Ratio-based Attention |
|
|
| # Results |
| ## MNIST |
| - 11 Epoch training, train accuracy: 99.02% |
| - Test accuracy: 98.66% |
| ## CIFAR-10 |
| - 10 Epoch training, train accuracy: 67.89% |
| - Test accuracy (10 epoch): 73.42% |
| - 40 Epoch training, train accuracy: 76.63% |
| - Test accuracy: 79.79% |
| ## Tiny-ImageNet |
| - 15 Epoch training, train accuracy: 32.07% |
| - Test accuracy (15 epoch): ? (the result are lying) |
| - 30 Epoch training, train accuracy: 37.20% |
| - Test accuracy: ? (the results are lying) |
| - File size: 5.72MB |
| ## Tiny-ImageNet 2 (V5) |
| - 20 Epoch training, train accuracy: 36.98% |
| - Test accuracy: 34.2% |
| - 40 Epoch training, train accuracy: 46.58% |
| - Test accuracy: 35.46% |
| - Code in second and third notebook |
|
|
| # Explanation |
| I was courious, what if a token look at other tokens, but without QKV? Instead, it's like make transformation for two tokens (current token and another token), then divide them. Let's say current token is token A and another token is token B. It's divide like "transformA(A) / transformB(B)" which the transform is a linear NN. With tanh for normalizing (to make it don't explode). And, the reverse ("transformA(B) / transformB(A)"). Then, the result of "transformA(A) / transformB(B)" multiply with A, and the reverse multiply with B. Then add them, then divide by 2. That's the new number for that interaction. Add to temp variable. Loop again for another token interaction (but for the code it's vectorized). Then, that variable averaged. That's the new A. |
|
|
| Then I try it for MNIST (with LookThem arch with a layer), in just few epoch, I get astonishing results. Because of that, I get deeper and try for CIFAR-10, with similar architecture. And the results is good too. Because of that, I get deeper to Tiny-ImageNet. The result is.. I don't know, the notebook's results is not for evaluation result (the AI changed the code). Maybe not 100% accuracy, but at least is can learn, with even less memory in disk (just ~5MB). That's the results for you all. |
|
|
| There's many space to experimenting like deeper architecture, another activation function, etc. But without big train parameter tunes, it's reach impressive performance (for it's size).. correct me if I wrong about that. So, for everyone who have bigger resources, you all can experimenting with this architecture. I train it on Google Colab's T4, and code generated by Gemini 3 Flash (except for original code). |
|
|
| # Code |
|
|
| ## Base Code |
| ``` |
| import torch |
| import torch.nn as nn |
| |
| class LookThem(nn.Module): |
| def __init__(self): |
| super(LookThem, self).__init__() |
| |
| # Menggunakan ModuleList seperti arsitektur awalmu |
| self.mod1 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)]) |
| self.mod2 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)]) |
| self.transform = nn.ModuleList([nn.Linear(1, 1) for _ in range(5)]) |
| |
| self.mlp = nn.Sequential( |
| nn.Linear(5, 10), |
| nn.ReLU(), |
| nn.Linear(10, 5) |
| ) |
| |
| def forward(self, x): |
| # x shape: [batch, 5, 1] |
| batch_size = x.size(0) |
| new_x = [] |
| |
| for i in range(5): |
| # Inisialisasi tensor output dengan shape [batch, 1] agar konsisten |
| out_i = torch.zeros((batch_size, 1), device=x.device) |
| count = 0 |
| |
| for j in range(5): |
| if i == j: |
| continue # Lewati jika indeks sama, interaksi hanya untuk token berbeda |
| |
| # Tambahkan epsilon (1e-7) untuk mencegah pembagian dengan nol (ZeroDivisionError) |
| out_mod2_j = self.mod2[j](x[:, j]) + 1e-7 |
| out_mod2_i = self.mod2[i](x[:, i]) + 1e-7 |
| |
| compare = torch.tanh(self.mod1[i](x[:, i]) / out_mod2_j) |
| compare2 = torch.tanh(self.mod1[j](x[:, j]) / out_mod2_i) |
| |
| # Transformasi hasil interaksi |
| interaksi = (self.transform[j](compare) * x[:, i] + self.transform[j](compare2) * x[:, j]) / 2 |
| out_i += interaksi |
| count += 1 |
| |
| # Cari rata-rata interaksi (dibagi jumlah token yang berinteraksi, yaitu 4) |
| out_i = out_i / count |
| |
| # Kombinasikan nilai asli (self-identity) dengan nilai interaksi |
| # Misal dengan cara dijumlahkan |
| #total_fitur_i = x[:, i] + out_i |
| new_x.append(out_i) |
| |
| # Menggabungkan kembali menjadi [batch, 5] |
| x_new = torch.cat(new_x, dim=1) |
| |
| return self.mlp(x_new) |
| ``` |
| ## Vectorized |
| ``` |
| class LookThemLayer(nn.Module): |
| def __init__(self, num_tokens, in_features, hidden_dim): |
| super(LookThemLayer, self).__init__() |
| self.num_tokens = num_tokens |
| self.in_features = in_features |
| |
| # Batched Parameters (Vectorized) |
| self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| |
| self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim)) |
| self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim)) |
| self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1)) |
| self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1)) |
| |
| self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1)) |
| self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1)) |
| self._init_weights() |
| |
| def _init_weights(self): |
| for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]: |
| nn.init.kaiming_uniform_(w, a=math.sqrt(5)) |
| |
| def forward(self, x): |
| N = self.num_tokens |
| |
| # 1. Einstein Summation Projections |
| h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1 |
| out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2 |
| |
| h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1 |
| out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2 |
| |
| # 2. Rasio Kontras + Tanh |
| out_m2_safe = out_m2 + 1e-7 |
| compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1)) |
| compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2)) |
| |
| # 3. Spatial J Transformations |
| bias_reshaped = self.trans_b.view(1, 1, N, 1) |
| trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped |
| trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped |
| |
| # 4. Contextual Interaction |
| interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 |
| |
| # 5. Masking Self-Bias (i == j) |
| mask = 1.0 - torch.eye(N, device=x.device) |
| interaksi_masked = interaksi * mask.view(1, N, N, 1) |
| |
| return interaksi_masked.sum(dim=2) / (N - 1.0) |
| ``` |
|
|
| ## Colab notebook in this repo |
|
|
|
|
|
|
| More detail in notebook |
|
|
| # Reaction |
| I don't believe this simple architecture can achieve this performance, and the realization that this LookThem architecture born from.. "just try that".. So for you all, thanks for reading this Spontaneous Paper (A raw paper for those who are too lazy to polish it). |