| import torch.nn as nn | |
| import torch.nn.functional as F | |
| def mlp( | |
| input_dim, | |
| hidden_dim, | |
| output_dim, | |
| hidden_depth, | |
| output_mod=None, | |
| batchnorm=False, | |
| activation=nn.ReLU, | |
| ): | |
| if hidden_depth == 0: | |
| mods = [nn.Linear(input_dim, output_dim)] | |
| else: | |
| mods = ( | |
| [nn.Linear(input_dim, hidden_dim), activation(inplace=True)] | |
| if not batchnorm | |
| else [ | |
| nn.Linear(input_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| activation(inplace=True), | |
| ] | |
| ) | |
| for _ in range(hidden_depth - 1): | |
| mods += ( | |
| [nn.Linear(hidden_dim, hidden_dim), activation(inplace=True)] | |
| if not batchnorm | |
| else [ | |
| nn.Linear(hidden_dim, hidden_dim), | |
| nn.BatchNorm1d(hidden_dim), | |
| activation(inplace=True), | |
| ] | |
| ) | |
| mods.append(nn.Linear(hidden_dim, output_dim)) | |
| if output_mod is not None: | |
| mods.append(output_mod) | |
| trunk = nn.Sequential(*mods) | |
| return trunk | |
| def weight_init(m): | |
| """Custom weight init for Conv2D and Linear layers.""" | |
| if isinstance(m, nn.Linear): | |
| nn.init.orthogonal_(m.weight.data) | |
| if hasattr(m.bias, "data"): | |
| m.bias.data.fill_(0.0) | |
| class MLP(nn.Module): | |
| def __init__( | |
| self, | |
| input_dim, | |
| hidden_dim, | |
| output_dim, | |
| hidden_depth, | |
| output_mod=None, | |
| batchnorm=False, | |
| activation=nn.ReLU, | |
| ): | |
| super().__init__() | |
| self.trunk = mlp( | |
| input_dim, | |
| hidden_dim, | |
| output_dim, | |
| hidden_depth, | |
| output_mod, | |
| batchnorm=batchnorm, | |
| activation=activation, | |
| ) | |
| self.apply(weight_init) | |
| def forward(self, x): | |
| return self.trunk(x) | |