ASomeoneWhoInterestedWithAI commited on
Commit
a5137a6
·
verified ·
1 Parent(s): 2ab4f6e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -0
app.py CHANGED
@@ -18,6 +18,95 @@ if not os.path.exists(MODEL_PATH):
18
  print("Download complete!")
19
 
20
  # --- DEFINE YOUR MODEL ARCHITECTURE (TETAP SAMA) ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
21
  # ... (Salin definisi kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST Anda di sini) ...
22
 
23
  # --- LOAD WEIGHTS ON CPU/GPU ---
 
18
  print("Download complete!")
19
 
20
  # --- DEFINE YOUR MODEL ARCHITECTURE (TETAP SAMA) ---
21
+ class LookThemLayer(nn.Module):
22
+ def __init__(self, num_tokens, in_features, hidden_dim):
23
+ super().__init__()
24
+ self.num_tokens = num_tokens
25
+ self.mod1_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
26
+ self.mod1_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
27
+ self.mod1_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
28
+ self.mod1_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
29
+ self.mod2_w1 = nn.Parameter(torch.randn(num_tokens, in_features, hidden_dim))
30
+ self.mod2_b1 = nn.Parameter(torch.zeros(num_tokens, hidden_dim))
31
+ self.mod2_w2 = nn.Parameter(torch.randn(num_tokens, hidden_dim, 1))
32
+ self.mod2_b2 = nn.Parameter(torch.zeros(num_tokens, 1))
33
+ self.trans_w = nn.Parameter(torch.randn(num_tokens, 1, 1))
34
+ self.trans_b = nn.Parameter(torch.zeros(num_tokens, 1))
35
+ self._init_weights()
36
+
37
+ def _init_weights(self):
38
+ for w in [self.mod1_w1, self.mod2_w1, self.mod1_w2, self.mod2_w2]:
39
+ nn.init.xavier_uniform_(w)
40
+
41
+ def forward(self, x):
42
+ N = self.num_tokens
43
+ h1 = torch.einsum("bti,tij->btj", x, self.mod1_w1) + self.mod1_b1
44
+ out_m1 = torch.einsum("btj,tjk->btk", F.gelu(h1), self.mod1_w2) + self.mod1_b2
45
+ h2 = torch.einsum("bti,tij->btj", x, self.mod2_w1) + self.mod2_b1
46
+ out_m2 = torch.einsum("btj,tjk->btk", F.gelu(h2), self.mod2_w2) + self.mod2_b2
47
+
48
+ out_m2_safe = torch.sign(out_m2) * torch.clamp(torch.abs(out_m2), min=1e-6)
49
+ compare = torch.tanh(out_m1.unsqueeze(2) / out_m2_safe.unsqueeze(1))
50
+ compare2 = torch.tanh(out_m1.unsqueeze(1) / out_m2_safe.unsqueeze(2))
51
+
52
+ trans_compare = torch.einsum("bije,jef->bijf", compare, self.trans_w) + self.trans_b.view(1, 1, N, 1)
53
+ trans_compare2 = torch.einsum("bije,jef->bijf", compare2, self.trans_w) + self.trans_b.view(1, 1, N, 1)
54
+
55
+ interaksi = (trans_compare * x.unsqueeze(2) + trans_compare2 * x.unsqueeze(1)) / 2
56
+ mask = (1.0 - torch.eye(N, device=x.device)).view(1, N, N, 1)
57
+ return (interaksi * mask).sum(dim=2) / (N - 1.0)
58
+
59
+ class LiteResidualBlock(nn.Module):
60
+ def __init__(self, dim, dropout=0.05):
61
+ super().__init__()
62
+ self.block = nn.Sequential(nn.Linear(dim, dim), nn.GELU(), nn.Dropout(dropout), nn.Linear(dim, dim))
63
+ self.norm = nn.LayerNorm(dim)
64
+ def forward(self, x):
65
+ return self.norm(x + self.block(x))
66
+
67
+ class LookThemV8MNIST(nn.Module):
68
+ def __init__(self):
69
+ super().__init__()
70
+ self.stream_a = nn.Sequential(
71
+ nn.Conv2d(1, 4, 3, 2, 1),
72
+ nn.BatchNorm2d(4), nn.GELU(),
73
+ nn.Conv2d(4, 8, 3, 2, 1),
74
+ nn.BatchNorm2d(8), nn.GELU(),
75
+ nn.AdaptiveMaxPool2d((8, 8)))
76
+ self.stream_b = nn.Sequential(
77
+ nn.Conv2d(1, 4, 3, 1, 1),
78
+ nn.BatchNorm2d(4), nn.GELU(),
79
+ nn.Conv2d(4, 8, 3, 1, 1),
80
+ nn.BatchNorm2d(8), nn.GELU(),
81
+ nn.AdaptiveMaxPool2d((8, 8)))
82
+
83
+ self.lookthemA = LookThemLayer(64, 8, 32)
84
+ self.lookthemB = LookThemLayer(64, 8, 32)
85
+ self.lookthem_comb = LookThemLayer(64, 16, 32)
86
+ self.comb_norm = nn.LayerNorm(16)
87
+
88
+ self.FFN1 = nn.Conv1d(16, 8, 1)
89
+ self.lookthem2 = LookThemLayer(64, 8, 32)
90
+ self.FFN2 = nn.Conv1d(8, 8, 1)
91
+
92
+ self.compressor = nn.Conv1d(8, 4, 1)
93
+ self.input_proj = nn.Linear(64 * 4, 128)
94
+ self.res_blocks = nn.Sequential(LiteResidualBlock(128), LiteResidualBlock(128))
95
+ self.head = nn.Sequential(nn.Linear(128, 128), nn.GELU(), nn.Linear(128, 10))
96
+
97
+ def forward(self, x):
98
+ b = x.size(0)
99
+ fa = self.lookthemA(self.stream_a(x).view(b, 8, 64).transpose(1, 2))
100
+ fb = self.lookthemB(self.stream_b(x).view(b, 8, 64).transpose(1, 2))
101
+ x = self.comb_norm(self.lookthem_comb(torch.cat([fa, fb], dim=2)))
102
+ x = x.transpose(1, 2)
103
+ x = self.FFN1(x).transpose(1, 2)
104
+ res = x
105
+ x = self.lookthem2(x).transpose(1, 2)
106
+ x = self.FFN2(x) + res.transpose(1, 2)
107
+ x = self.compressor(x).flatten(1)
108
+ x = self.res_blocks(self.input_proj(x))
109
+ return self.head(x)
110
  # ... (Salin definisi kelas LookThemLayer, LiteResidualBlock, dan LookThemV8MNIST Anda di sini) ...
111
 
112
  # --- LOAD WEIGHTS ON CPU/GPU ---