""" Loss functions for signature verification training. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional class ContrastiveLoss(nn.Module): """ Contrastive loss for Siamese network training. """ def __init__(self, margin: float = 1.0): """ Initialize contrastive loss. Args: margin: Margin for dissimilar pairs """ super(ContrastiveLoss, self).__init__() self.margin = margin def forward(self, similarity: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Compute contrastive loss. Args: similarity: Similarity scores (B, 1) labels: Binary labels (1 for genuine, 0 for forged) (B,) Returns: Contrastive loss """ # Convert labels to float labels = labels.float() # Compute loss for genuine pairs (similarity should be high) genuine_loss = labels * torch.pow(1 - similarity.squeeze(), 2) # Compute loss for forged pairs (similarity should be low) forged_loss = (1 - labels) * torch.pow(torch.clamp(similarity.squeeze() - self.margin, min=0), 2) # Total loss loss = torch.mean(genuine_loss + forged_loss) return loss class TripletLoss(nn.Module): """ Triplet loss for signature verification. """ def __init__(self, margin: float = 1.0): """ Initialize triplet loss. Args: margin: Margin between positive and negative distances """ super(TripletLoss, self).__init__() self.margin = margin def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> torch.Tensor: """ Compute triplet loss. Args: anchor: Anchor features (B, feature_dim) positive: Positive features (B, feature_dim) negative: Negative features (B, feature_dim) Returns: Triplet loss """ # Compute distances pos_dist = F.pairwise_distance(anchor, positive, p=2) neg_dist = F.pairwise_distance(anchor, negative, p=2) # Compute triplet loss loss = F.relu(pos_dist - neg_dist + self.margin) return torch.mean(loss) class CenterLoss(nn.Module): """ Center loss for learning discriminative features. """ def __init__(self, num_classes: int, feature_dim: int, lambda_c: float = 1.0): """ Initialize center loss. Args: num_classes: Number of signature classes feature_dim: Dimension of feature vectors lambda_c: Weight for center loss """ super(CenterLoss, self).__init__() self.num_classes = num_classes self.feature_dim = feature_dim self.lambda_c = lambda_c # Initialize centers self.centers = nn.Parameter(torch.randn(num_classes, feature_dim)) def forward(self, features: torch.Tensor, labels: torch.Tensor) -> torch.Tensor: """ Compute center loss. Args: features: Feature vectors (B, feature_dim) labels: Class labels (B,) Returns: Center loss """ # Get centers for current batch batch_size = features.size(0) centers_batch = self.centers[labels] # Compute center loss loss = F.mse_loss(features, centers_batch) return self.lambda_c * loss class CombinedLoss(nn.Module): """ Combined loss function for signature verification. """ def __init__(self, contrastive_weight: float = 1.0, triplet_weight: float = 0.5, center_weight: float = 0.1, margin: float = 1.0, num_classes: Optional[int] = None, feature_dim: Optional[int] = None): """ Initialize combined loss. Args: contrastive_weight: Weight for contrastive loss triplet_weight: Weight for triplet loss center_weight: Weight for center loss margin: Margin for contrastive and triplet losses num_classes: Number of classes for center loss feature_dim: Feature dimension for center loss """ super(CombinedLoss, self).__init__() self.contrastive_weight = contrastive_weight self.triplet_weight = triplet_weight self.center_weight = center_weight # Initialize loss functions self.contrastive_loss = ContrastiveLoss(margin=margin) self.triplet_loss = TripletLoss(margin=margin) if num_classes is not None and feature_dim is not None: self.center_loss = CenterLoss(num_classes, feature_dim) else: self.center_loss = None def forward(self, similarity: Optional[torch.Tensor] = None, labels: Optional[torch.Tensor] = None, anchor: Optional[torch.Tensor] = None, positive: Optional[torch.Tensor] = None, negative: Optional[torch.Tensor] = None, features: Optional[torch.Tensor] = None) -> torch.Tensor: """ Compute combined loss. Args: similarity: Similarity scores for contrastive loss labels: Labels for contrastive and center loss anchor: Anchor features for triplet loss positive: Positive features for triplet loss negative: Negative features for triplet loss features: Features for center loss Returns: Combined loss """ total_loss = 0.0 # Contrastive loss if similarity is not None and labels is not None: contrastive_loss = self.contrastive_loss(similarity, labels) total_loss += self.contrastive_weight * contrastive_loss # Triplet loss if anchor is not None and positive is not None and negative is not None: triplet_loss = self.triplet_loss(anchor, positive, negative) total_loss += self.triplet_weight * triplet_loss # Center loss if self.center_loss is not None and features is not None and labels is not None: center_loss = self.center_loss(features, labels) total_loss += self.center_weight * center_loss return total_loss class FocalLoss(nn.Module): """ Focal loss for handling class imbalance in signature verification. """ def __init__(self, alpha: float = 1.0, gamma: float = 2.0, reduction: str = 'mean'): """ Initialize focal loss. Args: alpha: Weighting factor for rare class gamma: Focusing parameter reduction: Reduction method ('mean', 'sum', 'none') """ super(FocalLoss, self).__init__() self.alpha = alpha self.gamma = gamma self.reduction = reduction def forward(self, inputs: torch.Tensor, targets: torch.Tensor) -> torch.Tensor: """ Compute focal loss. Args: inputs: Predicted probabilities (B, num_classes) targets: Target labels (B,) Returns: Focal loss """ # Convert to one-hot encoding targets_one_hot = torch.zeros_like(inputs) targets_one_hot.scatter_(1, targets.unsqueeze(1), 1) # Compute cross entropy ce_loss = F.cross_entropy(inputs, targets, reduction='none') # Compute focal weight pt = torch.exp(-ce_loss) focal_weight = self.alpha * (1 - pt) ** self.gamma # Compute focal loss focal_loss = focal_weight * ce_loss if self.reduction == 'mean': return torch.mean(focal_loss) elif self.reduction == 'sum': return torch.sum(focal_loss) else: return focal_loss class AdaptiveLoss(nn.Module): """ Adaptive loss that adjusts weights based on training progress. """ def __init__(self, initial_contrastive_weight: float = 1.0, initial_triplet_weight: float = 0.5, adaptation_rate: float = 0.01): """ Initialize adaptive loss. Args: initial_contrastive_weight: Initial weight for contrastive loss initial_triplet_weight: Initial weight for triplet loss adaptation_rate: Rate of weight adaptation """ super(AdaptiveLoss, self).__init__() self.contrastive_weight = nn.Parameter(torch.tensor(initial_contrastive_weight)) self.triplet_weight = nn.Parameter(torch.tensor(initial_triplet_weight)) self.adaptation_rate = adaptation_rate # Initialize loss functions self.contrastive_loss = ContrastiveLoss() self.triplet_loss = TripletLoss() def forward(self, similarity: torch.Tensor, labels: torch.Tensor, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> Tuple[torch.Tensor, dict]: """ Compute adaptive loss. Args: similarity: Similarity scores labels: Labels anchor: Anchor features positive: Positive features negative: Negative features Returns: Tuple of (total_loss, loss_info) """ # Compute individual losses contrastive_loss = self.contrastive_loss(similarity, labels) triplet_loss = self.triplet_loss(anchor, positive, negative) # Compute total loss total_loss = (torch.sigmoid(self.contrastive_weight) * contrastive_loss + torch.sigmoid(self.triplet_weight) * triplet_loss) # Prepare loss info loss_info = { 'contrastive_loss': contrastive_loss.item(), 'triplet_loss': triplet_loss.item(), 'contrastive_weight': torch.sigmoid(self.contrastive_weight).item(), 'triplet_weight': torch.sigmoid(self.triplet_weight).item(), 'total_loss': total_loss.item() } return total_loss, loss_info