|
|
""" |
|
|
Loss Functions |
|
|
============= |
|
|
|
|
|
This module implements various loss functions for neural network training, |
|
|
including cross-entropy, KL divergence, and custom loss functions for |
|
|
the MangoMAS multi-agent system. |
|
|
""" |
|
|
|
|
|
import logging |
|
|
from abc import ABC, abstractmethod |
|
|
from typing import Dict, Optional, Any |
|
|
import torch |
|
|
import torch.nn.functional as F |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
class LossFunction(ABC): |
|
|
"""Abstract base class for all loss functions""" |
|
|
|
|
|
def __init__(self, reduction: str = 'mean'): |
|
|
self.reduction = reduction |
|
|
|
|
|
@abstractmethod |
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
"""Compute the loss""" |
|
|
pass |
|
|
|
|
|
def __call__(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
"""Call the loss function""" |
|
|
return self.forward(predictions, targets) |
|
|
|
|
|
|
|
|
class CrossEntropyLoss(LossFunction): |
|
|
""" |
|
|
Cross-entropy loss for classification tasks |
|
|
|
|
|
Mathematical formulation: |
|
|
L = -∑(y_i * log(ŷ_i)) |
|
|
|
|
|
Where y_i is the true label and ŷ_i is the predicted probability. |
|
|
""" |
|
|
|
|
|
def __init__(self, reduction: str = 'mean', label_smoothing: float = 0.0, |
|
|
weight: Optional[torch.Tensor] = None): |
|
|
super().__init__(reduction) |
|
|
self.label_smoothing = label_smoothing |
|
|
self.weight = weight |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute cross-entropy loss |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions (logits) |
|
|
targets: True labels |
|
|
|
|
|
Returns: |
|
|
Computed loss |
|
|
""" |
|
|
|
|
|
if predictions.dim() > 1 and predictions.size(1) > 1: |
|
|
|
|
|
log_probs = F.log_softmax(predictions, dim=1) |
|
|
|
|
|
if targets.dim() == 1: |
|
|
|
|
|
loss = F.nll_loss(log_probs, targets, weight=self.weight, |
|
|
reduction=self.reduction, label_smoothing=self.label_smoothing) |
|
|
else: |
|
|
|
|
|
loss = -(targets * log_probs).sum(dim=1) |
|
|
if self.reduction == 'mean': |
|
|
loss = loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
loss = loss.sum() |
|
|
else: |
|
|
|
|
|
loss = F.binary_cross_entropy_with_logits(predictions, targets.float(), |
|
|
weight=self.weight, reduction=self.reduction) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class KLDivergenceLoss(LossFunction): |
|
|
""" |
|
|
Kullback-Leibler divergence loss for distribution matching |
|
|
|
|
|
Mathematical formulation: |
|
|
KL(P||Q) = ∑ P(x) * log(P(x)/Q(x)) |
|
|
""" |
|
|
|
|
|
def __init__(self, reduction: str = 'mean', log_target: bool = False): |
|
|
super().__init__(reduction) |
|
|
self.log_target = log_target |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute KL divergence loss |
|
|
|
|
|
Args: |
|
|
predictions: Predicted distribution (log probabilities) |
|
|
targets: Target distribution (probabilities or log probabilities) |
|
|
|
|
|
Returns: |
|
|
Computed KL divergence loss |
|
|
""" |
|
|
if self.log_target: |
|
|
|
|
|
loss = F.kl_div(predictions, targets, reduction=self.reduction, log_target=True) |
|
|
else: |
|
|
|
|
|
log_predictions = F.log_softmax(predictions, dim=-1) |
|
|
loss = F.kl_div(log_predictions, targets, reduction=self.reduction, log_target=False) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class MSELoss(LossFunction): |
|
|
""" |
|
|
Mean Squared Error loss for regression tasks |
|
|
|
|
|
Mathematical formulation: |
|
|
L = (1/n) * ∑(y_i - ŷ_i)² |
|
|
""" |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute MSE loss |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions |
|
|
targets: True values |
|
|
|
|
|
Returns: |
|
|
Computed MSE loss |
|
|
""" |
|
|
loss = F.mse_loss(predictions, targets, reduction=self.reduction) |
|
|
return loss |
|
|
|
|
|
|
|
|
class HuberLoss(LossFunction): |
|
|
""" |
|
|
Huber loss (smooth L1 loss) for robust regression |
|
|
|
|
|
Mathematical formulation: |
|
|
L = { 0.5 * (y - ŷ)², if |y - ŷ| < δ |
|
|
{ δ * (|y - ŷ| - 0.5 * δ), otherwise |
|
|
""" |
|
|
|
|
|
def __init__(self, reduction: str = 'mean', delta: float = 1.0): |
|
|
super().__init__(reduction) |
|
|
self.delta = delta |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute Huber loss |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions |
|
|
targets: True values |
|
|
|
|
|
Returns: |
|
|
Computed Huber loss |
|
|
""" |
|
|
loss = F.huber_loss(predictions, targets, reduction=self.reduction, delta=self.delta) |
|
|
return loss |
|
|
|
|
|
|
|
|
class FocalLoss(LossFunction): |
|
|
""" |
|
|
Focal loss for addressing class imbalance |
|
|
|
|
|
Mathematical formulation: |
|
|
FL = -α(1-p_t)^γ * log(p_t) |
|
|
|
|
|
Where p_t is the predicted probability for the true class. |
|
|
""" |
|
|
|
|
|
def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'): |
|
|
super().__init__(reduction) |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute focal loss |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions (logits) |
|
|
targets: True labels |
|
|
|
|
|
Returns: |
|
|
Computed focal loss |
|
|
""" |
|
|
|
|
|
probs = F.softmax(predictions, dim=1) |
|
|
|
|
|
|
|
|
if targets.dim() == 1: |
|
|
|
|
|
pt = probs.gather(1, targets.unsqueeze(1)).squeeze(1) |
|
|
else: |
|
|
|
|
|
pt = (probs * targets).sum(dim=1) |
|
|
|
|
|
|
|
|
focal_weight = self.alpha * (1 - pt) ** self.gamma |
|
|
loss = -focal_weight * torch.log(pt + 1e-8) |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
loss = loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
loss = loss.sum() |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class ContrastiveLoss(LossFunction): |
|
|
""" |
|
|
Contrastive loss for learning representations |
|
|
|
|
|
Mathematical formulation: |
|
|
L = (1-y) * d² + y * max(0, margin - d)² |
|
|
|
|
|
Where d is the distance between embeddings and y is the similarity label. |
|
|
""" |
|
|
|
|
|
def __init__(self, margin: float = 1.0, reduction: str = 'mean'): |
|
|
super().__init__(reduction) |
|
|
self.margin = margin |
|
|
|
|
|
def forward(self, embeddings1: torch.Tensor, embeddings2: torch.Tensor, |
|
|
labels: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute contrastive loss |
|
|
|
|
|
Args: |
|
|
embeddings1: First set of embeddings |
|
|
embeddings2: Second set of embeddings |
|
|
labels: Similarity labels (1 for similar, 0 for dissimilar) |
|
|
|
|
|
Returns: |
|
|
Computed contrastive loss |
|
|
""" |
|
|
|
|
|
distance = F.pairwise_distance(embeddings1, embeddings2) |
|
|
|
|
|
|
|
|
positive_loss = labels.float() * distance.pow(2) |
|
|
negative_loss = (1 - labels.float()) * F.relu(self.margin - distance).pow(2) |
|
|
|
|
|
loss = positive_loss + negative_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
loss = loss.mean() |
|
|
elif self.reduction == 'sum': |
|
|
loss = loss.sum() |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class CustomAgentLoss(LossFunction): |
|
|
""" |
|
|
Custom loss function for MangoMAS agents |
|
|
|
|
|
Combines multiple loss components to optimize agent performance. |
|
|
""" |
|
|
|
|
|
def __init__(self, task_loss_weight: float = 1.0, |
|
|
consistency_loss_weight: float = 0.1, |
|
|
regularization_weight: float = 0.01): |
|
|
super().__init__('mean') |
|
|
self.task_loss_weight = task_loss_weight |
|
|
self.consistency_loss_weight = consistency_loss_weight |
|
|
self.regularization_weight = regularization_weight |
|
|
|
|
|
|
|
|
self.task_loss = CrossEntropyLoss() |
|
|
self.consistency_loss = MSELoss() |
|
|
|
|
|
def forward(self, predictions: torch.Tensor, targets: torch.Tensor, |
|
|
model_outputs: Optional[Dict[str, torch.Tensor]] = None) -> torch.Tensor: |
|
|
""" |
|
|
Compute custom agent loss |
|
|
|
|
|
Args: |
|
|
predictions: Model predictions |
|
|
targets: True labels |
|
|
model_outputs: Additional model outputs for consistency loss |
|
|
|
|
|
Returns: |
|
|
Computed custom loss |
|
|
""" |
|
|
|
|
|
task_loss = self.task_loss(predictions, targets) |
|
|
|
|
|
|
|
|
consistency_loss = torch.tensor(0.0, device=predictions.device) |
|
|
if model_outputs is not None and 'hidden_states' in model_outputs: |
|
|
|
|
|
hidden_states = model_outputs['hidden_states'] |
|
|
if len(hidden_states) > 1: |
|
|
|
|
|
for i in range(len(hidden_states) - 1): |
|
|
consistency_loss += self.consistency_loss(hidden_states[i], hidden_states[i+1]) |
|
|
consistency_loss /= (len(hidden_states) - 1) |
|
|
|
|
|
|
|
|
regularization_loss = torch.tensor(0.0, device=predictions.device) |
|
|
|
|
|
|
|
|
|
|
|
total_loss = (self.task_loss_weight * task_loss + |
|
|
self.consistency_loss_weight * consistency_loss + |
|
|
self.regularization_weight * regularization_loss) |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
class LossFunctionFactory: |
|
|
"""Factory class for creating loss functions""" |
|
|
|
|
|
@staticmethod |
|
|
def create_loss_function(loss_type: str, **kwargs) -> LossFunction: |
|
|
"""Create a loss function instance""" |
|
|
loss_functions = { |
|
|
'cross_entropy': CrossEntropyLoss, |
|
|
'kl_divergence': KLDivergenceLoss, |
|
|
'mse': MSELoss, |
|
|
'huber': HuberLoss, |
|
|
'focal': FocalLoss, |
|
|
'contrastive': ContrastiveLoss, |
|
|
'custom_agent': CustomAgentLoss |
|
|
} |
|
|
|
|
|
if loss_type.lower() not in loss_functions: |
|
|
raise ValueError(f"Unknown loss function type: {loss_type}") |
|
|
|
|
|
loss_class = loss_functions[loss_type.lower()] |
|
|
return loss_class(**kwargs) |
|
|
|
|
|
@staticmethod |
|
|
def get_default_config(loss_type: str) -> Dict[str, Any]: |
|
|
"""Get default configuration for loss function""" |
|
|
configs = { |
|
|
'cross_entropy': { |
|
|
'reduction': 'mean', |
|
|
'label_smoothing': 0.0 |
|
|
}, |
|
|
'kl_divergence': { |
|
|
'reduction': 'mean', |
|
|
'log_target': False |
|
|
}, |
|
|
'mse': { |
|
|
'reduction': 'mean' |
|
|
}, |
|
|
'huber': { |
|
|
'reduction': 'mean', |
|
|
'delta': 1.0 |
|
|
}, |
|
|
'focal': { |
|
|
'alpha': 1.0, |
|
|
'gamma': 2.0, |
|
|
'reduction': 'mean' |
|
|
}, |
|
|
'contrastive': { |
|
|
'margin': 1.0, |
|
|
'reduction': 'mean' |
|
|
}, |
|
|
'custom_agent': { |
|
|
'task_loss_weight': 1.0, |
|
|
'consistency_loss_weight': 0.1, |
|
|
'regularization_weight': 0.01 |
|
|
} |
|
|
} |
|
|
|
|
|
return configs.get(loss_type.lower(), {}) |
|
|
|