phi35-moe-multimodal / src /multi_agent_training /modular_distillation_loss.py
Mango-Metrics-NLM
feat: Phi-3.5-MoE multi-agent model repository
c8b77b5
"""
Refactored Knowledge Distillation Loss using modular architecture.
This module implements a clean, testable loss function that follows the interface contracts
and provides better separation of concerns.
"""
import logging
from typing import Any, Dict, Optional
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..core.base_components import BaseLossFunction
from ..core.exceptions import TrainingError
logger = logging.getLogger(__name__)
class ModularDistillationLoss(BaseLossFunction):
"""
Modular distillation loss function implementing clean interface contracts.
"""
def __init__(
self,
alpha: float = 0.5,
temperature: float = 3.0,
task_loss_fn: Optional[nn.Module] = None,
):
"""
Initialize the modular distillation loss.
Args:
alpha: Balance between task loss and distillation loss (0.0-1.0)
temperature: Temperature for softmax in knowledge distillation
task_loss_fn: Custom task loss function (defaults to CrossEntropyLoss)
"""
super().__init__({"alpha": alpha, "temperature": temperature})
if not 0.0 <= alpha <= 1.0:
raise ValueError(f"Alpha must be between 0 and 1, got {alpha}")
if temperature <= 0.0:
raise ValueError(f"Temperature must be positive, got {temperature}")
self.alpha = alpha
self.temperature = temperature
self.task_loss_fn = task_loss_fn or nn.CrossEntropyLoss(ignore_index=-100)
self.kl_loss_fn = nn.KLDivLoss(reduction="batchmean")
logger.info(
f"Initialized ModularDistillationLoss with alpha={alpha}, temperature={temperature}"
)
def compute(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: Optional[torch.Tensor] = None,
attention_mask: Optional[torch.Tensor] = None,
**kwargs,
) -> torch.Tensor:
"""
Compute the combined distillation loss.
Args:
student_logits: Logits from student model
teacher_logits: Logits from teacher model
labels: Target labels for task loss
attention_mask: Attention mask for valid tokens
Returns:
Combined loss tensor
"""
try:
# Compute task loss (if labels provided)
task_loss = torch.tensor(0.0, device=student_logits.device)
if labels is not None:
task_loss = self._compute_task_loss(
student_logits, labels, attention_mask
)
# Compute distillation loss
distill_loss = self._compute_distillation_loss(
student_logits, teacher_logits, attention_mask
)
# Combined loss
total_loss = self.alpha * task_loss + (1 - self.alpha) * distill_loss
# Track metrics
metrics = {
"task_loss": task_loss.item(),
"distillation_loss": distill_loss.item(),
"total_loss": total_loss.item(),
"alpha": self.alpha,
"temperature": self.temperature,
}
self._track_metrics(metrics)
self._track_loss(total_loss.item())
return total_loss
except Exception as e:
raise TrainingError(
f"Loss computation failed: {str(e)}", "LOSS_COMPUTATION_ERROR"
)
def _compute_task_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute standard cross-entropy task loss."""
# Reshape for cross entropy: [batch_size * seq_len, vocab_size]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_labels = shift_labels.view(-1)
# Apply attention mask if provided
if attention_mask is not None:
shift_attention_mask = attention_mask[..., 1:].contiguous()
flat_attention_mask = shift_attention_mask.view(-1)
# Only compute loss on non-masked tokens
active_logits = flat_logits[flat_attention_mask == 1]
active_labels = flat_labels[flat_attention_mask == 1]
if active_logits.size(0) > 0:
return self.task_loss_fn(active_logits, active_labels)
else:
return torch.tensor(0.0, device=logits.device)
return self.task_loss_fn(flat_logits, flat_labels)
def _compute_distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute knowledge distillation loss using KL divergence."""
# Apply temperature scaling
student_soft = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_soft = F.softmax(teacher_logits / self.temperature, dim=-1)
# Compute KL divergence
if attention_mask is not None:
# Mask out padding tokens
mask = attention_mask.unsqueeze(-1).expand_as(student_soft)
student_soft_masked = student_soft * mask
teacher_soft_masked = teacher_soft * mask
# Compute loss only on valid tokens
kl_loss = self.kl_loss_fn(student_soft_masked, teacher_soft_masked)
else:
kl_loss = self.kl_loss_fn(student_soft, teacher_soft)
# Scale by temperature squared (as per distillation literature)
return kl_loss * (self.temperature**2)
def update_alpha(self, new_alpha: float) -> None:
"""Update the alpha parameter for dynamic loss weighting."""
if not 0.0 <= new_alpha <= 1.0:
raise ValueError(f"Alpha must be between 0 and 1, got {new_alpha}")
old_alpha = self.alpha
self.alpha = new_alpha
self.config["alpha"] = new_alpha
logger.info(f"Updated alpha from {old_alpha} to {new_alpha}")
def update_temperature(self, new_temperature: float) -> None:
"""Update the temperature parameter for dynamic distillation."""
if new_temperature <= 0.0:
raise ValueError(f"Temperature must be positive, got {new_temperature}")
old_temperature = self.temperature
self.temperature = new_temperature
self.config["temperature"] = new_temperature
logger.info(f"Updated temperature from {old_temperature} to {new_temperature}")
class AdaptiveDistillationLoss(ModularDistillationLoss):
"""
Adaptive distillation loss that adjusts alpha based on training progress.
"""
def __init__(
self,
alpha: float = 0.5,
temperature: float = 3.0,
adaptation_strategy: str = "linear_decay",
adaptation_config: Optional[Dict[str, Any]] = None,
):
"""
Initialize adaptive distillation loss.
Args:
alpha: Initial alpha value
temperature: Temperature for distillation
adaptation_strategy: Strategy for adapting alpha ("linear_decay", "cosine_decay", "step_decay")
adaptation_config: Configuration for adaptation strategy
"""
super().__init__(alpha, temperature)
self.initial_alpha = alpha
self.adaptation_strategy = adaptation_strategy
self.adaptation_config = adaptation_config or {}
self.step_count = 0
logger.info(
f"Initialized AdaptiveDistillationLoss with strategy: {adaptation_strategy}"
)
def compute(self, *args, **kwargs) -> torch.Tensor:
"""Compute loss with adaptive alpha adjustment."""
# Update alpha based on training progress
self._update_alpha_adaptive()
# Increment step count
self.step_count += 1
return super().compute(*args, **kwargs)
def _update_alpha_adaptive(self) -> None:
"""Update alpha based on the selected adaptation strategy."""
total_steps = self.adaptation_config.get("total_steps", 1000)
if self.adaptation_strategy == "linear_decay":
# Linearly decay alpha from initial value to 0
progress = min(self.step_count / total_steps, 1.0)
new_alpha = self.initial_alpha * (1.0 - progress)
elif self.adaptation_strategy == "cosine_decay":
# Cosine decay
import math
progress = min(self.step_count / total_steps, 1.0)
new_alpha = self.initial_alpha * (1 + math.cos(math.pi * progress)) / 2
elif self.adaptation_strategy == "step_decay":
# Step decay at specified intervals
decay_steps = self.adaptation_config.get("decay_steps", [500, 750])
decay_factor = self.adaptation_config.get("decay_factor", 0.5)
new_alpha = self.initial_alpha
for decay_step in decay_steps:
if self.step_count >= decay_step:
new_alpha *= decay_factor
else:
# No adaptation
new_alpha = self.alpha
# Update if changed significantly
if abs(new_alpha - self.alpha) > 1e-6:
self.alpha = new_alpha