Upload sentiment classifier trained on Amazon Reviews
Browse files- sentiment_classifier.py +20 -9
sentiment_classifier.py
CHANGED
|
@@ -1,10 +1,11 @@
|
|
| 1 |
"""Sentiment classifier for text classification."""
|
| 2 |
|
| 3 |
-
from typing import Dict, Optional
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import AutoModel, PreTrainedModel
|
|
|
|
| 8 |
|
| 9 |
# Handle imports for both local usage and HuggingFace Hub
|
| 10 |
try:
|
|
@@ -104,8 +105,9 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 104 |
input_ids: torch.Tensor,
|
| 105 |
attention_mask: torch.Tensor,
|
| 106 |
labels: Optional[torch.Tensor] = None,
|
|
|
|
| 107 |
**kwargs,
|
| 108 |
-
) -> Dict[str, torch.Tensor]:
|
| 109 |
"""
|
| 110 |
Forward pass for classification.
|
| 111 |
|
|
@@ -113,11 +115,14 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 113 |
input_ids: Input token IDs [batch_size, seq_len].
|
| 114 |
attention_mask: Attention mask [batch_size, seq_len].
|
| 115 |
labels: Ground truth sentiment labels [batch_size].
|
|
|
|
| 116 |
**kwargs: Additional arguments.
|
| 117 |
|
| 118 |
Returns:
|
| 119 |
-
|
| 120 |
"""
|
|
|
|
|
|
|
| 121 |
# Encode with transformer
|
| 122 |
outputs = self.encoder(
|
| 123 |
input_ids=input_ids,
|
|
@@ -140,10 +145,16 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 140 |
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
|
| 141 |
loss = loss_fct(logits, labels)
|
| 142 |
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
def predict(
|
| 149 |
self,
|
|
@@ -164,7 +175,7 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 164 |
|
| 165 |
with torch.no_grad():
|
| 166 |
outputs = self.forward(input_ids, attention_mask)
|
| 167 |
-
logits = outputs
|
| 168 |
label_predictions = torch.argmax(logits, dim=-1)
|
| 169 |
|
| 170 |
return label_predictions
|
|
@@ -188,7 +199,7 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 188 |
|
| 189 |
with torch.no_grad():
|
| 190 |
outputs = self.forward(input_ids, attention_mask)
|
| 191 |
-
logits = outputs
|
| 192 |
probabilities = torch.softmax(logits, dim=-1)
|
| 193 |
|
| 194 |
return probabilities
|
|
|
|
| 1 |
"""Sentiment classifier for text classification."""
|
| 2 |
|
| 3 |
+
from typing import Dict, Optional, Union
|
| 4 |
|
| 5 |
import torch
|
| 6 |
import torch.nn as nn
|
| 7 |
from transformers import AutoModel, PreTrainedModel
|
| 8 |
+
from transformers.modeling_outputs import SequenceClassifierOutput
|
| 9 |
|
| 10 |
# Handle imports for both local usage and HuggingFace Hub
|
| 11 |
try:
|
|
|
|
| 105 |
input_ids: torch.Tensor,
|
| 106 |
attention_mask: torch.Tensor,
|
| 107 |
labels: Optional[torch.Tensor] = None,
|
| 108 |
+
return_dict: Optional[bool] = None,
|
| 109 |
**kwargs,
|
| 110 |
+
) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]:
|
| 111 |
"""
|
| 112 |
Forward pass for classification.
|
| 113 |
|
|
|
|
| 115 |
input_ids: Input token IDs [batch_size, seq_len].
|
| 116 |
attention_mask: Attention mask [batch_size, seq_len].
|
| 117 |
labels: Ground truth sentiment labels [batch_size].
|
| 118 |
+
return_dict: Whether to return a SequenceClassifierOutput or dict.
|
| 119 |
**kwargs: Additional arguments.
|
| 120 |
|
| 121 |
Returns:
|
| 122 |
+
SequenceClassifierOutput or dictionary containing loss and logits.
|
| 123 |
"""
|
| 124 |
+
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
|
| 125 |
+
|
| 126 |
# Encode with transformer
|
| 127 |
outputs = self.encoder(
|
| 128 |
input_ids=input_ids,
|
|
|
|
| 145 |
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
|
| 146 |
loss = loss_fct(logits, labels)
|
| 147 |
|
| 148 |
+
if not return_dict:
|
| 149 |
+
output = (logits,)
|
| 150 |
+
return ((loss,) + output) if loss is not None else output
|
| 151 |
+
|
| 152 |
+
return SequenceClassifierOutput(
|
| 153 |
+
loss=loss,
|
| 154 |
+
logits=logits,
|
| 155 |
+
hidden_states=None,
|
| 156 |
+
attentions=None,
|
| 157 |
+
)
|
| 158 |
|
| 159 |
def predict(
|
| 160 |
self,
|
|
|
|
| 175 |
|
| 176 |
with torch.no_grad():
|
| 177 |
outputs = self.forward(input_ids, attention_mask)
|
| 178 |
+
logits = outputs.logits
|
| 179 |
label_predictions = torch.argmax(logits, dim=-1)
|
| 180 |
|
| 181 |
return label_predictions
|
|
|
|
| 199 |
|
| 200 |
with torch.no_grad():
|
| 201 |
outputs = self.forward(input_ids, attention_mask)
|
| 202 |
+
logits = outputs.logits
|
| 203 |
probabilities = torch.softmax(logits, dim=-1)
|
| 204 |
|
| 205 |
return probabilities
|