sentiment-classifier / sentiment_classifier.py
anpmts's picture
Upload sentiment classifier trained on Amazon Reviews
abb6dd8 verified
"""Sentiment classifier for text classification."""
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
# Handle imports for both local usage and HuggingFace Hub
try:
from .configuration_sentiment import SentimentClassifierConfig
except ImportError:
try:
from configuration_sentiment import SentimentClassifierConfig
except ImportError:
from src.models.configuration_sentiment import SentimentClassifierConfig
class SentimentClassifier(PreTrainedModel):
"""
Sentiment classifier for sequence classification.
Outputs:
Sentiment label (positive/neutral/negative) - classification
"""
config_class = SentimentClassifierConfig
def __init__(
self,
config: Optional[SentimentClassifierConfig] = None,
pretrained_model: str = "xlm-roberta-base",
num_labels: int = 3,
dropout: float = 0.1,
hidden_size: Optional[int] = None,
class_weights: Optional[torch.Tensor] = None,
use_flash_attention_2: bool = False,
gradient_checkpointing: bool = False,
):
"""
Initialize sentiment classifier.
Args:
config: Model configuration object.
pretrained_model: Name of the pre-trained model.
num_labels: Number of sentiment classes (default: 3).
dropout: Dropout probability.
hidden_size: Hidden size of the model (auto-detected if None).
class_weights: Tensor of class weights for classification loss.
use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
gradient_checkpointing: Enable gradient checkpointing to save memory.
"""
# Create config if not provided
if config is None:
config = SentimentClassifierConfig(
pretrained_model=pretrained_model,
num_labels=num_labels,
dropout=dropout,
hidden_size=hidden_size,
)
super().__init__(config)
# Load pre-trained transformer with optional Flash Attention 2
encoder_kwargs = {}
if use_flash_attention_2:
try:
encoder_kwargs["attn_implementation"] = "flash_attention_2"
except Exception:
# Flash Attention 2 not available, will use default
pass
self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
# Enable gradient checkpointing if requested (saves memory at cost of compute)
if gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
# Get hidden size
if config.hidden_size is None:
config.hidden_size = self.encoder.config.hidden_size
self.hidden_size = config.hidden_size
self.num_labels = config.num_labels
# Dropout
self.dropout = nn.Dropout(config.dropout)
# Classification head (sentiment label)
self.classifier = nn.Linear(self.hidden_size, self.num_labels)
# Class weights
self.register_buffer(
"class_weights",
class_weights if class_weights is not None else torch.ones(self.num_labels),
)
# Initialize weights
self.post_init()
def _init_weights(self, module):
"""Initialize head weights."""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]:
"""
Forward pass for classification.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Ground truth sentiment labels [batch_size].
return_dict: Whether to return a SequenceClassifierOutput or dict.
**kwargs: Additional arguments.
Returns:
SequenceClassifierOutput or dictionary containing loss and logits.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode with transformer
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
# Use [CLS] token representation
pooled_output = outputs.last_hidden_state[:, 0, :]
# Apply dropout
pooled_output = self.dropout(pooled_output)
# Classification head
logits = self.classifier(pooled_output)
# Compute loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""
Make predictions.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
Predicted labels [batch_size].
"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
logits = outputs.logits
label_predictions = torch.argmax(logits, dim=-1)
return label_predictions
def get_probabilities(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""
Get class probabilities.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
Class probabilities [batch_size, num_labels].
"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
return probabilities
def freeze_encoder(self):
"""Freeze encoder parameters (only train classification head)."""
for param in self.encoder.parameters():
param.requires_grad = False
def unfreeze_encoder(self):
"""Unfreeze encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = True
def get_num_trainable_params(self) -> int:
"""Get number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)