Update modeling_alpha3d.py
Browse files- modeling_alpha3d.py +19 -12
modeling_alpha3d.py
CHANGED
|
@@ -10,19 +10,26 @@ class Alpha3DModel(PreTrainedModel):
|
|
| 10 |
super().__init__(config)
|
| 11 |
self.config = config
|
| 12 |
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
layers.append(nn.ReLU())
|
| 21 |
-
prev_dim = h_dim
|
| 22 |
|
| 23 |
-
|
| 24 |
-
|
| 25 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
def forward(self, x):
|
| 28 |
out = self.net(x)
|
|
|
|
| 10 |
super().__init__(config)
|
| 11 |
self.config = config
|
| 12 |
|
| 13 |
+
# ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ
|
| 14 |
+
# Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально
|
| 15 |
+
self.net = nn.Sequential(
|
| 16 |
+
# Слой 1
|
| 17 |
+
nn.Linear(5, 128),
|
| 18 |
+
nn.BatchNorm1d(128),
|
| 19 |
+
nn.ReLU(),
|
|
|
|
|
|
|
| 20 |
|
| 21 |
+
# Слой 2
|
| 22 |
+
nn.Linear(128, 512),
|
| 23 |
+
nn.BatchNorm1d(512),
|
| 24 |
+
nn.ReLU(),
|
| 25 |
+
|
| 26 |
+
# Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!)
|
| 27 |
+
nn.Linear(512, 1024),
|
| 28 |
+
nn.ReLU(),
|
| 29 |
+
|
| 30 |
+
# Выход
|
| 31 |
+
nn.Linear(1024, config.num_points * 6)
|
| 32 |
+
)
|
| 33 |
|
| 34 |
def forward(self, x):
|
| 35 |
out = self.net(x)
|