Abdullah-Nazhat commited on
Commit
075a642
·
verified ·
1 Parent(s): d4deac8

Update interactor.py

Browse files
Files changed (1) hide show
  1. interactor.py +13 -13
interactor.py CHANGED
@@ -14,36 +14,39 @@ class MemoryUnit(nn.Module):
14
  self.proj_1 = nn.Linear(dim,dim)
15
  self.proj_2 = nn.Linear(dim,dim)
16
  self.proj_3 = nn.Linear(dim,dim)
 
17
 
18
  def forward(self, x):
19
 
20
  x = self.norm_token(x)
21
  u, v = x, x
22
  u = self.proj_1(u)
23
- u = self.norm_token(u)
24
  v = self.proj_2(v)
25
  g = u * v
26
  x = self.proj_3(g)
27
- x = self.norm_token(x)
28
 
29
  return x
30
 
31
  class InteractionUnit(nn.Module):
32
  def __init__(self,dim,score_dim):
33
  super().__init__()
34
-
35
-
36
  self.norm_token = nn.LayerNorm(dim)
37
- self.norm_score = nn.LayerNorm(score_dim)
38
 
39
  def forward(self, x):
40
 
41
  x = self.norm_token(x)
42
- q,k,v = x,x,x
43
- score = torch.matmul(q, k.transpose(-1, -2))
44
- interaction = self.norm_score(score)
45
- x = torch.matmul(interaction,v)
46
- x = self.norm_token(x)
 
 
 
47
 
48
  return x
49
 
@@ -91,6 +94,3 @@ class Interactor(nn.Module):
91
 
92
 
93
 
94
-
95
-
96
-
 
14
  self.proj_1 = nn.Linear(dim,dim)
15
  self.proj_2 = nn.Linear(dim,dim)
16
  self.proj_3 = nn.Linear(dim,dim)
17
+ self.gelu = nn.GELU()
18
 
19
  def forward(self, x):
20
 
21
  x = self.norm_token(x)
22
  u, v = x, x
23
  u = self.proj_1(u)
24
+ u = self.gelu(u)
25
  v = self.proj_2(v)
26
  g = u * v
27
  x = self.proj_3(g)
28
+
29
 
30
  return x
31
 
32
  class InteractionUnit(nn.Module):
33
  def __init__(self,dim,score_dim):
34
  super().__init__()
35
+
 
36
  self.norm_token = nn.LayerNorm(dim)
37
+ self.gelu = nn.GELU()
38
 
39
  def forward(self, x):
40
 
41
  x = self.norm_token(x)
42
+ dim0 = x.shape[0]
43
+ dim1 = x.shape[1]
44
+ dim2 = x.shape[2]
45
+ x = x.reshape([dim0,dim1*dim2])
46
+ x = self.gelu(x)
47
+ x = x.reshape([dim0,dim1,dim2])
48
+
49
+
50
 
51
  return x
52
 
 
94
 
95
 
96