Spaces:
Runtime error
Runtime error
| import torch; torch.manual_seed(0) | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| import torch.utils | |
| import torch.distributions | |
| import matplotlib.pyplot as plt; plt.rcParams['figure.dpi'] = 200 | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| def get_activation(activation): | |
| if activation == 'tanh': | |
| activ = F.tanh | |
| elif activation == 'relu': | |
| activ = F.relu | |
| elif activation == 'mish': | |
| activ = F.mish | |
| elif activation == 'sigmoid': | |
| activ = torch.sigmoid | |
| elif activation == 'leakyrelu': | |
| activ = F.leaky_relu | |
| elif activation == 'exp': | |
| activ = torch.exp | |
| else: | |
| raise ValueError | |
| return activ | |
| class SimpleNet(nn.Module): | |
| def __init__(self, input_dim, hidden_dims, output_dim, activation, dropout, final_activ=None): | |
| super(SimpleNet, self).__init__() | |
| self.linears = nn.ModuleList() | |
| self.dropouts = nn.ModuleList() | |
| self.output_dim = output_dim | |
| dims = [input_dim] + hidden_dims + [output_dim] | |
| for d_in, d_out in zip(dims[:-1], dims[1:]): | |
| self.linears.append(nn.Linear(d_in, d_out)) | |
| self.dropouts.append(nn.Dropout(dropout)) | |
| self.activation = get_activation(activation) | |
| self.n_layers = len(self.linears) | |
| self.layer_range = range(self.n_layers) | |
| if final_activ != None: | |
| self.final_activ = get_activation(final_activ) | |
| self.use_final_activ = True | |
| else: | |
| self.use_final_activ = False | |
| def forward(self, x): | |
| for i_layer, layer, dropout in zip(self.layer_range, self.linears, self.dropouts): | |
| x = layer(x) | |
| if i_layer != self.n_layers - 1: | |
| x = self.activation(dropout(x)) | |
| if self.use_final_activ: x = self.final_activ(x) | |
| return x | |