Upload RobertaEmotion
Browse files- configuration_roberta_emotion.py +2 -1
- modeling_roberta_emotion.py +3 -5
- pytorch_model.bin +1 -1
configuration_roberta_emotion.py
CHANGED
|
@@ -1,7 +1,8 @@
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
|
|
|
| 3 |
class RobertaEmotionConfig(PretrainedConfig):
|
| 4 |
model_type = "ma2za/roberta-emotion"
|
| 5 |
|
| 6 |
def __init__(self, **kwargs):
|
| 7 |
-
super().__init__(**kwargs)
|
|
|
|
| 1 |
from transformers import PretrainedConfig
|
| 2 |
|
| 3 |
+
|
| 4 |
class RobertaEmotionConfig(PretrainedConfig):
|
| 5 |
model_type = "ma2za/roberta-emotion"
|
| 6 |
|
| 7 |
def __init__(self, **kwargs):
|
| 8 |
+
super().__init__(**kwargs)
|
modeling_roberta_emotion.py
CHANGED
|
@@ -1,13 +1,11 @@
|
|
|
|
|
| 1 |
from torch.nn import CrossEntropyLoss
|
| 2 |
-
|
| 3 |
-
import torch
|
| 4 |
-
|
| 5 |
from transformers import AutoModel, PreTrainedModel
|
| 6 |
-
|
| 7 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 8 |
|
| 9 |
from .configuration_roberta_emotion import RobertaEmotionConfig
|
| 10 |
|
|
|
|
| 11 |
class RobertaEmotion(PreTrainedModel):
|
| 12 |
config_class = RobertaEmotionConfig
|
| 13 |
|
|
@@ -29,4 +27,4 @@ class RobertaEmotion(PreTrainedModel):
|
|
| 29 |
loss_fct = CrossEntropyLoss()
|
| 30 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 31 |
|
| 32 |
-
return SequenceClassifierOutput(loss=loss, logits=logits)
|
|
|
|
| 1 |
+
import torch
|
| 2 |
from torch.nn import CrossEntropyLoss
|
|
|
|
|
|
|
|
|
|
| 3 |
from transformers import AutoModel, PreTrainedModel
|
|
|
|
| 4 |
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 5 |
|
| 6 |
from .configuration_roberta_emotion import RobertaEmotionConfig
|
| 7 |
|
| 8 |
+
|
| 9 |
class RobertaEmotion(PreTrainedModel):
|
| 10 |
config_class = RobertaEmotionConfig
|
| 11 |
|
|
|
|
| 27 |
loss_fct = CrossEntropyLoss()
|
| 28 |
loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
|
| 29 |
|
| 30 |
+
return SequenceClassifierOutput(loss=loss, logits=logits)
|
pytorch_model.bin
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 498674549
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:e69dd62f8db826d469c8d82f95887b24557ba057283a4994b3d8aca3c918e251
|
| 3 |
size 498674549
|