Mango-Metrics-NLM
feat: Phi-3.5-MoE multi-agent model repository
c8b77b5
"""
Loss Functions
=============
This module implements various loss functions for neural network training,
including cross-entropy, KL divergence, and custom loss functions for
the MangoMAS multi-agent system.
"""
import logging
from abc import ABC, abstractmethod
from typing import Dict, Optional, Any
import torch
import torch.nn.functional as F
logger = logging.getLogger(__name__)
class LossFunction(ABC):
"""Abstract base class for all loss functions"""
def __init__(self, reduction: str = 'mean'):
self.reduction = reduction
@abstractmethod
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Compute the loss"""
pass
def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""Call the loss function"""
return self.forward(predictions, targets)
class CrossEntropyLoss(LossFunction):
"""
Cross-entropy loss for classification tasks
Mathematical formulation:
L = -∑(y_i * log(ŷ_i))
Where y_i is the true label and ŷ_i is the predicted probability.
"""
def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0,
weight: Optional[torch.Tensor] = None):
super().__init__(reduction)
self.label_smoothing = label_smoothing
self.weight = weight
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute cross-entropy loss
Args:
predictions: Model predictions (logits)
targets: True labels
Returns:
Computed loss
"""
# Apply softmax to get probabilities
if predictions.dim() > 1 and predictions.size(1) > 1:
# Multi-class classification
log_probs = F.log_softmax(predictions, dim=1)
if targets.dim() == 1:
# Targets are class indices
loss = F.nll_loss(log_probs, targets, weight=self.weight,
reduction=self.reduction, label_smoothing=self.label_smoothing)
else:
# Targets are one-hot encoded
loss = -(targets * log_probs).sum(dim=1)
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
else:
# Binary classification
loss = F.binary_cross_entropy_with_logits(predictions, targets.float(),
weight=self.weight, reduction=self.reduction)
return loss
class KLDivergenceLoss(LossFunction):
"""
Kullback-Leibler divergence loss for distribution matching
Mathematical formulation:
KL(P||Q) = ∑ P(x) * log(P(x)/Q(x))
"""
def __init__(self, reduction: str = 'mean', log_target: bool = False):
super().__init__(reduction)
self.log_target = log_target
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute KL divergence loss
Args:
predictions: Predicted distribution (log probabilities)
targets: Target distribution (probabilities or log probabilities)
Returns:
Computed KL divergence loss
"""
if self.log_target:
# Both predictions and targets are in log space
loss = F.kl_div(predictions, targets, reduction=self.reduction, log_target=True)
else:
# Convert predictions to log space, targets are probabilities
log_predictions = F.log_softmax(predictions, dim=-1)
loss = F.kl_div(log_predictions, targets, reduction=self.reduction, log_target=False)
return loss
class MSELoss(LossFunction):
"""
Mean Squared Error loss for regression tasks
Mathematical formulation:
L = (1/n) * ∑(y_i - ŷ_i)²
"""
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute MSE loss
Args:
predictions: Model predictions
targets: True values
Returns:
Computed MSE loss
"""
loss = F.mse_loss(predictions, targets, reduction=self.reduction)
return loss
class HuberLoss(LossFunction):
"""
Huber loss (smooth L1 loss) for robust regression
Mathematical formulation:
L = { 0.5 * (y - ŷ)², if |y - ŷ| < δ
{ δ * (|y - ŷ| - 0.5 * δ), otherwise
"""
def __init__(self, reduction: str = 'mean', delta: float = 1.0):
super().__init__(reduction)
self.delta = delta
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute Huber loss
Args:
predictions: Model predictions
targets: True values
Returns:
Computed Huber loss
"""
loss = F.huber_loss(predictions, targets, reduction=self.reduction, delta=self.delta)
return loss
class FocalLoss(LossFunction):
"""
Focal loss for addressing class imbalance
Mathematical formulation:
FL = -α(1-p_t)^γ * log(p_t)
Where p_t is the predicted probability for the true class.
"""
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'):
super().__init__(reduction)
self.alpha = alpha
self.gamma = gamma
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor:
"""
Compute focal loss
Args:
predictions: Model predictions (logits)
targets: True labels
Returns:
Computed focal loss
"""
# Compute probabilities
probs = F.softmax(predictions, dim=1)
# Get probabilities for true classes
if targets.dim() == 1:
# Targets are class indices
pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1)
else:
# Targets are one-hot encoded
pt = (probs * targets).sum(dim=1)
# Compute focal loss
focal_weight = self.alpha * (1 - pt) ** self.gamma
loss = -focal_weight * torch.log(pt + 1e-8)
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
class ContrastiveLoss(LossFunction):
"""
Contrastive loss for learning representations
Mathematical formulation:
L = (1-y) * d² + y * max(0, margin - d)²
Where d is the distance between embeddings and y is the similarity label.
"""
def __init__(self, margin: float = 1.0, reduction: str = 'mean'):
super().__init__(reduction)
self.margin = margin
def forward(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""
Compute contrastive loss
Args:
embeddings1: First set of embeddings
embeddings2: Second set of embeddings
labels: Similarity labels (1 for similar, 0 for dissimilar)
Returns:
Computed contrastive loss
"""
# Compute Euclidean distance
distance = F.pairwise_distance(embeddings1, embeddings2)
# Compute contrastive loss
positive_loss = labels.float() * distance.pow(2)
negative_loss = (1 - labels.float()) * F.relu(self.margin - distance).pow(2)
loss = positive_loss + negative_loss
if self.reduction == 'mean':
loss = loss.mean()
elif self.reduction == 'sum':
loss = loss.sum()
return loss
class CustomAgentLoss(LossFunction):
"""
Custom loss function for MangoMAS agents
Combines multiple loss components to optimize agent performance.
"""
def __init__(self, task_loss_weight: float = 1.0,
consistency_loss_weight: float = 0.1,
regularization_weight: float = 0.01):
super().__init__('mean')
self.task_loss_weight = task_loss_weight
self.consistency_loss_weight = consistency_loss_weight
self.regularization_weight = regularization_weight
# Initialize component losses
self.task_loss = CrossEntropyLoss()
self.consistency_loss = MSELoss()
def forward(self, predictions: torch.Tensor, targets: torch.Tensor,
model_outputs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor:
"""
Compute custom agent loss
Args:
predictions: Model predictions
targets: True labels
model_outputs: Additional model outputs for consistency loss
Returns:
Computed custom loss
"""
# Task-specific loss
task_loss = self.task_loss(predictions, targets)
# Consistency loss (if model outputs provided)
consistency_loss = torch.tensor(0.0, device=predictions.device)
if model_outputs is not None and 'hidden_states' in model_outputs:
# Encourage consistent hidden representations
hidden_states = model_outputs['hidden_states']
if len(hidden_states) > 1:
# Compute consistency between consecutive hidden states
for i in range(len(hidden_states) - 1):
consistency_loss += self.consistency_loss(hidden_states[i], hidden_states[i+1])
consistency_loss /= (len(hidden_states) - 1)
# Regularization loss (L2 penalty)
regularization_loss = torch.tensor(0.0, device=predictions.device)
# This would be computed from model parameters in practice
# Combine losses
total_loss = (self.task_loss_weight * task_loss +
self.consistency_loss_weight * consistency_loss +
self.regularization_weight * regularization_loss)
return total_loss
class LossFunctionFactory:
"""Factory class for creating loss functions"""
@staticmethod
def create_loss_function(loss_type: str, **kwargs) -> LossFunction:
"""Create a loss function instance"""
loss_functions = {
'cross_entropy': CrossEntropyLoss,
'kl_divergence': KLDivergenceLoss,
'mse': MSELoss,
'huber': HuberLoss,
'focal': FocalLoss,
'contrastive': ContrastiveLoss,
'custom_agent': CustomAgentLoss
}
if loss_type.lower() not in loss_functions:
raise ValueError(f"Unknown loss function type: {loss_type}")
loss_class = loss_functions[loss_type.lower()]
return loss_class(**kwargs)
@staticmethod
def get_default_config(loss_type: str) -> Dict[str, Any]:
"""Get default configuration for loss function"""
configs = {
'cross_entropy': {
'reduction': 'mean',
'label_smoothing': 0.0
},
'kl_divergence': {
'reduction': 'mean',
'log_target': False
},
'mse': {
'reduction': 'mean'
},
'huber': {
'reduction': 'mean',
'delta': 1.0
},
'focal': {
'alpha': 1.0,
'gamma': 2.0,
'reduction': 'mean'
},
'contrastive': {
'margin': 1.0,
'reduction': 'mean'
},
'custom_agent': {
'task_loss_weight': 1.0,
'consistency_loss_weight': 0.1,
'regularization_weight': 0.01
}
}
return configs.get(loss_type.lower(), {})