|
|
"""Sentiment classifier for text classification.""" |
|
|
|
|
|
from typing import Dict, Optional, Union |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
from transformers import AutoModel, PreTrainedModel |
|
|
from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
|
|
|
|
|
|
try: |
|
|
from .configuration_sentiment import SentimentClassifierConfig |
|
|
except ImportError: |
|
|
try: |
|
|
from configuration_sentiment import SentimentClassifierConfig |
|
|
except ImportError: |
|
|
from src.models.configuration_sentiment import SentimentClassifierConfig |
|
|
|
|
|
|
|
|
class SentimentClassifier(PreTrainedModel): |
|
|
""" |
|
|
Sentiment classifier for sequence classification. |
|
|
|
|
|
Outputs: |
|
|
Sentiment label (positive/neutral/negative) - classification |
|
|
""" |
|
|
|
|
|
config_class = SentimentClassifierConfig |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
config: Optional[SentimentClassifierConfig] = None, |
|
|
pretrained_model: str = "xlm-roberta-base", |
|
|
num_labels: int = 3, |
|
|
dropout: float = 0.1, |
|
|
hidden_size: Optional[int] = None, |
|
|
class_weights: Optional[torch.Tensor] = None, |
|
|
use_flash_attention_2: bool = False, |
|
|
gradient_checkpointing: bool = False, |
|
|
): |
|
|
""" |
|
|
Initialize sentiment classifier. |
|
|
|
|
|
Args: |
|
|
config: Model configuration object. |
|
|
pretrained_model: Name of the pre-trained model. |
|
|
num_labels: Number of sentiment classes (default: 3). |
|
|
dropout: Dropout probability. |
|
|
hidden_size: Hidden size of the model (auto-detected if None). |
|
|
class_weights: Tensor of class weights for classification loss. |
|
|
use_flash_attention_2: Use Flash Attention 2 for faster attention (if available). |
|
|
gradient_checkpointing: Enable gradient checkpointing to save memory. |
|
|
""" |
|
|
|
|
|
if config is None: |
|
|
config = SentimentClassifierConfig( |
|
|
pretrained_model=pretrained_model, |
|
|
num_labels=num_labels, |
|
|
dropout=dropout, |
|
|
hidden_size=hidden_size, |
|
|
) |
|
|
|
|
|
super().__init__(config) |
|
|
|
|
|
|
|
|
encoder_kwargs = {} |
|
|
if use_flash_attention_2: |
|
|
try: |
|
|
encoder_kwargs["attn_implementation"] = "flash_attention_2" |
|
|
except Exception: |
|
|
|
|
|
pass |
|
|
|
|
|
self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs) |
|
|
|
|
|
|
|
|
if gradient_checkpointing: |
|
|
self.encoder.gradient_checkpointing_enable() |
|
|
|
|
|
|
|
|
if config.hidden_size is None: |
|
|
config.hidden_size = self.encoder.config.hidden_size |
|
|
|
|
|
self.hidden_size = config.hidden_size |
|
|
self.num_labels = config.num_labels |
|
|
|
|
|
|
|
|
self.dropout = nn.Dropout(config.dropout) |
|
|
|
|
|
|
|
|
self.classifier = nn.Linear(self.hidden_size, self.num_labels) |
|
|
|
|
|
|
|
|
self.register_buffer( |
|
|
"class_weights", |
|
|
class_weights if class_weights is not None else torch.ones(self.num_labels), |
|
|
) |
|
|
|
|
|
|
|
|
self.post_init() |
|
|
|
|
|
def _init_weights(self, module): |
|
|
"""Initialize head weights.""" |
|
|
if isinstance(module, nn.Linear): |
|
|
nn.init.xavier_uniform_(module.weight) |
|
|
if module.bias is not None: |
|
|
nn.init.zeros_(module.bias) |
|
|
|
|
|
def forward( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
return_dict: Optional[bool] = None, |
|
|
**kwargs, |
|
|
) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]: |
|
|
""" |
|
|
Forward pass for classification. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_len]. |
|
|
attention_mask: Attention mask [batch_size, seq_len]. |
|
|
labels: Ground truth sentiment labels [batch_size]. |
|
|
return_dict: Whether to return a SequenceClassifierOutput or dict. |
|
|
**kwargs: Additional arguments. |
|
|
|
|
|
Returns: |
|
|
SequenceClassifierOutput or dictionary containing loss and logits. |
|
|
""" |
|
|
return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
|
|
|
|
|
|
outputs = self.encoder( |
|
|
input_ids=input_ids, |
|
|
attention_mask=attention_mask, |
|
|
return_dict=True, |
|
|
) |
|
|
|
|
|
|
|
|
pooled_output = outputs.last_hidden_state[:, 0, :] |
|
|
|
|
|
|
|
|
pooled_output = self.dropout(pooled_output) |
|
|
|
|
|
|
|
|
logits = self.classifier(pooled_output) |
|
|
|
|
|
|
|
|
loss = None |
|
|
if labels is not None: |
|
|
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) |
|
|
loss = loss_fct(logits, labels) |
|
|
|
|
|
if not return_dict: |
|
|
output = (logits,) |
|
|
return ((loss,) + output) if loss is not None else output |
|
|
|
|
|
return SequenceClassifierOutput( |
|
|
loss=loss, |
|
|
logits=logits, |
|
|
hidden_states=None, |
|
|
attentions=None, |
|
|
) |
|
|
|
|
|
def predict( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Make predictions. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_len]. |
|
|
attention_mask: Attention mask [batch_size, seq_len]. |
|
|
|
|
|
Returns: |
|
|
Predicted labels [batch_size]. |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.forward(input_ids, attention_mask) |
|
|
logits = outputs.logits |
|
|
label_predictions = torch.argmax(logits, dim=-1) |
|
|
|
|
|
return label_predictions |
|
|
|
|
|
def get_probabilities( |
|
|
self, |
|
|
input_ids: torch.Tensor, |
|
|
attention_mask: torch.Tensor, |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Get class probabilities. |
|
|
|
|
|
Args: |
|
|
input_ids: Input token IDs [batch_size, seq_len]. |
|
|
attention_mask: Attention mask [batch_size, seq_len]. |
|
|
|
|
|
Returns: |
|
|
Class probabilities [batch_size, num_labels]. |
|
|
""" |
|
|
self.eval() |
|
|
|
|
|
with torch.no_grad(): |
|
|
outputs = self.forward(input_ids, attention_mask) |
|
|
logits = outputs.logits |
|
|
probabilities = torch.softmax(logits, dim=-1) |
|
|
|
|
|
return probabilities |
|
|
|
|
|
def freeze_encoder(self): |
|
|
"""Freeze encoder parameters (only train classification head).""" |
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
def unfreeze_encoder(self): |
|
|
"""Unfreeze encoder parameters.""" |
|
|
for param in self.encoder.parameters(): |
|
|
param.requires_grad = True |
|
|
|
|
|
def get_num_trainable_params(self) -> int: |
|
|
"""Get number of trainable parameters.""" |
|
|
return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|
|