anpmts commited on
Commit
b3578be
·
verified ·
1 Parent(s): 877fdb8

Upload sentiment classifier trained on Amazon Reviews

Browse files
Files changed (1) hide show
  1. sentiment_classifier.py +20 -9
sentiment_classifier.py CHANGED
@@ -1,10 +1,11 @@
1
  """Sentiment classifier for text classification."""
2
 
3
- from typing import Dict, Optional
4
 
5
  import torch
6
  import torch.nn as nn
7
  from transformers import AutoModel, PreTrainedModel
 
8
 
9
  # Handle imports for both local usage and HuggingFace Hub
10
  try:
@@ -104,8 +105,9 @@ class SentimentClassifier(PreTrainedModel):
104
  input_ids: torch.Tensor,
105
  attention_mask: torch.Tensor,
106
  labels: Optional[torch.Tensor] = None,
 
107
  **kwargs,
108
- ) -> Dict[str, torch.Tensor]:
109
  """
110
  Forward pass for classification.
111
 
@@ -113,11 +115,14 @@ class SentimentClassifier(PreTrainedModel):
113
  input_ids: Input token IDs [batch_size, seq_len].
114
  attention_mask: Attention mask [batch_size, seq_len].
115
  labels: Ground truth sentiment labels [batch_size].
 
116
  **kwargs: Additional arguments.
117
 
118
  Returns:
119
- Dictionary containing loss and logits.
120
  """
 
 
121
  # Encode with transformer
122
  outputs = self.encoder(
123
  input_ids=input_ids,
@@ -140,10 +145,16 @@ class SentimentClassifier(PreTrainedModel):
140
  loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
141
  loss = loss_fct(logits, labels)
142
 
143
- return {
144
- "loss": loss,
145
- "logits": logits,
146
- }
 
 
 
 
 
 
147
 
148
  def predict(
149
  self,
@@ -164,7 +175,7 @@ class SentimentClassifier(PreTrainedModel):
164
 
165
  with torch.no_grad():
166
  outputs = self.forward(input_ids, attention_mask)
167
- logits = outputs["logits"]
168
  label_predictions = torch.argmax(logits, dim=-1)
169
 
170
  return label_predictions
@@ -188,7 +199,7 @@ class SentimentClassifier(PreTrainedModel):
188
 
189
  with torch.no_grad():
190
  outputs = self.forward(input_ids, attention_mask)
191
- logits = outputs["logits"]
192
  probabilities = torch.softmax(logits, dim=-1)
193
 
194
  return probabilities
 
1
  """Sentiment classifier for text classification."""
2
 
3
+ from typing import Dict, Optional, Union
4
 
5
  import torch
6
  import torch.nn as nn
7
  from transformers import AutoModel, PreTrainedModel
8
+ from transformers.modeling_outputs import SequenceClassifierOutput
9
 
10
  # Handle imports for both local usage and HuggingFace Hub
11
  try:
 
105
  input_ids: torch.Tensor,
106
  attention_mask: torch.Tensor,
107
  labels: Optional[torch.Tensor] = None,
108
+ return_dict: Optional[bool] = None,
109
  **kwargs,
110
+ ) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]:
111
  """
112
  Forward pass for classification.
113
 
 
115
  input_ids: Input token IDs [batch_size, seq_len].
116
  attention_mask: Attention mask [batch_size, seq_len].
117
  labels: Ground truth sentiment labels [batch_size].
118
+ return_dict: Whether to return a SequenceClassifierOutput or dict.
119
  **kwargs: Additional arguments.
120
 
121
  Returns:
122
+ SequenceClassifierOutput or dictionary containing loss and logits.
123
  """
124
+ return_dict = return_dict if return_dict is not None else self.config.use_return_dict
125
+
126
  # Encode with transformer
127
  outputs = self.encoder(
128
  input_ids=input_ids,
 
145
  loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
146
  loss = loss_fct(logits, labels)
147
 
148
+ if not return_dict:
149
+ output = (logits,)
150
+ return ((loss,) + output) if loss is not None else output
151
+
152
+ return SequenceClassifierOutput(
153
+ loss=loss,
154
+ logits=logits,
155
+ hidden_states=None,
156
+ attentions=None,
157
+ )
158
 
159
  def predict(
160
  self,
 
175
 
176
  with torch.no_grad():
177
  outputs = self.forward(input_ids, attention_mask)
178
+ logits = outputs.logits
179
  label_predictions = torch.argmax(logits, dim=-1)
180
 
181
  return label_predictions
 
199
 
200
  with torch.no_grad():
201
  outputs = self.forward(input_ids, attention_mask)
202
+ logits = outputs.logits
203
  probabilities = torch.softmax(logits, dim=-1)
204
 
205
  return probabilities