InklyAI / src /training /losses.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Loss functions for signature verification training.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional
class ContrastiveLoss(nn.Module):
"""
Contrastive loss for Siamese network training.
"""
def __init__(self, margin: float = 1.0):
"""
Initialize contrastive loss.
Args:
margin: Margin for dissimilar pairs
"""
super(ContrastiveLoss, self).__init__()
self.margin = margin
def forward(self,
similarity: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""
Compute contrastive loss.
Args:
similarity: Similarity scores (B, 1)
labels: Binary labels (1 for genuine, 0 for forged) (B,)
Returns:
Contrastive loss
"""
# Convert labels to float
labels = labels.float()
# Compute loss for genuine pairs (similarity should be high)
genuine_loss = labels * torch.pow(1 - similarity.squeeze(), 2)
# Compute loss for forged pairs (similarity should be low)
forged_loss = (1 - labels) * torch.pow(torch.clamp(similarity.squeeze() - self.margin, min=0), 2)
# Total loss
loss = torch.mean(genuine_loss + forged_loss)
return loss
class TripletLoss(nn.Module):
"""
Triplet loss for signature verification.
"""
def __init__(self, margin: float = 1.0):
"""
Initialize triplet loss.
Args:
margin: Margin between positive and negative distances
"""
super(TripletLoss, self).__init__()
self.margin = margin
def forward(self,
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor) -> torch.Tensor:
"""
Compute triplet loss.
Args:
anchor: Anchor features (B, feature_dim)
positive: Positive features (B, feature_dim)
negative: Negative features (B, feature_dim)
Returns:
Triplet loss
"""
# Compute distances
pos_dist = F.pairwise_distance(anchor, positive, p=2)
neg_dist = F.pairwise_distance(anchor, negative, p=2)
# Compute triplet loss
loss = F.relu(pos_dist - neg_dist + self.margin)
return torch.mean(loss)
class CenterLoss(nn.Module):
"""
Center loss for learning discriminative features.
"""
def __init__(self,
num_classes: int,
feature_dim: int,
lambda_c: float = 1.0):
"""
Initialize center loss.
Args:
num_classes: Number of signature classes
feature_dim: Dimension of feature vectors
lambda_c: Weight for center loss
"""
super(CenterLoss, self).__init__()
self.num_classes = num_classes
self.feature_dim = feature_dim
self.lambda_c = lambda_c
# Initialize centers
self.centers = nn.Parameter(torch.randn(num_classes, feature_dim))
def forward(self,
features: torch.Tensor,
labels: torch.Tensor) -> torch.Tensor:
"""
Compute center loss.
Args:
features: Feature vectors (B, feature_dim)
labels: Class labels (B,)
Returns:
Center loss
"""
# Get centers for current batch
batch_size = features.size(0)
centers_batch = self.centers[labels]
# Compute center loss
loss = F.mse_loss(features, centers_batch)
return self.lambda_c * loss
class CombinedLoss(nn.Module):
"""
Combined loss function for signature verification.
"""
def __init__(self,
contrastive_weight: float = 1.0,
triplet_weight: float = 0.5,
center_weight: float = 0.1,
margin: float = 1.0,
num_classes: Optional[int] = None,
feature_dim: Optional[int] = None):
"""
Initialize combined loss.
Args:
contrastive_weight: Weight for contrastive loss
triplet_weight: Weight for triplet loss
center_weight: Weight for center loss
margin: Margin for contrastive and triplet losses
num_classes: Number of classes for center loss
feature_dim: Feature dimension for center loss
"""
super(CombinedLoss, self).__init__()
self.contrastive_weight = contrastive_weight
self.triplet_weight = triplet_weight
self.center_weight = center_weight
# Initialize loss functions
self.contrastive_loss = ContrastiveLoss(margin=margin)
self.triplet_loss = TripletLoss(margin=margin)
if num_classes is not None and feature_dim is not None:
self.center_loss = CenterLoss(num_classes, feature_dim)
else:
self.center_loss = None
def forward(self,
similarity: Optional[torch.Tensor] = None,
labels: Optional[torch.Tensor] = None,
anchor: Optional[torch.Tensor] = None,
positive: Optional[torch.Tensor] = None,
negative: Optional[torch.Tensor] = None,
features: Optional[torch.Tensor] = None) -> torch.Tensor:
"""
Compute combined loss.
Args:
similarity: Similarity scores for contrastive loss
labels: Labels for contrastive and center loss
anchor: Anchor features for triplet loss
positive: Positive features for triplet loss
negative: Negative features for triplet loss
features: Features for center loss
Returns:
Combined loss
"""
total_loss = 0.0
# Contrastive loss
if similarity is not None and labels is not None:
contrastive_loss = self.contrastive_loss(similarity, labels)
total_loss += self.contrastive_weight * contrastive_loss
# Triplet loss
if anchor is not None and positive is not None and negative is not None:
triplet_loss = self.triplet_loss(anchor, positive, negative)
total_loss += self.triplet_weight * triplet_loss
# Center loss
if self.center_loss is not None and features is not None and labels is not None:
center_loss = self.center_loss(features, labels)
total_loss += self.center_weight * center_loss
return total_loss
class FocalLoss(nn.Module):
"""
Focal loss for handling class imbalance in signature verification.
"""
def __init__(self,
alpha: float = 1.0,
gamma: float = 2.0,
reduction: str = 'mean'):
"""
Initialize focal loss.
Args:
alpha: Weighting factor for rare class
gamma: Focusing parameter
reduction: Reduction method ('mean', 'sum', 'none')
"""
super(FocalLoss, self).__init__()
self.alpha = alpha
self.gamma = gamma
self.reduction = reduction
def forward(self,
inputs: torch.Tensor,
targets: torch.Tensor) -> torch.Tensor:
"""
Compute focal loss.
Args:
inputs: Predicted probabilities (B, num_classes)
targets: Target labels (B,)
Returns:
Focal loss
"""
# Convert to one-hot encoding
targets_one_hot = torch.zeros_like(inputs)
targets_one_hot.scatter_(1, targets.unsqueeze(1), 1)
# Compute cross entropy
ce_loss = F.cross_entropy(inputs, targets, reduction='none')
# Compute focal weight
pt = torch.exp(-ce_loss)
focal_weight = self.alpha * (1 - pt) ** self.gamma
# Compute focal loss
focal_loss = focal_weight * ce_loss
if self.reduction == 'mean':
return torch.mean(focal_loss)
elif self.reduction == 'sum':
return torch.sum(focal_loss)
else:
return focal_loss
class AdaptiveLoss(nn.Module):
"""
Adaptive loss that adjusts weights based on training progress.
"""
def __init__(self,
initial_contrastive_weight: float = 1.0,
initial_triplet_weight: float = 0.5,
adaptation_rate: float = 0.01):
"""
Initialize adaptive loss.
Args:
initial_contrastive_weight: Initial weight for contrastive loss
initial_triplet_weight: Initial weight for triplet loss
adaptation_rate: Rate of weight adaptation
"""
super(AdaptiveLoss, self).__init__()
self.contrastive_weight = nn.Parameter(torch.tensor(initial_contrastive_weight))
self.triplet_weight = nn.Parameter(torch.tensor(initial_triplet_weight))
self.adaptation_rate = adaptation_rate
# Initialize loss functions
self.contrastive_loss = ContrastiveLoss()
self.triplet_loss = TripletLoss()
def forward(self,
similarity: torch.Tensor,
labels: torch.Tensor,
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor) -> Tuple[torch.Tensor, dict]:
"""
Compute adaptive loss.
Args:
similarity: Similarity scores
labels: Labels
anchor: Anchor features
positive: Positive features
negative: Negative features
Returns:
Tuple of (total_loss, loss_info)
"""
# Compute individual losses
contrastive_loss = self.contrastive_loss(similarity, labels)
triplet_loss = self.triplet_loss(anchor, positive, negative)
# Compute total loss
total_loss = (torch.sigmoid(self.contrastive_weight) * contrastive_loss +
torch.sigmoid(self.triplet_weight) * triplet_loss)
# Prepare loss info
loss_info = {
'contrastive_loss': contrastive_loss.item(),
'triplet_loss': triplet_loss.item(),
'contrastive_weight': torch.sigmoid(self.contrastive_weight).item(),
'triplet_weight': torch.sigmoid(self.triplet_weight).item(),
'total_loss': total_loss.item()
}
return total_loss, loss_info