Abdullah-Nazhat commited on
Commit
a444e08
·
verified ·
1 Parent(s): 44213b2

Update smallformer.py

Browse files
Files changed (1) hide show
  1. smallformer.py +20 -2
smallformer.py CHANGED
@@ -5,6 +5,24 @@ import math
5
 
6
 
7
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
8
 
9
 
10
 
@@ -13,14 +31,14 @@ class LocalMappingUnit(nn.Module):
13
  super().__init__()
14
 
15
 
16
- self.mapping = nn.PReLU(dim)
17
  self.norm = nn.LayerNorm(dim,elementwise_affine=False)
18
 
19
 
20
  def forward(self, x):
21
 
22
  x = self.norm(x)
23
- x = self.mapping(x[-1])
24
 
25
  return x
26
 
 
5
 
6
 
7
 
8
+ class MLP(nn.Module):
9
+
10
+ def __init__(self,dim):
11
+ super().__init__()
12
+ self.proj_1 = nn.Linear(dim,dim,bias=False)
13
+ self.proj_2 = nn.Linear(dim,dim,bias=False)
14
+ self.gelu = nn.GELU()
15
+
16
+
17
+ def forward(self, x):
18
+
19
+ x = self.proj_1(x)
20
+ x = self.gelu(x)
21
+ x = self.proj_2(x)
22
+
23
+
24
+ return x
25
+
26
 
27
 
28
 
 
31
  super().__init__()
32
 
33
 
34
+ self.mapping = MLP(dim)
35
  self.norm = nn.LayerNorm(dim,elementwise_affine=False)
36
 
37
 
38
  def forward(self, x):
39
 
40
  x = self.norm(x)
41
+ x = self.mapping(x)
42
 
43
  return x
44