"""Sentiment classifier for text classification.""" from typing import Dict, Optional, Union import torch import torch.nn as nn from transformers import AutoModel, PreTrainedModel from transformers.modeling_outputs import SequenceClassifierOutput # Handle imports for both local usage and HuggingFace Hub try: from .configuration_sentiment import SentimentClassifierConfig except ImportError: try: from configuration_sentiment import SentimentClassifierConfig except ImportError: from src.models.configuration_sentiment import SentimentClassifierConfig class SentimentClassifier(PreTrainedModel): """ Sentiment classifier for sequence classification. Outputs: Sentiment label (positive/neutral/negative) - classification """ config_class = SentimentClassifierConfig def __init__( self, config: Optional[SentimentClassifierConfig] = None, pretrained_model: str = "xlm-roberta-base", num_labels: int = 3, dropout: float = 0.1, hidden_size: Optional[int] = None, class_weights: Optional[torch.Tensor] = None, use_flash_attention_2: bool = False, gradient_checkpointing: bool = False, ): """ Initialize sentiment classifier. Args: config: Model configuration object. pretrained_model: Name of the pre-trained model. num_labels: Number of sentiment classes (default: 3). dropout: Dropout probability. hidden_size: Hidden size of the model (auto-detected if None). class_weights: Tensor of class weights for classification loss. use_flash_attention_2: Use Flash Attention 2 for faster attention (if available). gradient_checkpointing: Enable gradient checkpointing to save memory. """ # Create config if not provided if config is None: config = SentimentClassifierConfig( pretrained_model=pretrained_model, num_labels=num_labels, dropout=dropout, hidden_size=hidden_size, ) super().__init__(config) # Load pre-trained transformer with optional Flash Attention 2 encoder_kwargs = {} if use_flash_attention_2: try: encoder_kwargs["attn_implementation"] = "flash_attention_2" except Exception: # Flash Attention 2 not available, will use default pass self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs) # Enable gradient checkpointing if requested (saves memory at cost of compute) if gradient_checkpointing: self.encoder.gradient_checkpointing_enable() # Get hidden size if config.hidden_size is None: config.hidden_size = self.encoder.config.hidden_size self.hidden_size = config.hidden_size self.num_labels = config.num_labels # Dropout self.dropout = nn.Dropout(config.dropout) # Classification head (sentiment label) self.classifier = nn.Linear(self.hidden_size, self.num_labels) # Class weights self.register_buffer( "class_weights", class_weights if class_weights is not None else torch.ones(self.num_labels), ) # Initialize weights self.post_init() def _init_weights(self, module): """Initialize head weights.""" if isinstance(module, nn.Linear): nn.init.xavier_uniform_(module.weight) if module.bias is not None: nn.init.zeros_(module.bias) def forward( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, labels: Optional[torch.Tensor] = None, return_dict: Optional[bool] = None, **kwargs, ) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]: """ Forward pass for classification. Args: input_ids: Input token IDs [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. labels: Ground truth sentiment labels [batch_size]. return_dict: Whether to return a SequenceClassifierOutput or dict. **kwargs: Additional arguments. Returns: SequenceClassifierOutput or dictionary containing loss and logits. """ return_dict = return_dict if return_dict is not None else self.config.use_return_dict # Encode with transformer outputs = self.encoder( input_ids=input_ids, attention_mask=attention_mask, return_dict=True, ) # Use [CLS] token representation pooled_output = outputs.last_hidden_state[:, 0, :] # Apply dropout pooled_output = self.dropout(pooled_output) # Classification head logits = self.classifier(pooled_output) # Compute loss if labels provided loss = None if labels is not None: loss_fct = nn.CrossEntropyLoss(weight=self.class_weights) loss = loss_fct(logits, labels) if not return_dict: output = (logits,) return ((loss,) + output) if loss is not None else output return SequenceClassifierOutput( loss=loss, logits=logits, hidden_states=None, attentions=None, ) def predict( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """ Make predictions. Args: input_ids: Input token IDs [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. Returns: Predicted labels [batch_size]. """ self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask) logits = outputs.logits label_predictions = torch.argmax(logits, dim=-1) return label_predictions def get_probabilities( self, input_ids: torch.Tensor, attention_mask: torch.Tensor, ) -> torch.Tensor: """ Get class probabilities. Args: input_ids: Input token IDs [batch_size, seq_len]. attention_mask: Attention mask [batch_size, seq_len]. Returns: Class probabilities [batch_size, num_labels]. """ self.eval() with torch.no_grad(): outputs = self.forward(input_ids, attention_mask) logits = outputs.logits probabilities = torch.softmax(logits, dim=-1) return probabilities def freeze_encoder(self): """Freeze encoder parameters (only train classification head).""" for param in self.encoder.parameters(): param.requires_grad = False def unfreeze_encoder(self): """Unfreeze encoder parameters.""" for param in self.encoder.parameters(): param.requires_grad = True def get_num_trainable_params(self) -> int: """Get number of trainable parameters.""" return sum(p.numel() for p in self.parameters() if p.requires_grad)