|
|
""" |
|
|
Loss functions for signature verification training. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple, Optional |
|
|
|
|
|
|
|
|
class ContrastiveLoss(nn.Module): |
|
|
""" |
|
|
Contrastive loss for Siamese network training. |
|
|
""" |
|
|
|
|
|
def __init__(self, margin: float = 1.0): |
|
|
""" |
|
|
Initialize contrastive loss. |
|
|
|
|
|
Args: |
|
|
margin: Margin for dissimilar pairs |
|
|
""" |
|
|
super(ContrastiveLoss, self).__init__() |
|
|
self.margin = margin |
|
|
|
|
|
def forward(self, |
|
|
similarity: torch.Tensor, |
|
|
labels: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute contrastive loss. |
|
|
|
|
|
Args: |
|
|
similarity: Similarity scores (B, 1) |
|
|
labels: Binary labels (1 for genuine, 0 for forged) (B,) |
|
|
|
|
|
Returns: |
|
|
Contrastive loss |
|
|
""" |
|
|
|
|
|
labels = labels.float() |
|
|
|
|
|
|
|
|
genuine_loss = labels * torch.pow(1 - similarity.squeeze(), 2) |
|
|
|
|
|
|
|
|
forged_loss = (1 - labels) * torch.pow(torch.clamp(similarity.squeeze() - self.margin, min=0), 2) |
|
|
|
|
|
|
|
|
loss = torch.mean(genuine_loss + forged_loss) |
|
|
|
|
|
return loss |
|
|
|
|
|
|
|
|
class TripletLoss(nn.Module): |
|
|
""" |
|
|
Triplet loss for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, margin: float = 1.0): |
|
|
""" |
|
|
Initialize triplet loss. |
|
|
|
|
|
Args: |
|
|
margin: Margin between positive and negative distances |
|
|
""" |
|
|
super(TripletLoss, self).__init__() |
|
|
self.margin = margin |
|
|
|
|
|
def forward(self, |
|
|
anchor: torch.Tensor, |
|
|
positive: torch.Tensor, |
|
|
negative: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute triplet loss. |
|
|
|
|
|
Args: |
|
|
anchor: Anchor features (B, feature_dim) |
|
|
positive: Positive features (B, feature_dim) |
|
|
negative: Negative features (B, feature_dim) |
|
|
|
|
|
Returns: |
|
|
Triplet loss |
|
|
""" |
|
|
|
|
|
pos_dist = F.pairwise_distance(anchor, positive, p=2) |
|
|
neg_dist = F.pairwise_distance(anchor, negative, p=2) |
|
|
|
|
|
|
|
|
loss = F.relu(pos_dist - neg_dist + self.margin) |
|
|
|
|
|
return torch.mean(loss) |
|
|
|
|
|
|
|
|
class CenterLoss(nn.Module): |
|
|
""" |
|
|
Center loss for learning discriminative features. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
num_classes: int, |
|
|
feature_dim: int, |
|
|
lambda_c: float = 1.0): |
|
|
""" |
|
|
Initialize center loss. |
|
|
|
|
|
Args: |
|
|
num_classes: Number of signature classes |
|
|
feature_dim: Dimension of feature vectors |
|
|
lambda_c: Weight for center loss |
|
|
""" |
|
|
super(CenterLoss, self).__init__() |
|
|
self.num_classes = num_classes |
|
|
self.feature_dim = feature_dim |
|
|
self.lambda_c = lambda_c |
|
|
|
|
|
|
|
|
self.centers = nn.Parameter(torch.randn(num_classes, feature_dim)) |
|
|
|
|
|
def forward(self, |
|
|
features: torch.Tensor, |
|
|
labels: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute center loss. |
|
|
|
|
|
Args: |
|
|
features: Feature vectors (B, feature_dim) |
|
|
labels: Class labels (B,) |
|
|
|
|
|
Returns: |
|
|
Center loss |
|
|
""" |
|
|
|
|
|
batch_size = features.size(0) |
|
|
centers_batch = self.centers[labels] |
|
|
|
|
|
|
|
|
loss = F.mse_loss(features, centers_batch) |
|
|
|
|
|
return self.lambda_c * loss |
|
|
|
|
|
|
|
|
class CombinedLoss(nn.Module): |
|
|
""" |
|
|
Combined loss function for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
contrastive_weight: float = 1.0, |
|
|
triplet_weight: float = 0.5, |
|
|
center_weight: float = 0.1, |
|
|
margin: float = 1.0, |
|
|
num_classes: Optional[int] = None, |
|
|
feature_dim: Optional[int] = None): |
|
|
""" |
|
|
Initialize combined loss. |
|
|
|
|
|
Args: |
|
|
contrastive_weight: Weight for contrastive loss |
|
|
triplet_weight: Weight for triplet loss |
|
|
center_weight: Weight for center loss |
|
|
margin: Margin for contrastive and triplet losses |
|
|
num_classes: Number of classes for center loss |
|
|
feature_dim: Feature dimension for center loss |
|
|
""" |
|
|
super(CombinedLoss, self).__init__() |
|
|
|
|
|
self.contrastive_weight = contrastive_weight |
|
|
self.triplet_weight = triplet_weight |
|
|
self.center_weight = center_weight |
|
|
|
|
|
|
|
|
self.contrastive_loss = ContrastiveLoss(margin=margin) |
|
|
self.triplet_loss = TripletLoss(margin=margin) |
|
|
|
|
|
if num_classes is not None and feature_dim is not None: |
|
|
self.center_loss = CenterLoss(num_classes, feature_dim) |
|
|
else: |
|
|
self.center_loss = None |
|
|
|
|
|
def forward(self, |
|
|
similarity: Optional[torch.Tensor] = None, |
|
|
labels: Optional[torch.Tensor] = None, |
|
|
anchor: Optional[torch.Tensor] = None, |
|
|
positive: Optional[torch.Tensor] = None, |
|
|
negative: Optional[torch.Tensor] = None, |
|
|
features: Optional[torch.Tensor] = None) -> torch.Tensor: |
|
|
""" |
|
|
Compute combined loss. |
|
|
|
|
|
Args: |
|
|
similarity: Similarity scores for contrastive loss |
|
|
labels: Labels for contrastive and center loss |
|
|
anchor: Anchor features for triplet loss |
|
|
positive: Positive features for triplet loss |
|
|
negative: Negative features for triplet loss |
|
|
features: Features for center loss |
|
|
|
|
|
Returns: |
|
|
Combined loss |
|
|
""" |
|
|
total_loss = 0.0 |
|
|
|
|
|
|
|
|
if similarity is not None and labels is not None: |
|
|
contrastive_loss = self.contrastive_loss(similarity, labels) |
|
|
total_loss += self.contrastive_weight * contrastive_loss |
|
|
|
|
|
|
|
|
if anchor is not None and positive is not None and negative is not None: |
|
|
triplet_loss = self.triplet_loss(anchor, positive, negative) |
|
|
total_loss += self.triplet_weight * triplet_loss |
|
|
|
|
|
|
|
|
if self.center_loss is not None and features is not None and labels is not None: |
|
|
center_loss = self.center_loss(features, labels) |
|
|
total_loss += self.center_weight * center_loss |
|
|
|
|
|
return total_loss |
|
|
|
|
|
|
|
|
class FocalLoss(nn.Module): |
|
|
""" |
|
|
Focal loss for handling class imbalance in signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
alpha: float = 1.0, |
|
|
gamma: float = 2.0, |
|
|
reduction: str = 'mean'): |
|
|
""" |
|
|
Initialize focal loss. |
|
|
|
|
|
Args: |
|
|
alpha: Weighting factor for rare class |
|
|
gamma: Focusing parameter |
|
|
reduction: Reduction method ('mean', 'sum', 'none') |
|
|
""" |
|
|
super(FocalLoss, self).__init__() |
|
|
self.alpha = alpha |
|
|
self.gamma = gamma |
|
|
self.reduction = reduction |
|
|
|
|
|
def forward(self, |
|
|
inputs: torch.Tensor, |
|
|
targets: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute focal loss. |
|
|
|
|
|
Args: |
|
|
inputs: Predicted probabilities (B, num_classes) |
|
|
targets: Target labels (B,) |
|
|
|
|
|
Returns: |
|
|
Focal loss |
|
|
""" |
|
|
|
|
|
targets_one_hot = torch.zeros_like(inputs) |
|
|
targets_one_hot.scatter_(1, targets.unsqueeze(1), 1) |
|
|
|
|
|
|
|
|
ce_loss = F.cross_entropy(inputs, targets, reduction='none') |
|
|
|
|
|
|
|
|
pt = torch.exp(-ce_loss) |
|
|
focal_weight = self.alpha * (1 - pt) ** self.gamma |
|
|
|
|
|
|
|
|
focal_loss = focal_weight * ce_loss |
|
|
|
|
|
if self.reduction == 'mean': |
|
|
return torch.mean(focal_loss) |
|
|
elif self.reduction == 'sum': |
|
|
return torch.sum(focal_loss) |
|
|
else: |
|
|
return focal_loss |
|
|
|
|
|
|
|
|
class AdaptiveLoss(nn.Module): |
|
|
""" |
|
|
Adaptive loss that adjusts weights based on training progress. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
initial_contrastive_weight: float = 1.0, |
|
|
initial_triplet_weight: float = 0.5, |
|
|
adaptation_rate: float = 0.01): |
|
|
""" |
|
|
Initialize adaptive loss. |
|
|
|
|
|
Args: |
|
|
initial_contrastive_weight: Initial weight for contrastive loss |
|
|
initial_triplet_weight: Initial weight for triplet loss |
|
|
adaptation_rate: Rate of weight adaptation |
|
|
""" |
|
|
super(AdaptiveLoss, self).__init__() |
|
|
|
|
|
self.contrastive_weight = nn.Parameter(torch.tensor(initial_contrastive_weight)) |
|
|
self.triplet_weight = nn.Parameter(torch.tensor(initial_triplet_weight)) |
|
|
self.adaptation_rate = adaptation_rate |
|
|
|
|
|
|
|
|
self.contrastive_loss = ContrastiveLoss() |
|
|
self.triplet_loss = TripletLoss() |
|
|
|
|
|
def forward(self, |
|
|
similarity: torch.Tensor, |
|
|
labels: torch.Tensor, |
|
|
anchor: torch.Tensor, |
|
|
positive: torch.Tensor, |
|
|
negative: torch.Tensor) -> Tuple[torch.Tensor, dict]: |
|
|
""" |
|
|
Compute adaptive loss. |
|
|
|
|
|
Args: |
|
|
similarity: Similarity scores |
|
|
labels: Labels |
|
|
anchor: Anchor features |
|
|
positive: Positive features |
|
|
negative: Negative features |
|
|
|
|
|
Returns: |
|
|
Tuple of (total_loss, loss_info) |
|
|
""" |
|
|
|
|
|
contrastive_loss = self.contrastive_loss(similarity, labels) |
|
|
triplet_loss = self.triplet_loss(anchor, positive, negative) |
|
|
|
|
|
|
|
|
total_loss = (torch.sigmoid(self.contrastive_weight) * contrastive_loss + |
|
|
torch.sigmoid(self.triplet_weight) * triplet_loss) |
|
|
|
|
|
|
|
|
loss_info = { |
|
|
'contrastive_loss': contrastive_loss.item(), |
|
|
'triplet_loss': triplet_loss.item(), |
|
|
'contrastive_weight': torch.sigmoid(self.contrastive_weight).item(), |
|
|
'triplet_weight': torch.sigmoid(self.triplet_weight).item(), |
|
|
'total_loss': total_loss.item() |
|
|
} |
|
|
|
|
|
return total_loss, loss_info |
|
|
|