File size: 768 Bytes
7bb1e78 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 | # Create a standalone file with just the model class
import torch
import torch.nn as nn
class TabularModel(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.2):
super(TabularModel, self).__init__()
layers = []
prev_size = input_size
for hidden_size in hidden_sizes:
layers.extend([
nn.Linear(prev_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate)
])
prev_size = hidden_size
layers.append(nn.Linear(prev_size, output_size))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x) |