| from transformers import PreTrainedModel, PretrainedConfig | |
| import torch.nn as nn | |
| class EmbeddingClassifierConfig(PretrainedConfig): | |
| model_type = "embedding_classifier" | |
| def __init__(self, input_dim=1024, num_classes=7, hidden_size=768, **kwargs): | |
| super().__init__(**kwargs) | |
| self.input_dim = input_dim | |
| self.num_classes = num_classes | |
| self.hidden_size = hidden_size | |
| class EmbeddingClassifier(PreTrainedModel): | |
| config_class = EmbeddingClassifierConfig | |
| def __init__(self, config): | |
| super().__init__(config) | |
| self.fc1 = nn.Linear(config.input_dim, config.hidden_size) | |
| self.activation = nn.ReLU() | |
| self.fc2 = nn.Linear(config.hidden_size, config.num_classes) | |
| def forward(self, x): | |
| x = self.fc1(x) | |
| x = self.activation(x) | |
| x = self.fc2(x) | |
| return x | |