Update interactor.py
Browse files- interactor.py +3 -3
interactor.py
CHANGED
|
@@ -51,7 +51,7 @@ class InteractionUnit(nn.Module):
|
|
| 51 |
return x
|
| 52 |
|
| 53 |
class InteractorBlock(nn.Module):
|
| 54 |
-
def __init__(self, d_model
|
| 55 |
super().__init__()
|
| 56 |
|
| 57 |
|
|
@@ -79,11 +79,11 @@ class InteractorBlock(nn.Module):
|
|
| 79 |
|
| 80 |
|
| 81 |
class Interactor(nn.Module):
|
| 82 |
-
def __init__(self, d_model,
|
| 83 |
super().__init__()
|
| 84 |
|
| 85 |
self.model = nn.Sequential(
|
| 86 |
-
*[InteractorBlock(d_model
|
| 87 |
)
|
| 88 |
|
| 89 |
def forward(self, x):
|
|
|
|
| 51 |
return x
|
| 52 |
|
| 53 |
class InteractorBlock(nn.Module):
|
| 54 |
+
def __init__(self, d_model):
|
| 55 |
super().__init__()
|
| 56 |
|
| 57 |
|
|
|
|
| 79 |
|
| 80 |
|
| 81 |
class Interactor(nn.Module):
|
| 82 |
+
def __init__(self, d_model, num_layers):
|
| 83 |
super().__init__()
|
| 84 |
|
| 85 |
self.model = nn.Sequential(
|
| 86 |
+
*[InteractorBlock(d_model) for _ in range(num_layers)]
|
| 87 |
)
|
| 88 |
|
| 89 |
def forward(self, x):
|