rasayan-tox21 / src /model.py
root
Initial commit: Rasayan Tox21 SNN Ensemble
0024d0e
import torch
import torch.nn as nn
class Tox21SNN(nn.Module):
def __init__(self, in_features, hidden_dim=768, n_layers=8, dropout=0.05):
super().__init__()
self.in_features = in_features
self.hidden_dim = hidden_dim
self.n_layers = n_layers
activation = nn.SELU()
drop = nn.AlphaDropout(p=dropout)
dims = [hidden_dim] * (n_layers + 1)
dims[0] = in_features
dims[-1] = 12
layers = []
for i in range(n_layers + 1):
in_dim = dims[i]
out_dim = dims[-1] if i == n_layers else dims[i + 1]
fc = nn.Linear(in_dim, out_dim)
if i < n_layers:
layers.extend([fc, activation, drop])
else:
layers.append(fc)
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)