Upload sentiment classifier trained on Amazon Reviews
Browse files- model.safetensors +1 -1
- sentiment_classifier.py +6 -0
model.safetensors
CHANGED
|
@@ -1,3 +1,3 @@
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
-
oid sha256:
|
| 3 |
size 1112208144
|
|
|
|
| 1 |
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:0a615a2066e4a0f86b39d9bcd8dedc63e6adc3e785b38ad31f60dfa8baad4c4b
|
| 3 |
size 1112208144
|
sentiment_classifier.py
CHANGED
|
@@ -36,6 +36,7 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 36 |
hidden_size: Optional[int] = None,
|
| 37 |
class_weights: Optional[torch.Tensor] = None,
|
| 38 |
use_flash_attention_2: bool = False,
|
|
|
|
| 39 |
):
|
| 40 |
"""
|
| 41 |
Initialize sentiment classifier.
|
|
@@ -48,6 +49,7 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 48 |
hidden_size: Hidden size of the model (auto-detected if None).
|
| 49 |
class_weights: Tensor of class weights for classification loss.
|
| 50 |
use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
|
|
|
|
| 51 |
"""
|
| 52 |
# Create config if not provided
|
| 53 |
if config is None:
|
|
@@ -71,6 +73,10 @@ class SentimentClassifier(PreTrainedModel):
|
|
| 71 |
|
| 72 |
self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
|
| 73 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 74 |
# Get hidden size
|
| 75 |
if config.hidden_size is None:
|
| 76 |
config.hidden_size = self.encoder.config.hidden_size
|
|
|
|
| 36 |
hidden_size: Optional[int] = None,
|
| 37 |
class_weights: Optional[torch.Tensor] = None,
|
| 38 |
use_flash_attention_2: bool = False,
|
| 39 |
+
gradient_checkpointing: bool = False,
|
| 40 |
):
|
| 41 |
"""
|
| 42 |
Initialize sentiment classifier.
|
|
|
|
| 49 |
hidden_size: Hidden size of the model (auto-detected if None).
|
| 50 |
class_weights: Tensor of class weights for classification loss.
|
| 51 |
use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
|
| 52 |
+
gradient_checkpointing: Enable gradient checkpointing to save memory.
|
| 53 |
"""
|
| 54 |
# Create config if not provided
|
| 55 |
if config is None:
|
|
|
|
| 73 |
|
| 74 |
self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
|
| 75 |
|
| 76 |
+
# Enable gradient checkpointing if requested (saves memory at cost of compute)
|
| 77 |
+
if gradient_checkpointing:
|
| 78 |
+
self.encoder.gradient_checkpointing_enable()
|
| 79 |
+
|
| 80 |
# Get hidden size
|
| 81 |
if config.hidden_size is None:
|
| 82 |
config.hidden_size = self.encoder.config.hidden_size
|