import torch import torch.nn as nn from transformers import PreTrainedModel from .configuration_alpha3d import Alpha3DConfig class Alpha3DModel(PreTrainedModel): config_class = Alpha3DConfig def __init__(self, config): super().__init__(config) self.config = config # ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ # Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально self.net = nn.Sequential( # Слой 1 nn.Linear(5, 128), nn.BatchNorm1d(128), nn.ReLU(), # Слой 2 nn.Linear(128, 512), nn.BatchNorm1d(512), nn.ReLU(), # Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!) nn.Linear(512, 1024), nn.ReLU(), # Выход nn.Linear(1024, config.num_points * 6) ) def forward(self, x): out = self.net(x) return out.view(-1, self.config.num_points, 6)