ASomeoneWhoInterestedWithAI commited on
Commit
d3d0ded
·
verified ·
1 Parent(s): db5e8ba

Upload README.md

Browse files
Files changed (1) hide show
  1. README.md +236 -0
README.md ADDED
@@ -0,0 +1,236 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # LookThem: Ratio-based Attention
2
+
3
+ # Explanation
4
+ I was courious, what if a token look at other tokens, but without QKV? Instead, it's like make transformation for two tokens (current token and another token), then divide them. Let's say current token is token A and another token is token B. It's divide like "transformA(A) / transformB(B)" which the transform is a linear NN. With tanh for normalizing (to make it don't explode). And, the reverse ("transformA(B) / transformB(A)"). Then, the result of "transformA(A) / transformB(B)" multiply with A, and the reverse multiply with B. Then add them, then divide by 2. That's the new number for that interaction. Add to temp variable. Loop again for another token interaction (but for the code it's vectorized). Then, that variable averaged. That's the new A.
5
+
6
+ Then I try it for MNIST (with LookThem arch with a layer), in just few epoch, I get astonishing results. Because of that, I get deeper and try for CIFAR-10, with similar architecture. And the results is good too. Because of that, I get deeper to Tiny-ImageNet. The result is.. around 50%. That's not 100% accuracy, but at least is can compete with old Tiny-ImageNet architecture, with even less memory in disk (just ~5MB). That's the results for you all.
7
+
8
+ There's many space to experimenting like deeper architecture, another activation function, etc. But without big train parameter tunes, it's reach SOTA (from scratch category).. correct me if I wrong about SOTA. So, for everyone who have bigger resources, you all can experimenting with this architecture. I train it on Google Colab's T4, and code generated by Gemini 3 Flash (except for original code).
9
+
10
+ # Code
11
+
12
+ ## Original Code
13
+ import torch
14
+ import torch.nn as nn
15
+
16
+ class LookThem(nn.Module):
17
+ def __init__(self):
18
+ super(LookThem, self).__init__()
19
+
20
+ # Menggunakan ModuleList seperti arsitektur awalmu
21
+ self.mod1 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
22
+ self.mod2 = nn.ModuleList([nn.Sequential(nn.Linear(1, 5), nn.ReLU(), nn.Linear(5, 1)) for _ in range(5)])
23
+ self.transform = nn.ModuleList([nn.Linear(1, 1) for _ in range(5)])
24
+
25
+ self.mlp = nn.Sequential(
26
+ nn.Linear(5, 10),
27
+ nn.ReLU(),
28
+ nn.Linear(10, 5)
29
+ )
30
+
31
+ def forward(self, x):
32
+ # x shape: [batch, 5, 1]
33
+ batch_size = x.size(0)
34
+ new_x = []
35
+
36
+ for i in range(5):
37
+ # Inisialisasi tensor output dengan shape [batch, 1] agar konsisten
38
+ out_i = torch.zeros((batch_size, 1), device=x.device)
39
+ count = 0
40
+
41
+ for j in range(5):
42
+ if i == j:
43
+ continue # Lewati jika indeks sama, interaksi hanya untuk token berbeda
44
+
45
+ # Tambahkan epsilon (1e-7) untuk mencegah pembagian dengan nol (ZeroDivisionError)
46
+ out_mod2_j = self.mod2[j](x[:, j]) + 1e-7
47
+ out_mod2_i = self.mod2[i](x[:, i]) + 1e-7
48
+
49
+ compare = self.mod1[i](x[:, i]) / out_mod2_j
50
+ compare2 = self.mod1[j](x[:, j]) / out_mod2_i
51
+
52
+ # Transformasi hasil interaksi
53
+ interaksi = (self.transform[j](compare) * x[:, i] + self.transform[j](compare2) * x[:, j]) / 2
54
+ out_i += interaksi
55
+ count += 1
56
+
57
+ # Cari rata-rata interaksi (dibagi jumlah token yang berinteraksi, yaitu 4)
58
+ out_i = out_i / count
59
+
60
+ # Kombinasikan nilai asli (self-identity) dengan nilai interaksi
61
+ # Misal dengan cara dijumlahkan
62
+ #total_fitur_i = x[:, i] + out_i
63
+ new_x.append(out_i)
64
+
65
+ # Menggabungkan kembali menjadi [batch, 5]
66
+ x_new = torch.cat(new_x, dim=1)
67
+
68
+ return self.mlp(x_new)
69
+
70
+ ## Vectorized
71
+ import torch
72
+ import torch.nn as nn
73
+ import math
74
+
75
+ class LookThemVectorized(nn.Module):
76
+ def __init__(self, num_tokens=5, in_features=1, hidden_dim=5):
77
+ super(LookThemVectorized, self).__init__()
78
+
79
+ self.num_tokens = num_tokens
80
+ self.in_features = in_features
81
+ self.hidden_dim = hidden_dim
82
+
83
+ # 1. Batched Parameters untuk Mod1
84
+ # Shape: [num_tokens, in_features, hidden_dim]
85
+ self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
86
+ self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
87
+ # Shape: [num_tokens, hidden_dim, 1]
88
+ self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
89
+ self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
90
+
91
+ # 2. Batched Parameters untuk Mod2
92
+ self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
93
+ self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
94
+ self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
95
+ self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
96
+
97
+ # 3. Batched Parameters untuk Transformasi Linear j
98
+ self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
99
+ self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
100
+
101
+ # 4. MLP Final disesuaikan dengan jumlah token yang dinamis
102
+ self.mlp = nn.Sequential(
103
+ nn.Linear(num_tokens, num_tokens * 2),
104
+ nn.ReLU(),
105
+ nn.Linear(num_tokens * 2, num_tokens)
106
+ )
107
+
108
+ self._init_weights()
109
+
110
+ def _init_weights(self):
111
+ # Inisialisasi Kaiming Uniform agar training stabil
112
+ for w in [self.mod1_w1, self.mod2_w1]:
113
+ nn.init.kaiming_uniform_(w, a=math.sqrt(5))
114
+ for w in [self.mod1_w2, self.mod2_w2, self.trans_w]:
115
+ nn.init.kaiming_uniform_(w, a=math.sqrt(5))
116
+
117
+ def forward(self, x):
118
+ # x shape sekarang: [Batch, num_tokens, in_features]
119
+ batch_size = x.size(0)
120
+ N = self.num_tokens
121
+
122
+ # 1. Jalankan Mod1 dan Mod2 secara paralel untuk semua token
123
+ h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
124
+ out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2 # [Batch, N, 1]
125
+
126
+ h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
127
+ out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2 # [Batch, N, 1]
128
+
129
+ # 2. Hitung Rasio Kombinasi i dan j via Broadcasting
130
+ out_m2_safe = out_m2 + 1e-7
131
+ compare = out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1) # [Batch, N, N, 1]
132
+ compare2 = out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2) # [Batch, N, N, 1]
133
+
134
+ # 3. Transformasikan hasil berdasar indeks j
135
+ # View khusus untuk bias agar nge-broadcast pas di koordinat j
136
+ bias_reshaped = self.trans_b.view(1, 1, N, 1)
137
+ trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
138
+ trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped
139
+
140
+ # 4. Hitung Interaksi Berbobot Fitur menggunakan fitur asli dari x
141
+ # x.unsqueeze(2) -> fitur token i, x.unsqueeze(1) -> fitur token j
142
+ interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2 # [Batch, N, N, in_features]
143
+
144
+ # 5. Buat Masking untuk mengabaikan Diri Sendiri (i == j)
145
+ mask = 1.0 - torch.eye(N, device=x.device)
146
+ interaksi_masked = interaksi * mask.view(1, N, N, 1) # Sesuai ukuran matriks interaksi
147
+
148
+ # 6. Rata-ratakan interaksi (dibagi N - 1 karena diri sendiri di-skip)
149
+ # Kita lakukan sum pada dimensi j (dim=2), lalu dirata-rata ke dimensi fitur terdalam
150
+ out_i = interaksi_masked.sum(dim=2) / (N - 1.0) # [Batch, N, in_features]
151
+
152
+ # 7. Siapkan tensor untuk masuk ke MLP final
153
+ # Kita rata-ratakan dimensi in_features agar menjadi [Batch, N] sebelum masuk MLP
154
+ x_new = out_i.mean(dim=-1)
155
+
156
+ return self.mlp(x_new)
157
+
158
+ ## Enhanced code (used in Tiny-ImageNet training)
159
+
160
+ class LookThemLayer(nn.Module):
161
+ def __init__(self, num_tokens, in_features, hidden_dim):
162
+ super(LookThemLayer, self).__init__()
163
+ self.num_tokens = num_tokens
164
+ self.in_features = in_features
165
+
166
+ # Batched Parameters (Vectorized)
167
+ self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
168
+ self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
169
+ self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
170
+ self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
171
+
172
+ self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
173
+ self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
174
+ self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
175
+ self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
176
+
177
+ self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
178
+ self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
179
+ self._init_weights()
180
+
181
+ def _init_weights(self):
182
+ for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2, self.trans_w]:
183
+ nn.init.kaiming_uniform_(w, a=math.sqrt(5))
184
+
185
+ def forward(self, x):
186
+ N = self.num_tokens
187
+
188
+ # 1. Einstein Summation Projections
189
+ h1 = torch.einsum('bti,tij->btj', x, self.mod1_w1) + self.mod1_b1
190
+ out_m1 = torch.einsum('btj,tjk->btk', torch.relu(h1), self.mod1_w2) + self.mod1_b2
191
+
192
+ h2 = torch.einsum('bti,tij->btj', x, self.mod2_w1) + self.mod2_b1
193
+ out_m2 = torch.einsum('btj,tjk->btk', torch.relu(h2), self.mod2_w2) + self.mod2_b2
194
+
195
+ # 2. Rasio Kontras + Tanh
196
+ out_m2_safe = out_m2 + 1e-7
197
+ compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
198
+ compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
199
+
200
+ # 3. Spatial J Transformations
201
+ bias_reshaped = self.trans_b.view(1, 1, N, 1)
202
+ trans_compare = torch.einsum('bije,jef->bijf', compare, self.trans_w) + bias_reshaped
203
+ trans_compare2 = torch.einsum('bije,jef->bijf', compare2, self.trans_w) + bias_reshaped
204
+
205
+ # 4. Contextual Interaction
206
+ interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
207
+
208
+ # 5. Masking Self-Bias (i == j)
209
+ mask = 1.0 - torch.eye(N, device=x.device)
210
+ interaksi_masked = interaksi * mask.view(1, N, N, 1)
211
+
212
+ return interaksi_masked.sum(dim=2) / (N - 1.0)
213
+
214
+
215
+ ## Colab notebook in this repo
216
+
217
+ # Results
218
+ ## MNIST
219
+ - 11 Epoch training, train accuracy: 99.02%
220
+ - Test accuracy: 98.66%
221
+ ## CIFAR-10
222
+ - 10 Epoch training, train accuracy: 67.89%
223
+ - Test accuracy (10 epoch): 73.42%
224
+ - 40 Epoch training, train accuracy: 76.63%
225
+ - Test accuracy: 79.79%
226
+ ## Tiny-ImageNet
227
+ - 15 Epoch training, train accuracy: 32.07%
228
+ - Test accuracy (15 epoch): 42.01%
229
+ - 30 Epoch training, train accuracy: 37.20%
230
+ - Test accuracy: 50.53%
231
+ - File size: 5.72MB
232
+
233
+ More detail in notebook
234
+
235
+ # Reaction
236
+ I don't believe this simple architecture can achieve ResNet-34 performance, and the realization that this LookThem architecture born from.. "just try that".. So for you all, thanks for reading this Spontaneous Paper (A raw paper for those who are too lazy to polish it).