| """ |
| Knowledge Distillation Loss Implementation for MangoMAS Local |
| |
| This module implements custom loss functions for knowledge distillation, |
| balancing task-specific loss with knowledge transfer from teacher models. |
| """ |
|
|
| import logging |
| from typing import Optional, Tuple |
|
|
| import torch |
| import torch.nn.functional as F |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class DistillationLoss: |
| """ |
| Custom loss function for knowledge distillation combining task loss |
| and distillation loss with configurable temperature and alpha parameters. |
| """ |
|
|
| def __init__(self, alpha: float = 0.5, temperature: float = 3.0): |
| """ |
| Initialize distillation loss with parameters. |
| |
| Args: |
| alpha: Balance between task loss and distillation loss (0.0-1.0) |
| temperature: Temperature for softmax in knowledge distillation |
| """ |
| self.alpha = alpha |
| self.temperature = temperature |
|
|
| if not 0.0 <= alpha <= 1.0: |
| raise ValueError(f"Alpha must be between 0 and 1, got {alpha}") |
| if temperature <= 0.0: |
| raise ValueError(f"Temperature must be positive, got {temperature}") |
|
|
| |
| import torch.nn as nn |
|
|
| self.task_loss = nn.CrossEntropyLoss(ignore_index=-100) |
| self.kl_loss = nn.KLDivLoss(reduction="batchmean") |
|
|
| logger.info( |
| f"Initialized DistillationLoss with alpha={alpha}, temperature={temperature}" |
| ) |
|
|
| def compute_loss( |
| self, |
| student_logits: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| labels: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> Tuple[torch.Tensor, dict]: |
| """ |
| Compute the combined distillation loss. |
| |
| Args: |
| student_logits: Logits from student model [batch_size, seq_len, vocab_size] |
| teacher_logits: Logits from teacher model [batch_size, seq_len, vocab_size] |
| labels: Target labels [batch_size, seq_len] |
| attention_mask: Attention mask for padding tokens [batch_size, seq_len] |
| |
| Returns: |
| Tuple of (total_loss, loss_dict) where loss_dict contains individual losses |
| """ |
| |
| task_loss = self._compute_task_loss(student_logits, labels, attention_mask) |
|
|
| |
| distill_loss = self._compute_distillation_loss( |
| student_logits, teacher_logits, attention_mask |
| ) |
|
|
| |
| |
| |
| total_loss = self.alpha * distill_loss + (1.0 - self.alpha) * task_loss |
|
|
| loss_dict = { |
| "total_loss": total_loss.item(), |
| "task_loss": task_loss.item(), |
| "distillation_loss": distill_loss.item(), |
| "alpha": self.alpha, |
| "temperature": self.temperature, |
| } |
|
|
| return total_loss, loss_dict |
|
|
| def _compute_task_loss( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Compute standard cross-entropy task loss.""" |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| flat_labels = shift_labels.view(-1) |
|
|
| |
| if attention_mask is not None: |
| shift_mask = attention_mask[..., 1:].contiguous() |
| flat_mask = shift_mask.view(-1) |
| |
| valid_indices = flat_mask.bool() |
| if valid_indices.sum() == 0: |
| return torch.tensor(0.0, device=logits.device) |
| flat_logits = flat_logits[valid_indices] |
| flat_labels = flat_labels[valid_indices] |
|
|
| |
| task_loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100) |
|
|
| return task_loss |
|
|
| def _compute_distillation_loss( |
| self, |
| student_logits: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| eps = 1e-8 |
| |
| student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1) |
| teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1) |
| |
| if attention_mask is not None: |
| mask = attention_mask[..., 1:].contiguous().view(-1) |
| student_log_probs = ( |
| student_log_probs[..., :-1, :] |
| .contiguous() |
| .view(-1, student_log_probs.size(-1))[mask.bool()] |
| ) |
| teacher_log_probs = ( |
| teacher_log_probs[..., :-1, :] |
| .contiguous() |
| .view(-1, teacher_log_probs.size(-1))[mask.bool()] |
| ) |
| else: |
| student_log_probs = ( |
| student_log_probs[..., :-1, :] |
| .contiguous() |
| .view(-1, student_log_probs.size(-1)) |
| ) |
| teacher_log_probs = ( |
| teacher_log_probs[..., :-1, :] |
| .contiguous() |
| .view(-1, teacher_log_probs.size(-1)) |
| ) |
| |
| if student_log_probs.shape[0] == 0: |
| return torch.tensor(0.0, device=student_logits.device) |
| distill_loss = F.kl_div( |
| student_log_probs, teacher_log_probs, reduction="batchmean", log_target=True |
| ) * (self.temperature**2) |
| return distill_loss |
|
|
| def update_alpha(self, new_alpha: float) -> None: |
| """Update the alpha parameter during training.""" |
| if not 0.0 <= new_alpha <= 1.0: |
| raise ValueError(f"Alpha must be between 0.0 and 1.0, got {new_alpha}") |
| self.alpha = new_alpha |
| logger.info(f"Updated alpha to {new_alpha}") |
|
|
| def update_temperature(self, new_temperature: float) -> None: |
| """Update the temperature parameter during training.""" |
| if new_temperature <= 0.0: |
| raise ValueError(f"Temperature must be positive, got {new_temperature}") |
| self.temperature = new_temperature |
| logger.info(f"Updated temperature to {new_temperature}") |
|
|
| def __call__(self, student_logits, teacher_logits, labels, attention_mask=None): |
| total_loss, _ = self.compute_loss( |
| student_logits, teacher_logits, labels, attention_mask |
| ) |
| return total_loss |
|
|
|
|
| class AdaptiveDistillationLoss(DistillationLoss): |
| """ |
| Adaptive distillation loss that adjusts alpha based on training progress. |
| Starts with more focus on distillation, gradually shifting to task loss. |
| """ |
|
|
| def __init__( |
| self, |
| initial_alpha: float = 0.5, |
| final_alpha: float = 0.1, |
| temperature: float = 3.0, |
| warmup_steps: int = 1000, |
| ): |
| if not 0.0 <= initial_alpha <= 1.0: |
| raise ValueError("initial_alpha must be between 0 and 1") |
| if not 0.0 <= final_alpha <= 1.0: |
| raise ValueError("final_alpha must be between 0 and 1") |
| if initial_alpha < final_alpha: |
| raise ValueError( |
| f"initial_alpha must be >= final_alpha, got {initial_alpha} < {final_alpha}" |
| ) |
| super().__init__(alpha=initial_alpha, temperature=temperature) |
| self.initial_alpha = initial_alpha |
| self.final_alpha = final_alpha |
| self.current_alpha = initial_alpha |
| self.warmup_steps = warmup_steps |
| self.current_step = 0 |
|
|
| def update_alpha(self, current_epoch: int, total_epochs: int): |
| """Update alpha based on current epoch and total epochs.""" |
| if total_epochs <= 0: |
| raise ValueError("total_epochs must be positive") |
| if current_epoch < 0: |
| raise ValueError("current_epoch must be non-negative") |
| if current_epoch >= total_epochs: |
| self.current_alpha = self.final_alpha |
| elif current_epoch <= 0: |
| self.current_alpha = self.initial_alpha |
| else: |
| progress = current_epoch / total_epochs |
| self.current_alpha = ( |
| self.initial_alpha - (self.initial_alpha - self.final_alpha) * progress |
| ) |
| self.alpha = self.current_alpha |
|
|
| def get_alpha(self): |
| """Return the current alpha value.""" |
| return self.current_alpha |
|
|