| """ |
| Refactored Knowledge Distillation Loss using modular architecture. |
| |
| This module implements a clean, testable loss function that follows the interface contracts |
| and provides better separation of concerns. |
| """ |
|
|
| import logging |
| from typing import Any, Dict, Optional |
|
|
| import torch |
| import torch.nn as nn |
| import torch.nn.functional as F |
|
|
| from ..core.base_components import BaseLossFunction |
| from ..core.exceptions import TrainingError |
|
|
| logger = logging.getLogger(__name__) |
|
|
|
|
| class ModularDistillationLoss(BaseLossFunction): |
| """ |
| Modular distillation loss function implementing clean interface contracts. |
| """ |
|
|
| def __init__( |
| self, |
| alpha: float = 0.5, |
| temperature: float = 3.0, |
| task_loss_fn: Optional[nn.Module] = None, |
| ): |
| """ |
| Initialize the modular distillation loss. |
| |
| Args: |
| alpha: Balance between task loss and distillation loss (0.0-1.0) |
| temperature: Temperature for softmax in knowledge distillation |
| task_loss_fn: Custom task loss function (defaults to CrossEntropyLoss) |
| """ |
| super().__init__({"alpha": alpha, "temperature": temperature}) |
|
|
| if not 0.0 <= alpha <= 1.0: |
| raise ValueError(f"Alpha must be between 0 and 1, got {alpha}") |
| if temperature <= 0.0: |
| raise ValueError(f"Temperature must be positive, got {temperature}") |
|
|
| self.alpha = alpha |
| self.temperature = temperature |
| self.task_loss_fn = task_loss_fn or nn.CrossEntropyLoss(ignore_index=-100) |
| self.kl_loss_fn = nn.KLDivLoss(reduction="batchmean") |
|
|
| logger.info( |
| f"Initialized ModularDistillationLoss with alpha={alpha}, temperature={temperature}" |
| ) |
|
|
| def compute( |
| self, |
| student_logits: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| labels: Optional[torch.Tensor] = None, |
| attention_mask: Optional[torch.Tensor] = None, |
| **kwargs, |
| ) -> torch.Tensor: |
| """ |
| Compute the combined distillation loss. |
| |
| Args: |
| student_logits: Logits from student model |
| teacher_logits: Logits from teacher model |
| labels: Target labels for task loss |
| attention_mask: Attention mask for valid tokens |
| |
| Returns: |
| Combined loss tensor |
| """ |
| try: |
| |
| task_loss = torch.tensor(0.0, device=student_logits.device) |
| if labels is not None: |
| task_loss = self._compute_task_loss( |
| student_logits, labels, attention_mask |
| ) |
|
|
| |
| distill_loss = self._compute_distillation_loss( |
| student_logits, teacher_logits, attention_mask |
| ) |
|
|
| |
| total_loss = self.alpha * task_loss + (1 - self.alpha) * distill_loss |
|
|
| |
| metrics = { |
| "task_loss": task_loss.item(), |
| "distillation_loss": distill_loss.item(), |
| "total_loss": total_loss.item(), |
| "alpha": self.alpha, |
| "temperature": self.temperature, |
| } |
| self._track_metrics(metrics) |
| self._track_loss(total_loss.item()) |
|
|
| return total_loss |
|
|
| except Exception as e: |
| raise TrainingError( |
| f"Loss computation failed: {str(e)}", "LOSS_COMPUTATION_ERROR" |
| ) |
|
|
| def _compute_task_loss( |
| self, |
| logits: torch.Tensor, |
| labels: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Compute standard cross-entropy task loss.""" |
| |
| shift_logits = logits[..., :-1, :].contiguous() |
| shift_labels = labels[..., 1:].contiguous() |
|
|
| |
| flat_logits = shift_logits.view(-1, shift_logits.size(-1)) |
| flat_labels = shift_labels.view(-1) |
|
|
| |
| if attention_mask is not None: |
| shift_attention_mask = attention_mask[..., 1:].contiguous() |
| flat_attention_mask = shift_attention_mask.view(-1) |
|
|
| |
| active_logits = flat_logits[flat_attention_mask == 1] |
| active_labels = flat_labels[flat_attention_mask == 1] |
|
|
| if active_logits.size(0) > 0: |
| return self.task_loss_fn(active_logits, active_labels) |
| else: |
| return torch.tensor(0.0, device=logits.device) |
|
|
| return self.task_loss_fn(flat_logits, flat_labels) |
|
|
| def _compute_distillation_loss( |
| self, |
| student_logits: torch.Tensor, |
| teacher_logits: torch.Tensor, |
| attention_mask: Optional[torch.Tensor] = None, |
| ) -> torch.Tensor: |
| """Compute knowledge distillation loss using KL divergence.""" |
| |
| student_soft = F.log_softmax(student_logits / self.temperature, dim=-1) |
| teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1) |
|
|
| |
| if attention_mask is not None: |
| |
| mask = attention_mask.unsqueeze(-1).expand_as(student_soft) |
| student_soft_masked = student_soft * mask |
| teacher_soft_masked = teacher_soft * mask |
|
|
| |
| kl_loss = self.kl_loss_fn(student_soft_masked, teacher_soft_masked) |
| else: |
| kl_loss = self.kl_loss_fn(student_soft, teacher_soft) |
|
|
| |
| return kl_loss * (self.temperature**2) |
|
|
| def update_alpha(self, new_alpha: float) -> None: |
| """Update the alpha parameter for dynamic loss weighting.""" |
| if not 0.0 <= new_alpha <= 1.0: |
| raise ValueError(f"Alpha must be between 0 and 1, got {new_alpha}") |
|
|
| old_alpha = self.alpha |
| self.alpha = new_alpha |
| self.config["alpha"] = new_alpha |
|
|
| logger.info(f"Updated alpha from {old_alpha} to {new_alpha}") |
|
|
| def update_temperature(self, new_temperature: float) -> None: |
| """Update the temperature parameter for dynamic distillation.""" |
| if new_temperature <= 0.0: |
| raise ValueError(f"Temperature must be positive, got {new_temperature}") |
|
|
| old_temperature = self.temperature |
| self.temperature = new_temperature |
| self.config["temperature"] = new_temperature |
|
|
| logger.info(f"Updated temperature from {old_temperature} to {new_temperature}") |
|
|
|
|
| class AdaptiveDistillationLoss(ModularDistillationLoss): |
| """ |
| Adaptive distillation loss that adjusts alpha based on training progress. |
| """ |
|
|
| def __init__( |
| self, |
| alpha: float = 0.5, |
| temperature: float = 3.0, |
| adaptation_strategy: str = "linear_decay", |
| adaptation_config: Optional[Dict[str, Any]] = None, |
| ): |
| """ |
| Initialize adaptive distillation loss. |
| |
| Args: |
| alpha: Initial alpha value |
| temperature: Temperature for distillation |
| adaptation_strategy: Strategy for adapting alpha ("linear_decay", "cosine_decay", "step_decay") |
| adaptation_config: Configuration for adaptation strategy |
| """ |
| super().__init__(alpha, temperature) |
|
|
| self.initial_alpha = alpha |
| self.adaptation_strategy = adaptation_strategy |
| self.adaptation_config = adaptation_config or {} |
| self.step_count = 0 |
|
|
| logger.info( |
| f"Initialized AdaptiveDistillationLoss with strategy: {adaptation_strategy}" |
| ) |
|
|
| def compute(self, *args, **kwargs) -> torch.Tensor: |
| """Compute loss with adaptive alpha adjustment.""" |
| |
| self._update_alpha_adaptive() |
|
|
| |
| self.step_count += 1 |
|
|
| return super().compute(*args, **kwargs) |
|
|
| def _update_alpha_adaptive(self) -> None: |
| """Update alpha based on the selected adaptation strategy.""" |
| total_steps = self.adaptation_config.get("total_steps", 1000) |
|
|
| if self.adaptation_strategy == "linear_decay": |
| |
| progress = min(self.step_count / total_steps, 1.0) |
| new_alpha = self.initial_alpha * (1.0 - progress) |
|
|
| elif self.adaptation_strategy == "cosine_decay": |
| |
| import math |
|
|
| progress = min(self.step_count / total_steps, 1.0) |
| new_alpha = self.initial_alpha * (1 + math.cos(math.pi * progress)) / 2 |
|
|
| elif self.adaptation_strategy == "step_decay": |
| |
| decay_steps = self.adaptation_config.get("decay_steps", [500, 750]) |
| decay_factor = self.adaptation_config.get("decay_factor", 0.5) |
|
|
| new_alpha = self.initial_alpha |
| for decay_step in decay_steps: |
| if self.step_count >= decay_step: |
| new_alpha *= decay_factor |
| else: |
| |
| new_alpha = self.alpha |
|
|
| |
| if abs(new_alpha - self.alpha) > 1e-6: |
| self.alpha = new_alpha |
|
|