Update interactor.py
Browse files- interactor.py +13 -13
interactor.py
CHANGED
|
@@ -14,36 +14,39 @@ class MemoryUnit(nn.Module):
|
|
| 14 |
self.proj_1 = nn.Linear(dim,dim)
|
| 15 |
self.proj_2 = nn.Linear(dim,dim)
|
| 16 |
self.proj_3 = nn.Linear(dim,dim)
|
|
|
|
| 17 |
|
| 18 |
def forward(self, x):
|
| 19 |
|
| 20 |
x = self.norm_token(x)
|
| 21 |
u, v = x, x
|
| 22 |
u = self.proj_1(u)
|
| 23 |
-
u = self.
|
| 24 |
v = self.proj_2(v)
|
| 25 |
g = u * v
|
| 26 |
x = self.proj_3(g)
|
| 27 |
-
|
| 28 |
|
| 29 |
return x
|
| 30 |
|
| 31 |
class InteractionUnit(nn.Module):
|
| 32 |
def __init__(self,dim,score_dim):
|
| 33 |
super().__init__()
|
| 34 |
-
|
| 35 |
-
|
| 36 |
self.norm_token = nn.LayerNorm(dim)
|
| 37 |
-
self.
|
| 38 |
|
| 39 |
def forward(self, x):
|
| 40 |
|
| 41 |
x = self.norm_token(x)
|
| 42 |
-
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
x =
|
| 46 |
-
x = self.
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
return x
|
| 49 |
|
|
@@ -91,6 +94,3 @@ class Interactor(nn.Module):
|
|
| 91 |
|
| 92 |
|
| 93 |
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
|
|
|
|
| 14 |
self.proj_1 = nn.Linear(dim,dim)
|
| 15 |
self.proj_2 = nn.Linear(dim,dim)
|
| 16 |
self.proj_3 = nn.Linear(dim,dim)
|
| 17 |
+
self.gelu = nn.GELU()
|
| 18 |
|
| 19 |
def forward(self, x):
|
| 20 |
|
| 21 |
x = self.norm_token(x)
|
| 22 |
u, v = x, x
|
| 23 |
u = self.proj_1(u)
|
| 24 |
+
u = self.gelu(u)
|
| 25 |
v = self.proj_2(v)
|
| 26 |
g = u * v
|
| 27 |
x = self.proj_3(g)
|
| 28 |
+
|
| 29 |
|
| 30 |
return x
|
| 31 |
|
| 32 |
class InteractionUnit(nn.Module):
|
| 33 |
def __init__(self,dim,score_dim):
|
| 34 |
super().__init__()
|
| 35 |
+
|
|
|
|
| 36 |
self.norm_token = nn.LayerNorm(dim)
|
| 37 |
+
self.gelu = nn.GELU()
|
| 38 |
|
| 39 |
def forward(self, x):
|
| 40 |
|
| 41 |
x = self.norm_token(x)
|
| 42 |
+
dim0 = x.shape[0]
|
| 43 |
+
dim1 = x.shape[1]
|
| 44 |
+
dim2 = x.shape[2]
|
| 45 |
+
x = x.reshape([dim0,dim1*dim2])
|
| 46 |
+
x = self.gelu(x)
|
| 47 |
+
x = x.reshape([dim0,dim1,dim2])
|
| 48 |
+
|
| 49 |
+
|
| 50 |
|
| 51 |
return x
|
| 52 |
|
|
|
|
| 94 |
|
| 95 |
|
| 96 |
|
|
|
|
|
|
|
|
|