alpha3D-v1 / modeling_alpha3d.py
prostochel097's picture
Update modeling_alpha3d.py
e01b0f3 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from .configuration_alpha3d import Alpha3DConfig
class Alpha3DModel(PreTrainedModel):
config_class = Alpha3DConfig
def __init__(self, config):
super().__init__(config)
self.config = config
# ТОЧНАЯ КОПИЯ ТОГО, ЧТО БЫЛО ПРИ ОБУЧЕНИИ
# Без циклов, чтобы индексы (0, 1, ... 8) совпали идеально
self.net = nn.Sequential(
# Слой 1
nn.Linear(5, 128),
nn.BatchNorm1d(128),
nn.ReLU(),
# Слой 2
nn.Linear(128, 512),
nn.BatchNorm1d(512),
nn.ReLU(),
# Слой 3 (Тут НЕ БЫЛО BatchNorm при обучении!)
nn.Linear(512, 1024),
nn.ReLU(),
# Выход
nn.Linear(1024, config.num_points * 6)
)
def forward(self, x):
out = self.net(x)
return out.view(-1, self.config.num_points, 6)