File size: 1,120 Bytes
a51fd90 5d9dac3 a51fd90 e01b0f3 a51fd90 e01b0f3 a51fd90 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 | import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_alpha3d import Alpha3DConfig
class Alpha3DModel(PreTrainedModel):
config_class = Alpha3DConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ
# Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально
self.net = nn.Sequential(
# Слой 1
nn.Linear(5, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
# Слой 2
nn.Linear(128, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
# Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!)
nn.Linear(512, 1024),
nn.ReLU(),
# Выход
nn.Linear(1024, config.num_points * 6)
)
def forward(self, x):
out = self.net(x)
return out.view(-1, self.config.num_points, 6) |