| import torch.nn as nn | |
| class SimpleMLP(nn.Module): | |
| def __init__(self, input_dim, hidden_dims, n_classes, dropout=0.1): | |
| super().__init__() | |
| layers = [] | |
| dims = [input_dim] + hidden_dims | |
| for in_d, out_d in zip(dims[:-1], dims[1:]): | |
| layers.append(nn.Linear(in_d, out_d)) | |
| layers.append(nn.ReLU()) | |
| #layers.append(nn.Dropout(dropout)) | |
| layers.append(nn.Linear(dims[-1], n_classes)) | |
| self.model = nn.Sequential(*layers) | |
| def forward(self, x): | |
| return self.model(x) | |