File size: 1,393 Bytes
ed52679
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import torch
import torch.nn as nn
from transformers import PreTrainedModel
from simple_mlp_configuration import SimpleMLPConfig
from transformers.modeling_outputs import SequenceClassifierOutput

class SimpleMLPForClassification(PreTrainedModel):
    config_class = SimpleMLPConfig

    def __init__(self, config):
        super().__init__(config)
        self.config = config
        self.num_labels = config.num_classes

        self.fc1 = nn.Linear(config.input_dim, config.hidden_dim)
        self.activation = nn.ReLU()
        self.dropout = nn.Dropout(config.dropout_rate)
        self.fc2 = nn.Linear(config.hidden_dim, config.num_classes)

        self.post_init()
    
    def forward(self, inputs_embeds, labels=None, return_dict=None):
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        x = self.fc1(inputs_embeds)
        x = self.activation(x)
        x = self.dropout(x)
        logits = self.fc2(x)

        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss()
            loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
        
        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output
        
        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
        )