File size: 7,310 Bytes
26f3ae9
 
b3578be
26f3ae9
 
 
 
b3578be
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb6dd8
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
abb6dd8
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
abb6dd8
 
 
 
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3578be
26f3ae9
b3578be
26f3ae9
 
 
 
 
 
 
b3578be
26f3ae9
 
 
b3578be
26f3ae9
b3578be
 
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3578be
 
 
 
 
 
 
 
 
 
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3578be
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b3578be
26f3ae9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
"""Sentiment classifier for text classification."""

from typing import Dict, Optional, Union

import torch
import torch.nn as nn
from transformers import AutoModel, PreTrainedModel
from transformers.modeling_outputs import SequenceClassifierOutput

# Handle imports for both local usage and HuggingFace Hub
try:
    from .configuration_sentiment import SentimentClassifierConfig
except ImportError:
    try:
        from configuration_sentiment import SentimentClassifierConfig
    except ImportError:
        from src.models.configuration_sentiment import SentimentClassifierConfig


class SentimentClassifier(PreTrainedModel):
    """
    Sentiment classifier for sequence classification.

    Outputs:
        Sentiment label (positive/neutral/negative) - classification
    """

    config_class = SentimentClassifierConfig

    def __init__(
        self,
        config: Optional[SentimentClassifierConfig] = None,
        pretrained_model: str = "xlm-roberta-base",
        num_labels: int = 3,
        dropout: float = 0.1,
        hidden_size: Optional[int] = None,
        class_weights: Optional[torch.Tensor] = None,
        use_flash_attention_2: bool = False,
        gradient_checkpointing: bool = False,
    ):
        """
        Initialize sentiment classifier.

        Args:
            config: Model configuration object.
            pretrained_model: Name of the pre-trained model.
            num_labels: Number of sentiment classes (default: 3).
            dropout: Dropout probability.
            hidden_size: Hidden size of the model (auto-detected if None).
            class_weights: Tensor of class weights for classification loss.
            use_flash_attention_2: Use Flash Attention 2 for faster attention (if available).
            gradient_checkpointing: Enable gradient checkpointing to save memory.
        """
        # Create config if not provided
        if config is None:
            config = SentimentClassifierConfig(
                pretrained_model=pretrained_model,
                num_labels=num_labels,
                dropout=dropout,
                hidden_size=hidden_size,
            )

        super().__init__(config)

        # Load pre-trained transformer with optional Flash Attention 2
        encoder_kwargs = {}
        if use_flash_attention_2:
            try:
                encoder_kwargs["attn_implementation"] = "flash_attention_2"
            except Exception:
                # Flash Attention 2 not available, will use default
                pass

        self.encoder = AutoModel.from_pretrained(config.pretrained_model, **encoder_kwargs)

        # Enable gradient checkpointing if requested (saves memory at cost of compute)
        if gradient_checkpointing:
            self.encoder.gradient_checkpointing_enable()

        # Get hidden size
        if config.hidden_size is None:
            config.hidden_size = self.encoder.config.hidden_size

        self.hidden_size = config.hidden_size
        self.num_labels = config.num_labels

        # Dropout
        self.dropout = nn.Dropout(config.dropout)

        # Classification head (sentiment label)
        self.classifier = nn.Linear(self.hidden_size, self.num_labels)

        # Class weights
        self.register_buffer(
            "class_weights",
            class_weights if class_weights is not None else torch.ones(self.num_labels),
        )

        # Initialize weights
        self.post_init()

    def _init_weights(self, module):
        """Initialize head weights."""
        if isinstance(module, nn.Linear):
            nn.init.xavier_uniform_(module.weight)
            if module.bias is not None:
                nn.init.zeros_(module.bias)

    def forward(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
        labels: Optional[torch.Tensor] = None,
        return_dict: Optional[bool] = None,
        **kwargs,
    ) -> Union[SequenceClassifierOutput, Dict[str, torch.Tensor]]:
        """
        Forward pass for classification.

        Args:
            input_ids: Input token IDs [batch_size, seq_len].
            attention_mask: Attention mask [batch_size, seq_len].
            labels: Ground truth sentiment labels [batch_size].
            return_dict: Whether to return a SequenceClassifierOutput or dict.
            **kwargs: Additional arguments.

        Returns:
            SequenceClassifierOutput or dictionary containing loss and logits.
        """
        return_dict = return_dict if return_dict is not None else self.config.use_return_dict

        # Encode with transformer
        outputs = self.encoder(
            input_ids=input_ids,
            attention_mask=attention_mask,
            return_dict=True,
        )

        # Use [CLS] token representation
        pooled_output = outputs.last_hidden_state[:, 0, :]

        # Apply dropout
        pooled_output = self.dropout(pooled_output)

        # Classification head
        logits = self.classifier(pooled_output)

        # Compute loss if labels provided
        loss = None
        if labels is not None:
            loss_fct = nn.CrossEntropyLoss(weight=self.class_weights)
            loss = loss_fct(logits, labels)

        if not return_dict:
            output = (logits,)
            return ((loss,) + output) if loss is not None else output

        return SequenceClassifierOutput(
            loss=loss,
            logits=logits,
            hidden_states=None,
            attentions=None,
        )

    def predict(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Make predictions.

        Args:
            input_ids: Input token IDs [batch_size, seq_len].
            attention_mask: Attention mask [batch_size, seq_len].

        Returns:
            Predicted labels [batch_size].
        """
        self.eval()

        with torch.no_grad():
            outputs = self.forward(input_ids, attention_mask)
            logits = outputs.logits
            label_predictions = torch.argmax(logits, dim=-1)

        return label_predictions

    def get_probabilities(
        self,
        input_ids: torch.Tensor,
        attention_mask: torch.Tensor,
    ) -> torch.Tensor:
        """
        Get class probabilities.

        Args:
            input_ids: Input token IDs [batch_size, seq_len].
            attention_mask: Attention mask [batch_size, seq_len].

        Returns:
            Class probabilities [batch_size, num_labels].
        """
        self.eval()

        with torch.no_grad():
            outputs = self.forward(input_ids, attention_mask)
            logits = outputs.logits
            probabilities = torch.softmax(logits, dim=-1)

        return probabilities

    def freeze_encoder(self):
        """Freeze encoder parameters (only train classification head)."""
        for param in self.encoder.parameters():
            param.requires_grad = False

    def unfreeze_encoder(self):
        """Unfreeze encoder parameters."""
        for param in self.encoder.parameters():
            param.requires_grad = True

    def get_num_trainable_params(self) -> int:
        """Get number of trainable parameters."""
        return sum(p.numel() for p in self.parameters() if p.requires_grad)