import torch import torch.nn as nn class Tox21SNN(nn.Module): def __init__(self, in_features, hidden_dim=768, n_layers=8, dropout=0.05): super().__init__() self.in_features = in_features self.hidden_dim = hidden_dim self.n_layers = n_layers activation = nn.SELU() drop = nn.AlphaDropout(p=dropout) dims = [hidden_dim] * (n_layers + 1) dims[0] = in_features dims[-1] = 12 layers = [] for i in range(n_layers + 1): in_dim = dims[i] out_dim = dims[-1] if i == n_layers else dims[i + 1] fc = nn.Linear(in_dim, out_dim) if i < n_layers: layers.extend([fc, activation, drop]) else: layers.append(fc) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)