""" Feature extraction module for signature verification using CNN-based approaches. """ import torch import torch.nn as nn import torchvision.models as models from typing import Tuple, Optional import torch.nn.functional as F class SignatureFeatureExtractor(nn.Module): """ CNN-based feature extractor for signature images. """ def __init__(self, backbone: str = 'resnet18', feature_dim: int = 512, pretrained: bool = True, freeze_backbone: bool = False): """ Initialize the feature extractor. Args: backbone: Backbone architecture ('resnet18', 'resnet34', 'resnet50', 'efficientnet') feature_dim: Dimension of output features pretrained: Whether to use pretrained weights freeze_backbone: Whether to freeze backbone parameters """ super(SignatureFeatureExtractor, self).__init__() self.backbone_name = backbone self.feature_dim = feature_dim self.pretrained = pretrained # Load backbone self.backbone = self._get_backbone(backbone, pretrained) # Freeze backbone if specified if freeze_backbone: for param in self.backbone.parameters(): param.requires_grad = False # Get the number of input features from backbone if 'resnet' in backbone: backbone_features = self.backbone.fc.in_features self.backbone.fc = nn.Identity() # Remove final classification layer elif 'efficientnet' in backbone: backbone_features = self.backbone.classifier.in_features self.backbone.classifier = nn.Identity() else: raise ValueError(f"Unsupported backbone: {backbone}") # Feature projection layers self.feature_projection = nn.Sequential( nn.Linear(backbone_features, feature_dim * 2), nn.BatchNorm1d(feature_dim * 2), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(feature_dim * 2, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True) ) # Initialize weights self._initialize_weights() def _get_backbone(self, backbone: str, pretrained: bool): """Get the backbone model.""" if backbone == 'resnet18': return models.resnet18(pretrained=pretrained) elif backbone == 'resnet34': return models.resnet34(pretrained=pretrained) elif backbone == 'resnet50': return models.resnet50(pretrained=pretrained) elif backbone == 'efficientnet_b0': return models.efficientnet_b0(pretrained=pretrained) elif backbone == 'efficientnet_b1': return models.efficientnet_b1(pretrained=pretrained) else: raise ValueError(f"Unsupported backbone: {backbone}") def _initialize_weights(self): """Initialize weights for the projection layers.""" for m in self.feature_projection.modules(): if isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm1d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the feature extractor. Args: x: Input signature images (B, C, H, W) Returns: Extracted features (B, feature_dim) """ # Extract features using backbone features = self.backbone(x) # Project to desired feature dimension features = self.feature_projection(features) # L2 normalize features features = F.normalize(features, p=2, dim=1) return features class CustomCNNFeatureExtractor(nn.Module): """ Custom CNN architecture specifically designed for signature verification. """ def __init__(self, input_channels: int = 3, feature_dim: int = 512, dropout_rate: float = 0.3): """ Initialize the custom CNN feature extractor. Args: input_channels: Number of input channels feature_dim: Dimension of output features dropout_rate: Dropout rate for regularization """ super(CustomCNNFeatureExtractor, self).__init__() self.feature_dim = feature_dim self.dropout_rate = dropout_rate # Convolutional layers self.conv_layers = nn.Sequential( # First block nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.Conv2d(32, 32, kernel_size=3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # Second block nn.Conv2d(32, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.Conv2d(64, 64, kernel_size=3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # Third block nn.Conv2d(64, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.Conv2d(128, 128, kernel_size=3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # Fourth block nn.Conv2d(128, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.Conv2d(256, 256, kernel_size=3, padding=1), nn.BatchNorm2d(256), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), # Fifth block nn.Conv2d(256, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.Conv2d(512, 512, kernel_size=3, padding=1), nn.BatchNorm2d(512), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((1, 1)) ) # Fully connected layers self.fc_layers = nn.Sequential( nn.Flatten(), nn.Linear(512, feature_dim * 2), nn.BatchNorm1d(feature_dim * 2), nn.ReLU(inplace=True), nn.Dropout(dropout_rate), nn.Linear(feature_dim * 2, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True) ) # Initialize weights self._initialize_weights() def _initialize_weights(self): """Initialize weights for all layers.""" for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') if m.bias is not None: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) elif isinstance(m, nn.Linear): nn.init.xavier_uniform_(m.weight) if m.bias is not None: nn.init.constant_(m.bias, 0) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the custom CNN. Args: x: Input signature images (B, C, H, W) Returns: Extracted features (B, feature_dim) """ # Extract features using convolutional layers features = self.conv_layers(x) # Project to desired feature dimension features = self.fc_layers(features) # L2 normalize features features = F.normalize(features, p=2, dim=1) return features class MultiScaleFeatureExtractor(nn.Module): """ Multi-scale feature extractor that captures features at different scales. """ def __init__(self, input_channels: int = 3, feature_dim: int = 512, scales: list = [1, 2, 4]): """ Initialize the multi-scale feature extractor. Args: input_channels: Number of input channels feature_dim: Dimension of output features scales: List of scales for multi-scale processing """ super(MultiScaleFeatureExtractor, self).__init__() self.scales = scales self.feature_dim = feature_dim # Create feature extractors for each scale self.scale_extractors = nn.ModuleList() for scale in scales: extractor = CustomCNNFeatureExtractor( input_channels=input_channels, feature_dim=feature_dim // len(scales) ) self.scale_extractors.append(extractor) # Fusion layer self.fusion = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True), nn.Dropout(0.3) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the multi-scale extractor. Args: x: Input signature images (B, C, H, W) Returns: Multi-scale features (B, feature_dim) """ scale_features = [] for i, scale in enumerate(self.scales): # Resize input to different scales if scale != 1: scaled_x = F.interpolate(x, scale_factor=1/scale, mode='bilinear', align_corners=False) else: scaled_x = x # Extract features at this scale features = self.scale_extractors[i](scaled_x) scale_features.append(features) # Concatenate features from all scales multi_scale_features = torch.cat(scale_features, dim=1) # Fuse features fused_features = self.fusion(multi_scale_features) # L2 normalize features fused_features = F.normalize(fused_features, p=2, dim=1) return fused_features class AttentionFeatureExtractor(nn.Module): """ Feature extractor with attention mechanism for focusing on important signature regions. """ def __init__(self, input_channels: int = 3, feature_dim: int = 512, attention_dim: int = 256): """ Initialize the attention-based feature extractor. Args: input_channels: Number of input channels feature_dim: Dimension of output features attention_dim: Dimension of attention features """ super(AttentionFeatureExtractor, self).__init__() self.feature_dim = feature_dim self.attention_dim = attention_dim # Base feature extractor self.base_extractor = CustomCNNFeatureExtractor( input_channels=input_channels, feature_dim=feature_dim ) # Attention mechanism self.attention_conv = nn.Sequential( nn.Conv2d(512, attention_dim, kernel_size=1), nn.BatchNorm2d(attention_dim), nn.ReLU(inplace=True), nn.Conv2d(attention_dim, 1, kernel_size=1), nn.Sigmoid() ) # Feature refinement self.feature_refinement = nn.Sequential( nn.Linear(feature_dim, feature_dim), nn.BatchNorm1d(feature_dim), nn.ReLU(inplace=True), nn.Dropout(0.3) ) def forward(self, x: torch.Tensor) -> torch.Tensor: """ Forward pass through the attention-based extractor. Args: x: Input signature images (B, C, H, W) Returns: Attention-weighted features (B, feature_dim) """ # Get base features base_features = self.base_extractor(x) # Get attention map (simplified - in practice, you'd extract intermediate features) # For now, we'll use a simplified approach attention_map = self.attention_conv(x.mean(dim=1, keepdim=True)) # Apply attention to features (simplified) attended_features = base_features * attention_map.mean(dim=[2, 3], keepdim=True).squeeze() # Refine features refined_features = self.feature_refinement(attended_features) # L2 normalize features refined_features = F.normalize(refined_features, p=2, dim=1) return refined_features