from typing import List import torch.nn as nn from barista.models.utils import get_activation_function class MLP(nn.Module): def __init__( self, d_input: int, d_out: int, layer_list: List = None, dropout: float = 0.1, bias: bool = True, use_first_dropout: bool = True, use_final_dropout: bool = False, use_final_activation: bool = False, activation: str = "linear", use_identity_stub: bool = True, **kwargs ): super(MLP, self).__init__() self.d_input = d_input self.d_out = d_out self.layer_list = layer_list self.dropout = dropout self.use_first_dropout = use_first_dropout self.use_final_dropout = use_final_dropout self.use_final_activation = use_final_activation self.activation_fn = get_activation_function(activation) current_dim = self.d_input self.layers = nn.ModuleList() if self.layer_list is not None: for _, dim in enumerate(self.layer_list): self.layers.append(nn.Linear(current_dim, dim, bias=bias)) current_dim = dim else: if use_identity_stub: self.layers.append(nn.Identity()) self.final_layer = nn.Linear(current_dim, self.d_out, bias=bias) def forward(self, x, *args, **kwargs): if self.use_first_dropout: x = nn.Dropout(self.dropout)(x) for layer in self.layers: x = layer(x) x = self.activation_fn(x) x = nn.Dropout(self.dropout)(x) x = self.final_layer(x) if self.use_final_activation: x = self.activation_fn(x) if self.use_final_dropout: x = nn.Dropout(self.dropout)(x) return x