File size: 553 Bytes
e703e79 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 | import torch.nn as nn
class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dims, n_classes, dropout=0.1):
super().__init__()
layers = []
dims = [input_dim] + hidden_dims
for in_d, out_d in zip(dims[:-1], dims[1:]):
layers.append(nn.Linear(in_d, out_d))
layers.append(nn.ReLU())
#layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(dims[-1], n_classes))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)
|