prostochel097 commited on
Commit
e01b0f3
·
verified ·
1 Parent(s): 5d9dac3

Update modeling_alpha3d.py

Browse files
Files changed (1) hide show
  1. modeling_alpha3d.py +19 -12
modeling_alpha3d.py CHANGED
@@ -10,19 +10,26 @@ class Alpha3DModel(PreTrainedModel):
10
  super().__init__(config)
11
  self.config = config
12
 
13
- layers = []
14
- input_dim = 5
15
-
16
- prev_dim = input_dim
17
- for h_dim in config.hidden_dims:
18
- layers.append(nn.Linear(prev_dim, h_dim))
19
- layers.append(nn.BatchNorm1d(h_dim))
20
- layers.append(nn.ReLU())
21
- prev_dim = h_dim
22
 
23
- layers.append(nn.Linear(prev_dim, config.num_points * 6))
24
-
25
- self.net = nn.Sequential(*layers)
 
 
 
 
 
 
 
 
 
26
 
27
  def forward(self, x):
28
  out = self.net(x)
 
10
  super().__init__(config)
11
  self.config = config
12
 
13
+ # ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ
14
+ # Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально
15
+ self.net = nn.Sequential(
16
+ # Слой 1
17
+ nn.Linear(5, 128),
18
+ nn.BatchNorm1d(128),
19
+ nn.ReLU(),
 
 
20
 
21
+ # Слой 2
22
+ nn.Linear(128, 512),
23
+ nn.BatchNorm1d(512),
24
+ nn.ReLU(),
25
+
26
+ # Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!)
27
+ nn.Linear(512, 1024),
28
+ nn.ReLU(),
29
+
30
+ # Выход
31
+ nn.Linear(1024, config.num_points * 6)
32
+ )
33
 
34
  def forward(self, x):
35
  out = self.net(x)