| import torch | |
| import torch.nn as nn | |
| from transformers import PreTrainedModel | |
| from .configuration_alpha3d import Alpha3DConfig | |
| class Alpha3DModel(PreTrainedModel): | |
| config_class = Alpha3DConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.config = config | |
| # ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ | |
| # Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально | |
| self.net = nn.Sequential( | |
| # Слой 1 | |
| nn.Linear(5, 128), | |
| nn.BatchNorm1d(128), | |
| nn.ReLU(), | |
| # Слой 2 | |
| nn.Linear(128, 512), | |
| nn.BatchNorm1d(512), | |
| nn.ReLU(), | |
| # Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!) | |
| nn.Linear(512, 1024), | |
| nn.ReLU(), | |
| # Выход | |
| nn.Linear(1024, config.num_points * 6) | |
| ) | |
| def forward(self, x): | |
| out = self.net(x) | |
| return out.view(-1, self.config.num_points, 6) |