File size: 855 Bytes
234ecab
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
from transformers import PreTrainedModel, PretrainedConfig
import torch.nn as nn

class EmbeddingClassifierConfig(PretrainedConfig):
    model_type = "embedding_classifier"

    def __init__(self, input_dim=1024, num_classes=7, hidden_size=768, **kwargs):
        super().__init__(**kwargs)
        self.input_dim = input_dim
        self.num_classes = num_classes
        self.hidden_size = hidden_size

class EmbeddingClassifier(PreTrainedModel):
    config_class = EmbeddingClassifierConfig

    def __init__(self, config):
        super().__init__(config)
        self.fc1 = nn.Linear(config.input_dim, config.hidden_size)
        self.activation = nn.ReLU()
        self.fc2 = nn.Linear(config.hidden_size, config.num_classes)

    def forward(self, x):
        x = self.fc1(x)
        x = self.activation(x)
        x = self.fc2(x)
        return x