File size: 14,070 Bytes
ae47555
 
 
 
 
 
 
80b78df
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a770449
77bc910
f292cd1
6cf4c1f
ae47555
 
 
 
 
 
 
 
 
a770449
f292cd1
6cf4c1f
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
77bc910
f292cd1
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
cb428cb
 
 
 
 
 
 
 
 
 
 
 
6bfe0f4
cb428cb
 
 
 
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
a770449
ae47555
 
 
 
 
f292cd1
 
 
 
 
 
 
 
 
 
 
 
 
ae47555
 
 
 
 
 
 
 
80b78df
 
 
 
ae47555
80b78df
ae47555
80b78df
ae47555
 
 
 
 
 
 
 
77bc910
 
 
 
80b78df
ae47555
 
 
80b78df
d26ac21
 
 
 
ae47555
 
 
 
 
 
 
 
ed14b1f
ae47555
d26ac21
ae47555
99575b1
ae47555
 
 
 
 
 
 
 
bad72d7
 
ae47555
 
 
 
 
 
 
 
 
 
 
 
 
 
cb428cb
 
 
 
 
 
 
 
 
 
 
 
6bfe0f4
cb428cb
 
 
 
ae47555
 
 
f292cd1
 
 
 
 
 
 
 
99575b1
 
 
f292cd1
99575b1
f292cd1
 
6e2887d
 
ae47555
 
 
 
80b78df
 
 
 
ae47555
80b78df
 
bad72d7
80b78df
bad72d7
ae47555
bad72d7
ae47555
80b78df
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
import torch
import torch.nn as nn
import torch.nn.functional as F
import numpy as np
from tqdm import tqdm
import logging
import os
from sklearn.metrics import f1_score, accuracy_score, precision_score, recall_score

logger = logging.getLogger(__name__)

class DistillationTrainer:
    """
    Trainer for knowledge distillation from teacher model (BERT) to student model (LSTM)
    """
    def __init__(
        self, 
        teacher_model, 
        student_model,
        train_loader, 
        val_loader, 
        test_loader=None,
        temperature=2.0,
        alpha=0.5,  # Weight for distillation loss vs. regular loss
        lr=0.001,
        weight_decay=1e-5,
        max_grad_norm=1.0,
        label_mapping=None,
        num_categories=1,
        num_classes=2,
        device=None
    ):
        self.teacher_model = teacher_model
        self.student_model = student_model
        self.train_loader = train_loader
        self.val_loader = val_loader
        self.test_loader = test_loader
        self.temperature = temperature
        self.alpha = alpha
        self.max_grad_norm = max_grad_norm
        self.num_categories = num_categories
        self.num_classes = num_classes
        
        self.device = device if device else torch.device('cuda' if torch.cuda.is_available() else 'cpu')
        logger.info(f"Using device: {self.device}")
        
        # Move models to device
        self.teacher_model.to(self.device)
        self.student_model.to(self.device)
        
        # Set teacher model to evaluation mode
        self.teacher_model.eval()
        
        # Optimizer for student model
        self.optimizer = torch.optim.Adam(
            self.student_model.parameters(), 
            lr=lr, 
            weight_decay=weight_decay
        )
        
        # Learning rate scheduler
        self.scheduler = torch.optim.lr_scheduler.ReduceLROnPlateau(
            self.optimizer, mode='max', factor=0.5, patience=2, verbose=True
        )
        
        # Loss functions
        self.ce_loss = nn.CrossEntropyLoss()  # For hard targets
        
        # Tracking metrics
        self.best_val_f1 = 0.0
        self.best_model_state = None
        self.label_mapping = label_mapping

    
    def distillation_loss(self, student_logits, teacher_logits, labels, temperature, alpha):
        """
        Compute the knowledge distillation loss
        
        Args:
            student_logits: Output from student model
            teacher_logits: Output from teacher model
            labels: Ground truth labels
            temperature: Temperature for softening probability distributions
            alpha: Weight for distillation loss vs. cross-entropy loss
            
        Returns:
            Combined loss
        """
        # Softmax with temperature for soft targets
        soft_targets = F.softmax(teacher_logits / temperature, dim=1)
        soft_prob = F.log_softmax(student_logits / temperature, dim=1)
        
        # Distillation loss (KL divergence)
        distill_loss = F.kl_div(soft_prob, soft_targets, reduction='batchmean') * (temperature ** 2)
        
        # Standard cross entropy with hard targets
        if self.num_categories > 1:
            total_loss = 0
            for i in range(self.num_categories):
                start_idx = i * self.num_classes
                end_idx = (i + 1) * self.num_classes
                category_outputs = student_logits[:, start_idx:end_idx] # Shape (batch, num_classes)
                category_labels = labels[:, i] # Shape (batch)
                
                # Ensure category_labels are in [0, self.num_classes - 1]
                if category_labels.max() >= self.num_classes or category_labels.min() < 0:
                    print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
                    
                total_loss += self.ce_loss(category_outputs, category_labels)

            ce_loss = total_loss / self.num_categories # Average loss
        else:
            ce_loss = self.ce_loss(student_logits, labels)
        
        # Weighted combination of the two losses
        loss = alpha * distill_loss + (1 - alpha) * ce_loss
        
        return loss
    
    def train(self, epochs, save_path='best_distilled_model.pth'):
        """
        Train student model with knowledge distillation
        """
        logger.info(f"Starting distillation training for {epochs} epochs")
        logger.info(f"Temperature: {self.temperature}, Alpha: {self.alpha}")
        
        for epoch in range(epochs):
            self.student_model.train()
            train_loss = 0.0
            all_preds = []
            all_labels = []
            
            # Training loop
            train_iterator = tqdm(self.train_loader, desc=f"Epoch {epoch+1}/{epochs} [Train]")
            for batch in train_iterator:
                # Move batch to device
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Get teacher predictions (no grad needed for teacher)
                with torch.no_grad():
                    teacher_logits = self.teacher_model(
                        input_ids=input_ids,
                        attention_mask=attention_mask
                    )
                
                # Forward pass through student model
                student_logits = self.student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                
                # Calculate distillation loss
                loss = self.distillation_loss(
                    student_logits, 
                    teacher_logits, 
                    labels, 
                    self.temperature, 
                    self.alpha
                )
                
                # Backward and optimize
                self.optimizer.zero_grad()
                loss.backward()
                torch.nn.utils.clip_grad_norm_(self.student_model.parameters(), self.max_grad_norm)
                self.optimizer.step()
                
                train_loss += loss.item()
                
                # Calculate accuracy for progress tracking
                if self.num_categories > 1:
                    batch_size, total_classes = student_logits.shape
                    if total_classes % self.num_categories != 0:
                        raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")

                    classes_per_group = total_classes // self.num_categories
                    # Group every classes_per_group values along dim=1
                    reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group)  # shape: (batch, self., classes_per_group)

                    # Argmax over each group of classes_per_group
                    preds = reshaped.argmax(dim=-1)
                else:
                    _, preds = torch.max(student_logits, 1)
                all_preds.extend(preds.cpu().tolist())
                all_labels.extend(labels.cpu().tolist())
                
                # Update progress bar
                train_iterator.set_postfix({'loss': f"{loss.item():.4f}"})
            
            # Calculate training metrics
            train_loss = train_loss / len(self.train_loader)
            if self.num_categories > 1:
                all_labels = np.concatenate(all_labels, axis=0)
                all_preds = np.concatenate(all_preds, axis=0)
            #train_acc = sum(1 for p, l in zip(all_preds, all_labels) if p == l) / len(all_preds)
            
            train_acc = accuracy_score(all_labels, all_preds)
            # Evaluate on validation set
            val_loss, val_acc, val_precision, val_recall, val_f1 = self.evaluate()
            
            # Update learning rate based on validation performance
            self.scheduler.step(val_f1)
            
            # Save best model
            if val_f1 > self.best_val_f1:
                self.best_val_f1 = val_f1
                self.best_model_state = self.student_model.state_dict().copy()
                torch.save({
                    'model_state_dict': self.student_model.state_dict(),
                    'label_mapping': self.label_mapping,
                }, save_path)
                logger.info(f"New best model saved with validation F1: {val_f1:.4f}, accuracy: {val_acc:.4f}")
            
            logger.info(f"Epoch {epoch+1}/{epochs}: "
                      f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, "
                      f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}")
            
            print(f"Epoch {epoch+1}/{epochs}: ",
                      f"Train Loss: {train_loss:.4f}, Train Acc: {train_acc:.4f}, ",
                      f"Val Loss: {val_loss:.4f}, Val Acc: {val_acc:.4f}, Val Precision: {val_precision:.4f}, Val Recall: {val_recall:.4f}, Val F1: {val_f1:.4f}")
        
        # Load best model for final evaluation
        if self.best_model_state is not None:
            self.student_model.load_state_dict(self.best_model_state)
            logger.info(f"Loaded best model with validation F1: {self.best_val_f1:.4f}")
        
        # Final evaluation on test set if provided
        if self.test_loader:
            test_loss, test_acc, test_precision, test_recall, test_f1 = self.evaluate(self.test_loader, "Test")
            logger.info(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
            print(f"Test Loss: {test_loss:.4f}, Test Acc: {test_acc:.4f}, Test F1: {test_f1:.4f}")
    
    def evaluate(self, data_loader=None, phase="Validation", threshold=0.55):
        """
        Evaluate the student model
        """
        if data_loader is None:
            data_loader = self.val_loader
        
        self.student_model.eval()
        eval_loss = 0.0
        all_preds = np.array([], dtype=int)
        all_labels = np.array([], dtype=int)
        
        with torch.no_grad():
            for batch in tqdm(data_loader, desc=f"[{phase}]"):
                input_ids = batch['input_ids'].to(self.device)
                attention_mask = batch['attention_mask'].to(self.device)
                labels = batch['label'].to(self.device)
                
                # Forward pass through student
                student_logits = self.student_model(
                    input_ids=input_ids,
                    attention_mask=attention_mask
                )
                
                # Calculate regular CE loss (no distillation during evaluation)
                if self.num_categories > 1:
                    total_loss = 0
                    for i in range(self.num_categories):
                        start_idx = i * self.num_classes
                        end_idx = (i + 1) * self.num_classes
                        category_outputs = student_logits[:, start_idx:end_idx] # Shape (batch, num_classes)
                        category_labels = labels[:, i] # Shape (batch)
                        
                        # Ensure category_labels are in [0, self.num_classes - 1]
                        if category_labels.max() >= self.num_classes or category_labels.min() < 0:
                            print(f"ERROR: Category {i} labels out of range [0, {self.num_classes - 1}]: min={category_labels.min()}, max={category_labels.max()}")
                            
                        total_loss += self.ce_loss(category_outputs, category_labels)

                    loss = total_loss / self.num_categories # Average loss
                else:
                    loss = self.ce_loss(student_logits, labels)
                eval_loss += loss.item()
                
                # Get predictions
                if self.num_categories > 1:
                    batch_size, total_classes = student_logits.shape
                    if total_classes % self.num_categories != 0:
                        raise ValueError(f"Error: Number of total classes in the batch must of divisible by {self.num_categories}")

                    classes_per_group = total_classes // self.num_categories
                    # Group every classes_per_group values along dim=1
                    reshaped = student_logits.view(student_logits.size(0), -1, classes_per_group)  # shape: (batch, self., classes_per_group)
                    probs = F.softmax(reshaped, dim=1)
                    # Keep only the probs that are above the threshold (to prevent false positive), else set it to 0 (NORMAL, in this case unconclusive)
                    probs = torch.where(probs > threshold, probs, 0.0)
                    # Argmax over each group of classes_per_group
                    preds = probs.argmax(dim=-1)
                else:
                    _, preds = torch.max(student_logits, 1)
                all_preds = np.append(all_preds, preds.cpu().numpy())
                all_labels = np.append(all_labels, labels.cpu().numpy())
        
        # Calculate metrics
        eval_loss = eval_loss / len(data_loader)
        
        if self.num_categories > 1:
            # Concatenate all labels and predictions
            all_labels = np.concatenate(all_labels, axis=0)
            all_preds = np.concatenate(all_preds, axis=0)
        # Accuracy
        accuracy = accuracy_score(all_labels, all_preds)
        # Precision
        precision = precision_score(all_labels, all_preds, average='weighted')
        # Recall
        recall = recall_score(all_labels, all_preds, average='weighted')
        # F1 score (macro-averaged)
        f1 = f1_score(all_labels, all_preds, average='weighted')
        
        return eval_loss, accuracy, precision, recall, f1