Other
English
File size: 553 Bytes
e703e79
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
import torch.nn as nn

class SimpleMLP(nn.Module):
    def __init__(self, input_dim, hidden_dims, n_classes, dropout=0.1):
        super().__init__()
        layers = []
        dims = [input_dim] + hidden_dims
        for in_d, out_d in zip(dims[:-1], dims[1:]):
            layers.append(nn.Linear(in_d, out_d))
            layers.append(nn.ReLU())
            #layers.append(nn.Dropout(dropout))
        layers.append(nn.Linear(dims[-1], n_classes))
        self.model = nn.Sequential(*layers)

    def forward(self, x):
        return self.model(x)