File size: 1,120 Bytes
a51fd90
 
 
5d9dac3
a51fd90
 
 
 
 
 
 
 
e01b0f3
 
 
 
 
 
 
a51fd90
e01b0f3
 
 
 
 
 
 
 
 
 
 
 
a51fd90
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_alpha3d import Alpha3DConfig 

class Alpha3DModel(PreTrainedModel):
    config_class = Alpha3DConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        
        # ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ
        # Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально
        self.net = nn.Sequential(
            # Слой 1
            nn.Linear(5, 128),
            nn.BatchNorm1d(128),
            nn.ReLU(),
            
            # Слой 2
            nn.Linear(128, 512),
            nn.BatchNorm1d(512),
            nn.ReLU(),
            
            # Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!)
            nn.Linear(512, 1024),
            nn.ReLU(),
            
            # Выход
            nn.Linear(1024, config.num_points * 6)
        )

    def forward(self, x):
        out = self.net(x)
        return out.view(-1, self.config.num_points, 6)