File size: 995 Bytes
98e2ea5
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
import torch.nn as nn


class MLP(nn.Module):
    def __init__(
        self,
        input_size,
        hidden_size,
        output_size,
        num_hidden_layers=1,
        bias=False,
        drop_module=None,
    ):
        super(MLP, self).__init__()
        self.layer_list = []

        self.activation = nn.ReLU()
        self.drop_module = drop_module
        self.num_hidden_layers = num_hidden_layers

        cur_output_size = input_size
        for i in range(num_hidden_layers):
            self.layer_list.append(nn.Linear(cur_output_size, hidden_size, bias=bias))
            self.layer_list.append(self.activation)
            if self.drop_module is not None:
                self.layer_list.append(self.drop_module)
            cur_output_size = hidden_size

        self.layer_list.append(nn.Linear(cur_output_size, output_size, bias=bias))
        self.fc_layers = nn.Sequential(*self.layer_list)

    def forward(self, mlp_input):
        return self.fc_layers(mlp_input)