File size: 7,540 Bytes
402ca49
 
 
 
 
 
 
 
d3d0ded
 
bd374cf
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
dbae999
bd374cf
dbae999
 
bd374cf
dbae999
bd374cf
 
d3d0ded
 
 
2604e09
d3d0ded
d61992c
d3d0ded
 
 
769e80b
 
d3d0ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769e80b
 
d3d0ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769e80b
d3d0ded
769e80b
d3d0ded
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
769e80b
d3d0ded
 
 
bd374cf
d3d0ded
 
 
 
2604e09
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
---
license: mit
tags:
- image-classification
- pytorch
- tiny-imagenet
---

# LookThem: Ratio-based Attention

# Results
## MNIST
- 11 Epoch training, train accuracy: 99.02%
- Test accuracy: 98.66%
## CIFAR-10
- 10 Epoch training, train accuracy: 67.89%
- Test accuracy (10 epoch): 73.42%
- 40 Epoch training, train accuracy: 76.63%
- Test accuracy: 79.79%
## Tiny-ImageNet
- 15 Epoch training, train accuracy: 32.07%
- Test accuracy (15 epoch): ? (the result are lying)
- 30 Epoch training, train accuracy: 37.20%
- Test accuracy: ? (the results are lying)
- File size: 5.72MB
## Tiny-ImageNet 2 (V5)
- 20 Epoch training, train accuracy (with dropout): 36.98%
- Test accuracy: 34.2%
- 40 Epoch training, Train Top-1 Accuracy (no dropout): 59.76%
- Train Top-5 Accuracy: 83.41%
- Test accuracy: 35.46%
- Test Top-5 Accuracy: 61.87%
- Code in second and third notebook

# Explanation
I was courious, what if a token look at other tokens, but without QKV? Instead, it's like make transformation for two tokens (current token and another token), then divide them. Let's say current token is token A and another token is token B. It's divide like "transformA(A) / transformB(B)" which the transform is a linear NN. With tanh for normalizing (to make it don't explode). And, the reverse ("transformA(B) / transformB(A)"). Then, the result of "transformA(A) / transformB(B)" multiply with A, and the reverse multiply with B. Then add them, then divide by 2. That's the new number for that interaction. Add to temp variable. Loop again for another token interaction (but for the code it's vectorized). Then, that variable averaged. That's the new A.

Then I try it for MNIST (with LookThem arch with a layer), in just few epoch, I get astonishing results. Because of that, I get deeper and try for CIFAR-10, with similar architecture. And the results is good too. Because of that, I get deeper to Tiny-ImageNet. The result is.. I don't know, the notebook's results is not for evaluation result (the AI changed the code). Maybe not 100% accuracy, but at least is can learn, with even less memory in disk (just ~5MB). That's the results for you all.

There's many space to experimenting like deeper architecture, another activation function, etc. But without big train parameter tunes, it's reach impressive performance (for it's size).. correct me if I wrong about that. So, for everyone who have bigger resources, you all can experimenting with this architecture. I train it on Google Colab's T4, and code generated by Gemini 3 Flash (except for original code).

# Code

## Base Code
```
import torch
import torch.nn as nn

class LookThem(nn.Module):
    def __init__(self):
        super(LookThem, self).__init__()

        # Menggunakan ModuleList seperti arsitektur awalmu
        self.mod1 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
        self.mod2 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
        self.transform = nn.ModuleList([nn.Linear(1, 1) for _ in range(5)])

        self.mlp = nn.Sequential(
            nn.Linear(5, 10),
            nn.ReLU(),
            nn.Linear(10, 5)
        )

    def forward(self, x):
        # x shape: [batch, 5, 1]
        batch_size = x.size(0)
        new_x = []

        for i in range(5):
            # Inisialisasi tensor output dengan shape [batch, 1] agar konsisten
            out_i = torch.zeros((batch_size, 1), device=x.device)
            count = 0

            for j in range(5):
                if i == j:
                    continue  # Lewati jika indeks sama, interaksi hanya untuk token berbeda

                # Tambahkan epsilon (1e-7) untuk mencegah pembagian dengan nol (ZeroDivisionError)
                out_mod2_j = self.mod2[j](x[:, j]) + 1e-7
                out_mod2_i = self.mod2[i](x[:, i]) + 1e-7

                compare = torch.tanh(self.mod1[i](x[:, i]) / out_mod2_j)
                compare2 = torch.tanh(self.mod1[j](x[:, j]) / out_mod2_i)

                # Transformasi hasil interaksi
                interaksi = (self.transform[j](compare) * x[:, i] + self.transform[j](compare2) * x[:, j]) / 2
                out_i += interaksi
                count += 1

            # Cari rata-rata interaksi (dibagi jumlah token yang berinteraksi, yaitu 4)
            out_i = out_i / count

            # Kombinasikan nilai asli (self-identity) dengan nilai interaksi
            # Misal dengan cara dijumlahkan
            #total_fitur_i = x[:, i] + out_i
            new_x.append(out_i)

        # Menggabungkan kembali menjadi [batch, 5]
        x_new = torch.cat(new_x, dim=1)

        return self.mlp(x_new)
```
## Vectorized
```
class LookThemLayer(nn.Module):
    def __init__(self, num_tokens, in_features, hidden_dim):
        super(LookThemLayer, self).__init__()
        self.num_tokens = num_tokens
        self.in_features = in_features

        # Batched Parameters (Vectorized)
        self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
        self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
        self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
        self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))

        self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
        self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
        self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
        self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))

        self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
        self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
        self._init_weights()

    def _init_weights(self):
        for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]:
            nn.init.kaiming_uniform_(w, a=math.sqrt(5))

    def forward(self, x):
        N = self.num_tokens
        
        # 1. Einstein Summation Projections
        h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
        out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2

        h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
        out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2

        # 2. Rasio Kontras + Tanh
        out_m2_safe = out_m2 + 1e-7
        compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
        compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))

        # 3. Spatial J Transformations
        bias_reshaped = self.trans_b.view(1, 1, N, 1)
        trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
        trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped

        # 4. Contextual Interaction
        interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
        
        # 5. Masking Self-Bias (i == j)
        mask = 1.0 - torch.eye(N, device=x.device)
        interaksi_masked = interaksi * mask.view(1, N, N, 1)

        return interaksi_masked.sum(dim=2) / (N - 1.0)
```

## Colab notebook in this repo



More detail in notebook

# Reaction
I don't believe this simple architecture can achieve this performance, and the  realization that this LookThem architecture born from.. "just try that".. So for you all, thanks for reading this Spontaneous Paper (A raw paper for those who are too lazy to polish it).