Abdullah-Nazhat commited on
Commit
8cbf94e
·
verified ·
1 Parent(s): 73dbe16

Update context_prelu.py

Browse files
Files changed (1) hide show
  1. context_prelu.py +23 -3
context_prelu.py CHANGED
@@ -1,7 +1,23 @@
1
  import torch
2
  from torch import nn
3
 
 
4
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
5
 
6
 
7
  class Context_PReLUBlock(nn.Module):
@@ -12,7 +28,7 @@ class Context_PReLUBlock(nn.Module):
12
 
13
  self.context_prelu = nn.PReLU(d_model * num_tokens)
14
  self.token_norm = nn.LayerNorm(d_model)
15
-
16
 
17
 
18
  def forward(self, x):
@@ -32,7 +48,12 @@ class Context_PReLUBlock(nn.Module):
32
  readout = self.context_prelu(context)
33
 
34
  x = readout.reshape([dim0,dim1,dim2])
35
-
 
 
 
 
 
36
 
37
  out = x + residual
38
 
@@ -59,4 +80,3 @@ class Context_PReLU(nn.Module):
59
 
60
 
61
 
62
-
 
1
  import torch
2
  from torch import nn
3
 
4
+ class MLP(nn.Module):
5
 
6
+ def __init__(self,dim):
7
+ super().__init__()
8
+ self.proj_1 = nn.Linear(dim,dim,bias=False)
9
+ self.proj_2 = nn.Linear(dim,dim,bias=False)
10
+ self.gelu = nn.GELU()
11
+
12
+
13
+ def forward(self, x):
14
+
15
+ x = self.proj_1(x)
16
+ x = self.gelu(x)
17
+ x = self.proj_2(x)
18
+
19
+
20
+ return x
21
 
22
 
23
  class Context_PReLUBlock(nn.Module):
 
28
 
29
  self.context_prelu = nn.PReLU(d_model * num_tokens)
30
  self.token_norm = nn.LayerNorm(d_model)
31
+ self.local_mapping = MLP(d_model)
32
 
33
 
34
  def forward(self, x):
 
48
  readout = self.context_prelu(context)
49
 
50
  x = readout.reshape([dim0,dim1,dim2])
51
+
52
+ x = x + residual
53
+
54
+ residual = x
55
+
56
+ x = self.local_mapping(x)
57
 
58
  out = x + residual
59
 
 
80
 
81
 
82