File size: 1,827 Bytes
a35137b
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
from typing import List

import torch.nn as nn

from barista.models.utils import get_activation_function


class MLP(nn.Module):
    def __init__(
        self,
        d_input: int,
        d_out: int,
        layer_list: List = None,
        dropout: float = 0.1,
        bias: bool = True,
        use_first_dropout: bool = True,
        use_final_dropout: bool = False,
        use_final_activation: bool = False,
        activation: str = "linear",
        use_identity_stub: bool = True,
        **kwargs
    ):
        super(MLP, self).__init__()

        self.d_input = d_input
        self.d_out = d_out
        self.layer_list = layer_list
        self.dropout = dropout
        self.use_first_dropout = use_first_dropout
        self.use_final_dropout = use_final_dropout
        self.use_final_activation = use_final_activation
        self.activation_fn = get_activation_function(activation)

        current_dim = self.d_input
        self.layers = nn.ModuleList()
        if self.layer_list is not None:
            for _, dim in enumerate(self.layer_list):
                self.layers.append(nn.Linear(current_dim, dim, bias=bias))
                current_dim = dim
        else:
            if use_identity_stub:
                self.layers.append(nn.Identity())

        self.final_layer = nn.Linear(current_dim, self.d_out, bias=bias)

    def forward(self, x, *args, **kwargs):
        if self.use_first_dropout:
            x = nn.Dropout(self.dropout)(x)
        for layer in self.layers:
            x = layer(x)
            x = self.activation_fn(x)
            x = nn.Dropout(self.dropout)(x)
        x = self.final_layer(x)
        if self.use_final_activation:
            x = self.activation_fn(x)
        if self.use_final_dropout:
            x = nn.Dropout(self.dropout)(x)
        return x