agri_flow_classifier_stella / modeling_embedding_classifier.py
rachitavya's picture
Custom modeling
234ecab
raw
history blame contribute delete
855 Bytes
from transformers import PreTrainedModel, PretrainedConfig
import torch.nn as nn
class EmbeddingClassifierConfig(PretrainedConfig):
model_type = "embedding_classifier"
def __init__(self, input_dim=1024, num_classes=7, hidden_size=768, **kwargs):
super().__init__(**kwargs)
self.input_dim = input_dim
self.num_classes = num_classes
self.hidden_size = hidden_size
class EmbeddingClassifier(PreTrainedModel):
config_class = EmbeddingClassifierConfig
def __init__(self, config):
super().__init__(config)
self.fc1 = nn.Linear(config.input_dim, config.hidden_size)
self.activation = nn.ReLU()
self.fc2 = nn.Linear(config.hidden_size, config.num_classes)
def forward(self, x):
x = self.fc1(x)
x = self.activation(x)
x = self.fc2(x)
return x