| from transformers import RobertaModel, AutoTokenizer | |
| from transformers.modeling_outputs import SequenceClassifierOutput | |
| from huggingface_hub import PyTorchModelHubMixin | |
| from torch.nn import CrossEntropyLoss | |
| import torch.nn as nn | |
| import torch | |
| class SentenceBERTClassifier(nn.Module, PyTorchModelHubMixin): | |
| def __init__(self, model_name="sentence-transformers/all-distilroberta-v1", num_labels=8): | |
| super().__init__() | |
| self.sbert = RobertaModel.from_pretrained(model_name) | |
| self.config = self.sbert.config | |
| self.config.num_labels = num_labels | |
| self.dropout = nn.Dropout(0.05) | |
| self.config.classifier_dropout = 0.05 | |
| self.classifier = nn.Linear(self.config.hidden_size, self.config.num_labels) | |
| def forward(self, input_ids, attention_mask): | |
| outputs = self.sbert(input_ids=input_ids, attention_mask=attention_mask) | |
| pooled_output = outputs.pooler_output | |
| dropout_output = self.dropout(pooled_output) | |
| logits = self.classifier(dropout_output) | |
| return SequenceClassifierOutput( | |
| logits=logits, | |
| hidden_states=outputs.hidden_states, | |
| attentions=outputs.attentions, | |
| ) |