| """Sentiment classifier for text classification.""" |
|
|
| from typing import Dict, Optional, Union |
|
|
| import torch |
| import torch.nn as nn |
| from transformers import AutoModel, PreTrainedModel |
| from transformers.modeling_outputs import SequenceClassifierOutput |
|
|
| |
| try: |
| from .configuration_sentiment import SentimentClassifierConfig |
| except ImportError: |
| try: |
| from configuration_sentiment import SentimentClassifierConfig |
| except ImportError: |
| from src.models.configuration_sentiment import SentimentClassifierConfig |
|
|
|
|
| class SentimentClassifier(PreTrainedModel): |
| """ |
| Sentiment classifier for sequence classification. |
| |
| Outputs: |
| Sentiment label (positive/neutral/negative) - classification |
| """ |
|
|
| config_class = SentimentClassifierConfig |
|
|
| def __init__( |
| self, |
| config: Optional[SentimentClassifierConfig] = None, |
| pretrained_model: str = "xlm-roberta-base", |
| num_labels: int = 3, |
| dropout: float = 0.1, |
| hidden_size: Optional[int] = None, |
| class_weights: Optional[torch.Tensor] = None, |
| use_flash_attention_2: bool = False, |
| gradient_checkpointing: bool = False, |
| ): |
| """ |
| Initialize sentiment classifier. |
| |
| Args: |
| config: Model configuration object. |
| pretrained_model: Name of the pre-trained model. |
| num_labels: Number of sentiment classes (default: 3). |
| dropout: Dropout probability. |
| hidden_size: Hidden size of the model (auto-detected if None). |
| class_weights: Tensor of class weights for classification loss. |
| use_flash_attention_2: Use Flash Attention 2 for faster attention (if available). |
| gradient_checkpointing: Enable gradient checkpointing to save memory. |
| """ |
| |
| if config is None: |
| config = SentimentClassifierConfig( |
| pretrained_model=pretrained_model, |
| num_labels=num_labels, |
| dropout=dropout, |
| hidden_size=hidden_size, |
| ) |
|
|
| super().__init__(config) |
|
|
| |
| encoder_kwargs = {} |
| if use_flash_attention_2: |
| try: |
| encoder_kwargs["attn_implementation"] = "flash_attention_2" |
| except Exception: |
| |
| pass |
|
|
| self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs) |
|
|
| |
| if gradient_checkpointing: |
| self.encoder.gradient_checkpointing_enable() |
|
|
| |
| if config.hidden_size is None: |
| config.hidden_size = self.encoder.config.hidden_size |
|
|
| self.hidden_size = config.hidden_size |
| self.num_labels = config.num_labels |
|
|
| |
| self.dropout = nn.Dropout(config.dropout) |
|
|
| |
| self.classifier = nn.Linear(self.hidden_size, self.num_labels) |
|
|
| |
| self.register_buffer( |
| "class_weights", |
| class_weights if class_weights is not None else torch.ones(self.num_labels), |
| ) |
|
|
| |
| self.post_init() |
|
|
| def _init_weights(self, module): |
| """Initialize head weights.""" |
| if isinstance(module, nn.Linear): |
| nn.init.xavier_uniform_(module.weight) |
| if module.bias is not None: |
| nn.init.zeros_(module.bias) |
|
|
| def forward( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| return_dict: Optional[bool] = None, |
| **kwargs, |
| ) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]: |
| """ |
| Forward pass for classification. |
| |
| Args: |
| input_ids: Input token IDs [batch_size, seq_len]. |
| attention_mask: Attention mask [batch_size, seq_len]. |
| labels: Ground truth sentiment labels [batch_size]. |
| return_dict: Whether to return a SequenceClassifierOutput or dict. |
| **kwargs: Additional arguments. |
| |
| Returns: |
| SequenceClassifierOutput or dictionary containing loss and logits. |
| """ |
| return_dict = return_dict if return_dict is not None else self.config.use_return_dict |
|
|
| |
| outputs = self.encoder( |
| input_ids=input_ids, |
| attention_mask=attention_mask, |
| return_dict=True, |
| ) |
|
|
| |
| pooled_output = outputs.last_hidden_state[:, 0, :] |
|
|
| |
| pooled_output = self.dropout(pooled_output) |
|
|
| |
| logits = self.classifier(pooled_output) |
|
|
| |
| loss = None |
| if labels is not None: |
| loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) |
| loss = loss_fct(logits, labels) |
|
|
| if not return_dict: |
| output = (logits,) |
| return ((loss,) + output) if loss is not None else output |
|
|
| return SequenceClassifierOutput( |
| loss=loss, |
| logits=logits, |
| hidden_states=None, |
| attentions=None, |
| ) |
|
|
| def predict( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Make predictions. |
| |
| Args: |
| input_ids: Input token IDs [batch_size, seq_len]. |
| attention_mask: Attention mask [batch_size, seq_len]. |
| |
| Returns: |
| Predicted labels [batch_size]. |
| """ |
| self.eval() |
|
|
| with torch.no_grad(): |
| outputs = self.forward(input_ids, attention_mask) |
| logits = outputs.logits |
| label_predictions = torch.argmax(logits, dim=-1) |
|
|
| return label_predictions |
|
|
| def get_probabilities( |
| self, |
| input_ids: torch.Tensor, |
| attention_mask: torch.Tensor, |
| ) -> torch.Tensor: |
| """ |
| Get class probabilities. |
| |
| Args: |
| input_ids: Input token IDs [batch_size, seq_len]. |
| attention_mask: Attention mask [batch_size, seq_len]. |
| |
| Returns: |
| Class probabilities [batch_size, num_labels]. |
| """ |
| self.eval() |
|
|
| with torch.no_grad(): |
| outputs = self.forward(input_ids, attention_mask) |
| logits = outputs.logits |
| probabilities = torch.softmax(logits, dim=-1) |
|
|
| return probabilities |
|
|
| def freeze_encoder(self): |
| """Freeze encoder parameters (only train classification head).""" |
| for param in self.encoder.parameters(): |
| param.requires_grad = False |
|
|
| def unfreeze_encoder(self): |
| """Unfreeze encoder parameters.""" |
| for param in self.encoder.parameters(): |
| param.requires_grad = True |
|
|
| def get_num_trainable_params(self) -> int: |
| """Get number of trainable parameters.""" |
| return sum(p.numel() for p in self.parameters() if p.requires_grad) |
|
|