firstTestModel / modeling_simple_mlp.py
go76dof's picture
Upload SimpleMLPForClassification
ed52679 verified
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from simple_mlp_configuration import SimpleMLPConfig
from transformers.modeling_outputs import SequenceClassifierOutput
class SimpleMLPForClassification(PreTrainedModel):
config_class = SimpleMLPConfig
def __init__(self, config):
super().__init__(config)
self.config = config
self.num_labels = config.num_classes
self.fc1 = nn.Linear(config.input_dim, config.hidden_dim)
self.activation = nn.ReLU()
self.dropout = nn.Dropout(config.dropout_rate)
self.fc2 = nn.Linear(config.hidden_dim, config.num_classes)
self.post_init()
def forward(self, inputs_embeds, labels=None, return_dict=None):
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
x = self.fc1(inputs_embeds)
x = self.activation(x)
x = self.dropout(x)
logits = self.fc2(x)
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss()
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
)