Mango-Metrics-NLM
feat: Phi-3.5-MoE multi-agent model repository
c8b77b5
"""
Knowledge Distillation Loss Implementation for MangoMAS Local
This module implements custom loss functions for knowledge distillation,
balancing task-specific loss with knowledge transfer from teacher models.
"""
import logging
from typing import Optional, Tuple
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class DistillationLoss:
"""
Custom loss function for knowledge distillation combining task loss
and distillation loss with configurable temperature and alpha parameters.
"""
def __init__(self, alpha: float = 0.5, temperature: float = 3.0):
"""
Initialize distillation loss with parameters.
Args:
alpha: Balance between task loss and distillation loss (0.0-1.0)
temperature: Temperature for softmax in knowledge distillation
"""
self.alpha = alpha
self.temperature = temperature
if not 0.0 <= alpha <= 1.0:
raise ValueError(f"Alpha must be between 0 and 1, got {alpha}")
if temperature <= 0.0:
raise ValueError(f"Temperature must be positive, got {temperature}")
# Add required loss attributes for tests
import torch.nn as nn
self.task_loss = nn.CrossEntropyLoss(ignore_index=-100)
self.kl_loss = nn.KLDivLoss(reduction="batchmean")
logger.info(
f"Initialized DistillationLoss with alpha={alpha}, temperature={temperature}"
)
def compute_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
labels: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> Tuple[torch.Tensor, dict]:
"""
Compute the combined distillation loss.
Args:
student_logits: Logits from student model [batch_size, seq_len, vocab_size]
teacher_logits: Logits from teacher model [batch_size, seq_len, vocab_size]
labels: Target labels [batch_size, seq_len]
attention_mask: Attention mask for padding tokens [batch_size, seq_len]
Returns:
Tuple of (total_loss, loss_dict) where loss_dict contains individual losses
"""
# Task-specific loss (standard cross-entropy)
task_loss = self._compute_task_loss(student_logits, labels, attention_mask)
# Knowledge distillation loss
distill_loss = self._compute_distillation_loss(
student_logits, teacher_logits, attention_mask
)
# Combined loss
# NOTE: alpha is treated as the weight for the distillation loss to match
# unit-test expectations (alpha=1.0 => pure distillation loss).
total_loss = self.alpha * distill_loss + (1.0 - self.alpha) * task_loss
loss_dict = {
"total_loss": total_loss.item(),
"task_loss": task_loss.item(),
"distillation_loss": distill_loss.item(),
"alpha": self.alpha,
"temperature": self.temperature,
}
return total_loss, loss_dict
def _compute_task_loss(
self,
logits: torch.Tensor,
labels: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
"""Compute standard cross-entropy task loss."""
# Reshape for cross entropy: [batch_size * seq_len, vocab_size]
shift_logits = logits[..., :-1, :].contiguous()
shift_labels = labels[..., 1:].contiguous()
# Flatten the tokens
flat_logits = shift_logits.view(-1, shift_logits.size(-1))
flat_labels = shift_labels.view(-1)
# Apply attention mask if provided
if attention_mask is not None:
shift_mask = attention_mask[..., 1:].contiguous()
flat_mask = shift_mask.view(-1)
# Only compute loss for non-padded tokens
valid_indices = flat_mask.bool()
if valid_indices.sum() == 0:
return torch.tensor(0.0, device=logits.device)
flat_logits = flat_logits[valid_indices]
flat_labels = flat_labels[valid_indices]
# Compute cross entropy loss, ignoring padding tokens
task_loss = F.cross_entropy(flat_logits, flat_labels, ignore_index=-100)
return task_loss
def _compute_distillation_loss(
self,
student_logits: torch.Tensor,
teacher_logits: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
) -> torch.Tensor:
eps = 1e-8
# Apply temperature scaling
student_log_probs = F.log_softmax(student_logits / self.temperature, dim=-1)
teacher_log_probs = F.log_softmax(teacher_logits / self.temperature, dim=-1)
# Optionally mask out padding tokens
if attention_mask is not None:
mask = attention_mask[..., 1:].contiguous().view(-1)
student_log_probs = (
student_log_probs[..., :-1, :]
.contiguous()
.view(-1, student_log_probs.size(-1))[mask.bool()]
)
teacher_log_probs = (
teacher_log_probs[..., :-1, :]
.contiguous()
.view(-1, teacher_log_probs.size(-1))[mask.bool()]
)
else:
student_log_probs = (
student_log_probs[..., :-1, :]
.contiguous()
.view(-1, student_log_probs.size(-1))
)
teacher_log_probs = (
teacher_log_probs[..., :-1, :]
.contiguous()
.view(-1, teacher_log_probs.size(-1))
)
# KLDivLoss expects log-probabilities for both input and target if log_target=True
if student_log_probs.shape[0] == 0:
return torch.tensor(0.0, device=student_logits.device)
distill_loss = F.kl_div(
student_log_probs, teacher_log_probs, reduction="batchmean", log_target=True
) * (self.temperature**2)
return distill_loss
def update_alpha(self, new_alpha: float) -> None:
"""Update the alpha parameter during training."""
if not 0.0 <= new_alpha <= 1.0:
raise ValueError(f"Alpha must be between 0.0 and 1.0, got {new_alpha}")
self.alpha = new_alpha
logger.info(f"Updated alpha to {new_alpha}")
def update_temperature(self, new_temperature: float) -> None:
"""Update the temperature parameter during training."""
if new_temperature <= 0.0:
raise ValueError(f"Temperature must be positive, got {new_temperature}")
self.temperature = new_temperature
logger.info(f"Updated temperature to {new_temperature}")
def __call__(self, student_logits, teacher_logits, labels, attention_mask=None):
total_loss, _ = self.compute_loss(
student_logits, teacher_logits, labels, attention_mask
)
return total_loss
class AdaptiveDistillationLoss(DistillationLoss):
"""
Adaptive distillation loss that adjusts alpha based on training progress.
Starts with more focus on distillation, gradually shifting to task loss.
"""
def __init__(
self,
initial_alpha: float = 0.5,
final_alpha: float = 0.1,
temperature: float = 3.0,
warmup_steps: int = 1000,
):
if not 0.0 <= initial_alpha <= 1.0:
raise ValueError("initial_alpha must be between 0 and 1")
if not 0.0 <= final_alpha <= 1.0:
raise ValueError("final_alpha must be between 0 and 1")
if initial_alpha < final_alpha:
raise ValueError(
f"initial_alpha must be >= final_alpha, got {initial_alpha} < {final_alpha}"
)
super().__init__(alpha=initial_alpha, temperature=temperature)
self.initial_alpha = initial_alpha
self.final_alpha = final_alpha
self.current_alpha = initial_alpha
self.warmup_steps = warmup_steps
self.current_step = 0
def update_alpha(self, current_epoch: int, total_epochs: int):
"""Update alpha based on current epoch and total epochs."""
if total_epochs <= 0:
raise ValueError("total_epochs must be positive")
if current_epoch < 0:
raise ValueError("current_epoch must be non-negative")
if current_epoch >= total_epochs:
self.current_alpha = self.final_alpha
elif current_epoch <= 0:
self.current_alpha = self.initial_alpha
else:
progress = current_epoch / total_epochs
self.current_alpha = (
self.initial_alpha - (self.initial_alpha - self.final_alpha) * progress
)
self.alpha = self.current_alpha
def get_alpha(self):
"""Return the current alpha value."""
return self.current_alpha