""" Loss Functions ============= This module implements various loss functions for neural network training, including cross-entropy, KL divergence, and custom loss functions for the MangoMAS multi-agent system. """ import logging from abc import ABC, abstractmethod from typing import Dict, Optional, Any import torch import torch.nn.functional as F logger = logging.getLogger(__name__) class LossFunction(ABC): """Abstract base class for all loss functions""" def __init__(self, reduction: str = 'mean'): self.reduction = reduction @abstractmethod def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Compute the loss""" pass def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """Call the loss function""" return self.forward(predictions, targets) class CrossEntropyLoss(LossFunction): """ Cross-entropy loss for classification tasks Mathematical formulation: L = -∑(y_i * log(ŷ_i)) Where y_i is the true label and ŷ_i is the predicted probability. """ def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, weight: Optional[torch.Tensor] = None): super().__init__(reduction) self.label_smoothing = label_smoothing self.weight = weight def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute cross-entropy loss Args: predictions: Model predictions (logits) targets: True labels Returns: Computed loss """ # Apply softmax to get probabilities if predictions.dim() > 1 and predictions.size(1) > 1: # Multi-class classification log_probs = F.log_softmax(predictions, dim=1) if targets.dim() == 1: # Targets are class indices loss = F.nll_loss(log_probs, targets, weight=self.weight, reduction=self.reduction, label_smoothing=self.label_smoothing) else: # Targets are one-hot encoded loss = -(targets * log_probs).sum(dim=1) if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() else: # Binary classification loss = F.binary_cross_entropy_with_logits(predictions, targets.float(), weight=self.weight, reduction=self.reduction) return loss class KLDivergenceLoss(LossFunction): """ Kullback-Leibler divergence loss for distribution matching Mathematical formulation: KL(P||Q) = ∑ P(x) * log(P(x)/Q(x)) """ def __init__(self, reduction: str = 'mean', log_target: bool = False): super().__init__(reduction) self.log_target = log_target def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute KL divergence loss Args: predictions: Predicted distribution (log probabilities) targets: Target distribution (probabilities or log probabilities) Returns: Computed KL divergence loss """ if self.log_target: # Both predictions and targets are in log space loss = F.kl_div(predictions, targets, reduction=self.reduction, log_target=True) else: # Convert predictions to log space, targets are probabilities log_predictions = F.log_softmax(predictions, dim=-1) loss = F.kl_div(log_predictions, targets, reduction=self.reduction, log_target=False) return loss class MSELoss(LossFunction): """ Mean Squared Error loss for regression tasks Mathematical formulation: L = (1/n) * ∑(y_i - ŷ_i)² """ def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute MSE loss Args: predictions: Model predictions targets: True values Returns: Computed MSE loss """ loss = F.mse_loss(predictions, targets, reduction=self.reduction) return loss class HuberLoss(LossFunction): """ Huber loss (smooth L1 loss) for robust regression Mathematical formulation: L = { 0.5 * (y - ŷ)², if |y - ŷ| < δ { δ * (|y - ŷ| - 0.5 * δ), otherwise """ def __init__(self, reduction: str = 'mean', delta: float = 1.0): super().__init__(reduction) self.delta = delta def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute Huber loss Args: predictions: Model predictions targets: True values Returns: Computed Huber loss """ loss = F.huber_loss(predictions, targets, reduction=self.reduction, delta=self.delta) return loss class FocalLoss(LossFunction): """ Focal loss for addressing class imbalance Mathematical formulation: FL = -α(1-p_t)^γ * log(p_t) Where p_t is the predicted probability for the true class. """ def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'): super().__init__(reduction) self.alpha = alpha self.gamma = gamma def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute focal loss Args: predictions: Model predictions (logits) targets: True labels Returns: Computed focal loss """ # Compute probabilities probs = F.softmax(predictions, dim=1) # Get probabilities for true classes if targets.dim() == 1: # Targets are class indices pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1) else: # Targets are one-hot encoded pt = (probs * targets).sum(dim=1) # Compute focal loss focal_weight = self.alpha * (1 - pt) ** self.gamma loss = -focal_weight * torch.log(pt + 1e-8) if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() return loss class ContrastiveLoss(LossFunction): """ Contrastive loss for learning representations Mathematical formulation: L = (1-y) * d² + y * max(0, margin - d)² Where d is the distance between embeddings and y is the similarity label. """ def __init__(self, margin: float = 1.0, reduction: str = 'mean'): super().__init__(reduction) self.margin = margin def forward(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Compute contrastive loss Args: embeddings1: First set of embeddings embeddings2: Second set of embeddings labels: Similarity labels (1 for similar, 0 for dissimilar) Returns: Computed contrastive loss """ # Compute Euclidean distance distance = F.pairwise_distance(embeddings1, embeddings2) # Compute contrastive loss positive_loss = labels.float() * distance.pow(2) negative_loss = (1 - labels.float()) * F.relu(self.margin - distance).pow(2) loss = positive_loss + negative_loss if self.reduction == 'mean': loss = loss.mean() elif self.reduction == 'sum': loss = loss.sum() return loss class CustomAgentLoss(LossFunction): """ Custom loss function for MangoMAS agents Combines multiple loss components to optimize agent performance. """ def __init__(self, task_loss_weight: float = 1.0, consistency_loss_weight: float = 0.1, regularization_weight: float = 0.01): super().__init__('mean') self.task_loss_weight = task_loss_weight self.consistency_loss_weight = consistency_loss_weight self.regularization_weight = regularization_weight # Initialize component losses self.task_loss = CrossEntropyLoss() self.consistency_loss = MSELoss() def forward(self, predictions: torch.Tensor, targets: torch.Tensor, model_outputs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: """ Compute custom agent loss Args: predictions: Model predictions targets: True labels model_outputs: Additional model outputs for consistency loss Returns: Computed custom loss """ # Task-specific loss task_loss = self.task_loss(predictions, targets) # Consistency loss (if model outputs provided) consistency_loss = torch.tensor(0.0, device=predictions.device) if model_outputs is not None and 'hidden_states' in model_outputs: # Encourage consistent hidden representations hidden_states = model_outputs['hidden_states'] if len(hidden_states) > 1: # Compute consistency between consecutive hidden states for i in range(len(hidden_states) - 1): consistency_loss += self.consistency_loss(hidden_states[i], hidden_states[i+1]) consistency_loss /= (len(hidden_states) - 1) # Regularization loss (L2 penalty) regularization_loss = torch.tensor(0.0, device=predictions.device) # This would be computed from model parameters in practice # Combine losses total_loss = (self.task_loss_weight * task_loss + self.consistency_loss_weight * consistency_loss + self.regularization_weight * regularization_loss) return total_loss class LossFunctionFactory: """Factory class for creating loss functions""" @staticmethod def create_loss_function(loss_type: str, **kwargs) -> LossFunction: """Create a loss function instance""" loss_functions = { 'cross_entropy': CrossEntropyLoss, 'kl_divergence': KLDivergenceLoss, 'mse': MSELoss, 'huber': HuberLoss, 'focal': FocalLoss, 'contrastive': ContrastiveLoss, 'custom_agent': CustomAgentLoss } if loss_type.lower() not in loss_functions: raise ValueError(f"Unknown loss function type: {loss_type}") loss_class = loss_functions[loss_type.lower()] return loss_class(**kwargs) @staticmethod def get_default_config(loss_type: str) -> Dict[str, Any]: """Get default configuration for loss function""" configs = { 'cross_entropy': { 'reduction': 'mean', 'label_smoothing': 0.0 }, 'kl_divergence': { 'reduction': 'mean', 'log_target': False }, 'mse': { 'reduction': 'mean' }, 'huber': { 'reduction': 'mean', 'delta': 1.0 }, 'focal': { 'alpha': 1.0, 'gamma': 2.0, 'reduction': 'mean' }, 'contrastive': { 'margin': 1.0, 'reduction': 'mean' }, 'custom_agent': { 'task_loss_weight': 1.0, 'consistency_loss_weight': 0.1, 'regularization_weight': 0.01 } } return configs.get(loss_type.lower(), {})