faisalsns's picture
initial commit
7bb1e78
raw
history blame contribute delete
768 Bytes
# Create a standalone file with just the model class
import torch
import torch.nn as nn
class TabularModel(nn.Module):
def __init__(self, input_size, hidden_sizes, output_size, dropout_rate=0.2):
super(TabularModel, self).__init__()
layers = []
prev_size = input_size
for hidden_size in hidden_sizes:
layers.extend([
nn.Linear(prev_size, hidden_size),
nn.BatchNorm1d(hidden_size),
nn.ReLU(),
nn.Dropout(dropout_rate)
])
prev_size = hidden_size
layers.append(nn.Linear(prev_size, output_size))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)