from transformers import PreTrainedModel, PretrainedConfig import torch.nn as nn class EmbeddingClassifierConfig(PretrainedConfig): model_type = "embedding_classifier" def __init__(self, input_dim=1024, num_classes=7, hidden_size=768, **kwargs): super().__init__(**kwargs) self.input_dim = input_dim self.num_classes = num_classes self.hidden_size = hidden_size class EmbeddingClassifier(PreTrainedModel): config_class = EmbeddingClassifierConfig def __init__(self, config): super().__init__(config) self.fc1 = nn.Linear(config.input_dim, config.hidden_size) self.activation = nn.ReLU() self.fc2 = nn.Linear(config.hidden_size, config.num_classes) def forward(self, x): x = self.fc1(x) x = self.activation(x) x = self.fc2(x) return x