import torch.nn as nn class SimpleMLP(nn.Module): def __init__(self, input_dim, hidden_dims, n_classes, dropout=0.1): super().__init__() layers = [] dims = [input_dim] + hidden_dims for in_d, out_d in zip(dims[:-1], dims[1:]): layers.append(nn.Linear(in_d, out_d)) layers.append(nn.ReLU()) #layers.append(nn.Dropout(dropout)) layers.append(nn.Linear(dims[-1], n_classes)) self.model = nn.Sequential(*layers) def forward(self, x): return self.model(x)