| from transformers import PretrainedConfig, XLMRobertaForSequenceClassification | |
| import torch.nn as nn | |
| import torch | |
| class CustomConfig(PretrainedConfig): | |
| model_type = "custom_model" | |
| def __init__(self, num_emotion_labels=18, **kwargs): | |
| super().__init__(**kwargs) | |
| self.num_emotion_labels = num_emotion_labels | |
| class CustomModel(XLMRobertaForSequenceClassification): | |
| config_class = CustomConfig | |
| def __init__(self, config): | |
| super(CustomModel, self).__init__(config) | |
| self.num_emotion_labels = config.num_emotion_labels | |
| self.dropout_emotion = nn.Dropout(config.hidden_dropout_prob) | |
| self.emotion_classifier = nn.Sequential( | |
| nn.Linear(config.hidden_size, 512), | |
| nn.Mish(), | |
| nn.Dropout(0.3), | |
| nn.Linear(512, self.num_emotion_labels) | |
| ) | |
| self._init_weights(self.emotion_classifier[0]) | |
| self._init_weights(self.emotion_classifier[3]) | |
| def _init_weights(self, module): | |
| if isinstance(module, nn.Linear): | |
| module.weight.data.normal_(mean=0.0, std=self.config.initializer_range) | |
| if module.bias is not None: | |
| module.bias.data.zero_() | |
| def forward(self, input_ids=None, attention_mask=None, sentiment=None, labels=None): | |
| outputs = self.roberta(input_ids=input_ids, attention_mask=attention_mask) | |
| sequence_output = outputs[0] | |
| if len(sequence_output.shape) != 3: | |
| raise ValueError(f"Expected sequence_output to have 3 dimensions, got {sequence_output.shape}") | |
| cls_hidden_states = sequence_output[:, 0, :] | |
| cls_hidden_states = self.dropout_emotion(cls_hidden_states) | |
| emotion_logits = self.emotion_classifier(cls_hidden_states) | |
| with torch.no_grad(): | |
| cls_token_state = sequence_output[:, 0, :].unsqueeze(1) | |
| sentiment_logits = self.classifier(cls_token_state).squeeze(1) | |
| if labels is not None: | |
| class_weights = torch.tensor([1.0] * self.num_emotion_labels).to(labels.device) | |
| loss_fct = nn.BCEWithLogitsLoss(pos_weight=class_weights) | |
| loss = loss_fct(emotion_logits, labels) | |
| return {"loss": loss, "emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits} | |
| return {"emotion_logits": emotion_logits, "sentiment_logits": sentiment_logits} | |