anpmts commited on
Commit
abb6dd8
·
verified ·
1 Parent(s): b3578be

Upload sentiment classifier trained on Amazon Reviews

Browse files
Files changed (2) hide show
  1. model.safetensors +1 -1
  2. sentiment_classifier.py +6 -0
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:25bcb5c098e6ee12a1982c57d0ae4af0e03db286684b66c37283561f7a7563c7
3
  size 1112208144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0a615a2066e4a0f86b39d9bcd8dedc63e6adc3e785b38ad31f60dfa8baad4c4b
3
  size 1112208144
sentiment_classifier.py CHANGED
@@ -36,6 +36,7 @@ class SentimentClassifier(PreTrainedModel):
36
  hidden_size: Optional[int] = None,
37
  class_weights: Optional[torch.Tensor] = None,
38
  use_flash_attention_2: bool = False,
 
39
  ):
40
  """
41
  Initialize sentiment classifier.
@@ -48,6 +49,7 @@ class SentimentClassifier(PreTrainedModel):
48
  hidden_size: Hidden size of the model (auto-detected if None).
49
  class_weights: Tensor of class weights for classification loss.
50
  use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
 
51
  """
52
  # Create config if not provided
53
  if config is None:
@@ -71,6 +73,10 @@ class SentimentClassifier(PreTrainedModel):
71
 
72
  self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
73
 
 
 
 
 
74
  # Get hidden size
75
  if config.hidden_size is None:
76
  config.hidden_size = self.encoder.config.hidden_size
 
36
  hidden_size: Optional[int] = None,
37
  class_weights: Optional[torch.Tensor] = None,
38
  use_flash_attention_2: bool = False,
39
+ gradient_checkpointing: bool = False,
40
  ):
41
  """
42
  Initialize sentiment classifier.
 
49
  hidden_size: Hidden size of the model (auto-detected if None).
50
  class_weights: Tensor of class weights for classification loss.
51
  use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
52
+ gradient_checkpointing: Enable gradient checkpointing to save memory.
53
  """
54
  # Create config if not provided
55
  if config is None:
 
73
 
74
  self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
75
 
76
+ # Enable gradient checkpointing if requested (saves memory at cost of compute)
77
+ if gradient_checkpointing:
78
+ self.encoder.gradient_checkpointing_enable()
79
+
80
  # Get hidden size
81
  if config.hidden_size is None:
82
  config.hidden_size = self.encoder.config.hidden_size