Neweret's picture
Upload 6 files
dd06c8e verified
raw
history blame
667 Bytes
import torch
import torch.nn as nn
import torch.nn.functional as F
class SimpleClassifier(nn.Module):
def __init__(self, input_dim, num_classes, p_dropout=0.3):
super().__init__()
self.linear1 = nn.Linear(input_dim, 256)
self.ln1 = nn.LayerNorm(256)
self.dropout = nn.Dropout(p_dropout)
self.linear2 = nn.Linear(256, 128)
self.ln2 = nn.LayerNorm(128)
self.linear_out = nn.Linear(128, num_classes)
def forward(self, x):
x = F.gelu(self.ln1(self.linear1(x)))
x = self.dropout(x)
x = F.gelu(self.ln2(self.linear2(x)))
x = self.dropout(x)
return self.linear_out(x)