File size: 7,310 Bytes
26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 abb6dd8 26f3ae9 abb6dd8 26f3ae9 abb6dd8 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 b3578be 26f3ae9 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
"""Sentiment classifier for text classification."""
from typing import Dict, Optional, Union
import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput
# Handle imports for both local usage and HuggingFace Hub
try:
from .configuration_sentiment import SentimentClassifierConfig
except ImportError:
try:
from configuration_sentiment import SentimentClassifierConfig
except ImportError:
from src.models.configuration_sentiment import SentimentClassifierConfig
class SentimentClassifier(PreTrainedModel):
"""
Sentiment classifier for sequence classification.
Outputs:
Sentiment label (positive/neutral/negative) - classification
"""
config_class = SentimentClassifierConfig
def __init__(
self,
config: Optional[SentimentClassifierConfig] = None,
pretrained_model: str = "xlm-roberta-base",
num_labels: int = 3,
dropout: float = 0.1,
hidden_size: Optional[int] = None,
class_weights: Optional[torch.Tensor] = None,
use_flash_attention_2: bool = False,
gradient_checkpointing: bool = False,
):
"""
Initialize sentiment classifier.
Args:
config: Model configuration object.
pretrained_model: Name of the pre-trained model.
num_labels: Number of sentiment classes (default: 3).
dropout: Dropout probability.
hidden_size: Hidden size of the model (auto-detected if None).
class_weights: Tensor of class weights for classification loss.
use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
gradient_checkpointing: Enable gradient checkpointing to save memory.
"""
# Create config if not provided
if config is None:
config = SentimentClassifierConfig(
pretrained_model=pretrained_model,
num_labels=num_labels,
dropout=dropout,
hidden_size=hidden_size,
)
super().__init__(config)
# Load pre-trained transformer with optional Flash Attention 2
encoder_kwargs = {}
if use_flash_attention_2:
try:
encoder_kwargs["attn_implementation"] = "flash_attention_2"
except Exception:
# Flash Attention 2 not available, will use default
pass
self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)
# Enable gradient checkpointing if requested (saves memory at cost of compute)
if gradient_checkpointing:
self.encoder.gradient_checkpointing_enable()
# Get hidden size
if config.hidden_size is None:
config.hidden_size = self.encoder.config.hidden_size
self.hidden_size = config.hidden_size
self.num_labels = config.num_labels
# Dropout
self.dropout = nn.Dropout(config.dropout)
# Classification head (sentiment label)
self.classifier = nn.Linear(self.hidden_size, self.num_labels)
# Class weights
self.register_buffer(
"class_weights",
class_weights if class_weights is not None else torch.ones(self.num_labels),
)
# Initialize weights
self.post_init()
def _init_weights(self, module):
"""Initialize head weights."""
if isinstance(module, nn.Linear):
nn.init.xavier_uniform_(module.weight)
if module.bias is not None:
nn.init.zeros_(module.bias)
def forward(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
labels: Optional[torch.Tensor] = None,
return_dict: Optional[bool] = None,
**kwargs,
) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]:
"""
Forward pass for classification.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
labels: Ground truth sentiment labels [batch_size].
return_dict: Whether to return a SequenceClassifierOutput or dict.
**kwargs: Additional arguments.
Returns:
SequenceClassifierOutput or dictionary containing loss and logits.
"""
return_dict = return_dict if return_dict is not None else self.config.use_return_dict
# Encode with transformer
outputs = self.encoder(
input_ids=input_ids,
attention_mask=attention_mask,
return_dict=True,
)
# Use [CLS] token representation
pooled_output = outputs.last_hidden_state[:, 0, :]
# Apply dropout
pooled_output = self.dropout(pooled_output)
# Classification head
logits = self.classifier(pooled_output)
# Compute loss if labels provided
loss = None
if labels is not None:
loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
loss = loss_fct(logits, labels)
if not return_dict:
output = (logits,)
return ((loss,) + output) if loss is not None else output
return SequenceClassifierOutput(
loss=loss,
logits=logits,
hidden_states=None,
attentions=None,
)
def predict(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""
Make predictions.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
Predicted labels [batch_size].
"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
logits = outputs.logits
label_predictions = torch.argmax(logits, dim=-1)
return label_predictions
def get_probabilities(
self,
input_ids: torch.Tensor,
attention_mask: torch.Tensor,
) -> torch.Tensor:
"""
Get class probabilities.
Args:
input_ids: Input token IDs [batch_size, seq_len].
attention_mask: Attention mask [batch_size, seq_len].
Returns:
Class probabilities [batch_size, num_labels].
"""
self.eval()
with torch.no_grad():
outputs = self.forward(input_ids, attention_mask)
logits = outputs.logits
probabilities = torch.softmax(logits, dim=-1)
return probabilities
def freeze_encoder(self):
"""Freeze encoder parameters (only train classification head)."""
for param in self.encoder.parameters():
param.requires_grad = False
def unfreeze_encoder(self):
"""Unfreeze encoder parameters."""
for param in self.encoder.parameters():
param.requires_grad = True
def get_num_trainable_params(self) -> int:
"""Get number of trainable parameters."""
return sum(p.numel() for p in self.parameters() if p.requires_grad)
|