English
Shanci's picture
Upload folder using huggingface_hub
e703e79 verified
raw
history blame contribute delete
553 Bytes
import torch.nn as nn
class SimpleMLP(nn.Module):
def __init__(self, input_dim, hidden_dims, n_classes, dropout=0.1):
super().__init__()
layers = []
dims = [input_dim] + hidden_dims
for in_d, out_d in zip(dims[:-1], dims[1:]):
layers.append(nn.Linear(in_d, out_d))
layers.append(nn.ReLU())
#layers.append(nn.Dropout(dropout))
layers.append(nn.Linear(dims[-1], n_classes))
self.model = nn.Sequential(*layers)
def forward(self, x):
return self.model(x)