Rui Wan
upload model
6977ea8
import torch
class NeuralNetwork(torch.nn.Module):
def __init__(self, layer_sizes, dropout_rate=0.0, activation=torch.nn.ReLU):
super(NeuralNetwork, self).__init__()
if dropout_rate > 0:
self.dropout_layer = torch.nn.Dropout(dropout_rate)
self.layer_sizes = layer_sizes
self.layers = torch.nn.ModuleList()
for i in range(len(layer_sizes) - 2):
self.layers.append(torch.nn.Linear(layer_sizes[i], layer_sizes[i + 1]))
self.layers.append(activation())
self.layers.append(torch.nn.Linear(layer_sizes[-2], layer_sizes[-1]))
# self.sequential = torch.nn.Sequential(*self.layers)
self.init_weights()
def init_weights(self):
for layer in self.layers:
if isinstance(layer, torch.nn.Linear):
torch.nn.init.xavier_normal_(layer.weight)
layer.bias.data.fill_(0.0)
def forward(self, x):
for layer in self.layers:
x = layer(x)
# Use the module's train/eval mode to control dropout.
if self.training and hasattr(self, 'dropout_layer') and not isinstance(layer, torch.nn.Linear):
x = self.dropout_layer(x)
return x
def predict(self, x):
self.eval()
with torch.no_grad():
return self.forward(x)