|
|
""" |
|
|
Feature extraction module for signature verification using CNN-based approaches. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torchvision.models as models |
|
|
from typing import Tuple, Optional |
|
|
import torch.nn.functional as F |
|
|
|
|
|
|
|
|
class SignatureFeatureExtractor(nn.Module): |
|
|
""" |
|
|
CNN-based feature extractor for signature images. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
backbone: str = 'resnet18', |
|
|
feature_dim: int = 512, |
|
|
pretrained: bool = True, |
|
|
freeze_backbone: bool = False): |
|
|
""" |
|
|
Initialize the feature extractor. |
|
|
|
|
|
Args: |
|
|
backbone: Backbone architecture ('resnet18', 'resnet34', 'resnet50', 'efficientnet') |
|
|
feature_dim: Dimension of output features |
|
|
pretrained: Whether to use pretrained weights |
|
|
freeze_backbone: Whether to freeze backbone parameters |
|
|
""" |
|
|
super(SignatureFeatureExtractor, self).__init__() |
|
|
|
|
|
self.backbone_name = backbone |
|
|
self.feature_dim = feature_dim |
|
|
self.pretrained = pretrained |
|
|
|
|
|
|
|
|
self.backbone = self._get_backbone(backbone, pretrained) |
|
|
|
|
|
|
|
|
if freeze_backbone: |
|
|
for param in self.backbone.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
if 'resnet' in backbone: |
|
|
backbone_features = self.backbone.fc.in_features |
|
|
self.backbone.fc = nn.Identity() |
|
|
elif 'efficientnet' in backbone: |
|
|
backbone_features = self.backbone.classifier.in_features |
|
|
self.backbone.classifier = nn.Identity() |
|
|
else: |
|
|
raise ValueError(f"Unsupported backbone: {backbone}") |
|
|
|
|
|
|
|
|
self.feature_projection = nn.Sequential( |
|
|
nn.Linear(backbone_features, feature_dim * 2), |
|
|
nn.BatchNorm1d(feature_dim * 2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(feature_dim * 2, feature_dim), |
|
|
nn.BatchNorm1d(feature_dim), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _get_backbone(self, backbone: str, pretrained: bool): |
|
|
"""Get the backbone model.""" |
|
|
if backbone == 'resnet18': |
|
|
return models.resnet18(pretrained=pretrained) |
|
|
elif backbone == 'resnet34': |
|
|
return models.resnet34(pretrained=pretrained) |
|
|
elif backbone == 'resnet50': |
|
|
return models.resnet50(pretrained=pretrained) |
|
|
elif backbone == 'efficientnet_b0': |
|
|
return models.efficientnet_b0(pretrained=pretrained) |
|
|
elif backbone == 'efficientnet_b1': |
|
|
return models.efficientnet_b1(pretrained=pretrained) |
|
|
else: |
|
|
raise ValueError(f"Unsupported backbone: {backbone}") |
|
|
|
|
|
def _initialize_weights(self): |
|
|
"""Initialize weights for the projection layers.""" |
|
|
for m in self.feature_projection.modules(): |
|
|
if isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm1d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the feature extractor. |
|
|
|
|
|
Args: |
|
|
x: Input signature images (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Extracted features (B, feature_dim) |
|
|
""" |
|
|
|
|
|
features = self.backbone(x) |
|
|
|
|
|
|
|
|
features = self.feature_projection(features) |
|
|
|
|
|
|
|
|
features = F.normalize(features, p=2, dim=1) |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
class CustomCNNFeatureExtractor(nn.Module): |
|
|
""" |
|
|
Custom CNN architecture specifically designed for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
input_channels: int = 3, |
|
|
feature_dim: int = 512, |
|
|
dropout_rate: float = 0.3): |
|
|
""" |
|
|
Initialize the custom CNN feature extractor. |
|
|
|
|
|
Args: |
|
|
input_channels: Number of input channels |
|
|
feature_dim: Dimension of output features |
|
|
dropout_rate: Dropout rate for regularization |
|
|
""" |
|
|
super(CustomCNNFeatureExtractor, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
self.dropout_rate = dropout_rate |
|
|
|
|
|
|
|
|
self.conv_layers = nn.Sequential( |
|
|
|
|
|
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(32, 32, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(32), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
|
|
|
nn.Conv2d(32, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(64, 64, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(64), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
|
|
|
nn.Conv2d(64, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(128, 128, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(128), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
|
|
|
nn.Conv2d(128, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(256, 256, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(256), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.MaxPool2d(2, 2), |
|
|
|
|
|
|
|
|
nn.Conv2d(256, 512, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(512, 512, kernel_size=3, padding=1), |
|
|
nn.BatchNorm2d(512), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.AdaptiveAvgPool2d((1, 1)) |
|
|
) |
|
|
|
|
|
|
|
|
self.fc_layers = nn.Sequential( |
|
|
nn.Flatten(), |
|
|
nn.Linear(512, feature_dim * 2), |
|
|
nn.BatchNorm1d(feature_dim * 2), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(dropout_rate), |
|
|
nn.Linear(feature_dim * 2, feature_dim), |
|
|
nn.BatchNorm1d(feature_dim), |
|
|
nn.ReLU(inplace=True) |
|
|
) |
|
|
|
|
|
|
|
|
self._initialize_weights() |
|
|
|
|
|
def _initialize_weights(self): |
|
|
"""Initialize weights for all layers.""" |
|
|
for m in self.modules(): |
|
|
if isinstance(m, nn.Conv2d): |
|
|
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu') |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.BatchNorm2d): |
|
|
nn.init.constant_(m.weight, 1) |
|
|
nn.init.constant_(m.bias, 0) |
|
|
elif isinstance(m, nn.Linear): |
|
|
nn.init.xavier_uniform_(m.weight) |
|
|
if m.bias is not None: |
|
|
nn.init.constant_(m.bias, 0) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the custom CNN. |
|
|
|
|
|
Args: |
|
|
x: Input signature images (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Extracted features (B, feature_dim) |
|
|
""" |
|
|
|
|
|
features = self.conv_layers(x) |
|
|
|
|
|
|
|
|
features = self.fc_layers(features) |
|
|
|
|
|
|
|
|
features = F.normalize(features, p=2, dim=1) |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
class MultiScaleFeatureExtractor(nn.Module): |
|
|
""" |
|
|
Multi-scale feature extractor that captures features at different scales. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
input_channels: int = 3, |
|
|
feature_dim: int = 512, |
|
|
scales: list = [1, 2, 4]): |
|
|
""" |
|
|
Initialize the multi-scale feature extractor. |
|
|
|
|
|
Args: |
|
|
input_channels: Number of input channels |
|
|
feature_dim: Dimension of output features |
|
|
scales: List of scales for multi-scale processing |
|
|
""" |
|
|
super(MultiScaleFeatureExtractor, self).__init__() |
|
|
|
|
|
self.scales = scales |
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
|
self.scale_extractors = nn.ModuleList() |
|
|
for scale in scales: |
|
|
extractor = CustomCNNFeatureExtractor( |
|
|
input_channels=input_channels, |
|
|
feature_dim=feature_dim // len(scales) |
|
|
) |
|
|
self.scale_extractors.append(extractor) |
|
|
|
|
|
|
|
|
self.fusion = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.BatchNorm1d(feature_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.3) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the multi-scale extractor. |
|
|
|
|
|
Args: |
|
|
x: Input signature images (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Multi-scale features (B, feature_dim) |
|
|
""" |
|
|
scale_features = [] |
|
|
|
|
|
for i, scale in enumerate(self.scales): |
|
|
|
|
|
if scale != 1: |
|
|
scaled_x = F.interpolate(x, scale_factor=1/scale, mode='bilinear', align_corners=False) |
|
|
else: |
|
|
scaled_x = x |
|
|
|
|
|
|
|
|
features = self.scale_extractors[i](scaled_x) |
|
|
scale_features.append(features) |
|
|
|
|
|
|
|
|
multi_scale_features = torch.cat(scale_features, dim=1) |
|
|
|
|
|
|
|
|
fused_features = self.fusion(multi_scale_features) |
|
|
|
|
|
|
|
|
fused_features = F.normalize(fused_features, p=2, dim=1) |
|
|
|
|
|
return fused_features |
|
|
|
|
|
|
|
|
class AttentionFeatureExtractor(nn.Module): |
|
|
""" |
|
|
Feature extractor with attention mechanism for focusing on important signature regions. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
input_channels: int = 3, |
|
|
feature_dim: int = 512, |
|
|
attention_dim: int = 256): |
|
|
""" |
|
|
Initialize the attention-based feature extractor. |
|
|
|
|
|
Args: |
|
|
input_channels: Number of input channels |
|
|
feature_dim: Dimension of output features |
|
|
attention_dim: Dimension of attention features |
|
|
""" |
|
|
super(AttentionFeatureExtractor, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
self.attention_dim = attention_dim |
|
|
|
|
|
|
|
|
self.base_extractor = CustomCNNFeatureExtractor( |
|
|
input_channels=input_channels, |
|
|
feature_dim=feature_dim |
|
|
) |
|
|
|
|
|
|
|
|
self.attention_conv = nn.Sequential( |
|
|
nn.Conv2d(512, attention_dim, kernel_size=1), |
|
|
nn.BatchNorm2d(attention_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Conv2d(attention_dim, 1, kernel_size=1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
|
|
|
|
|
|
self.feature_refinement = nn.Sequential( |
|
|
nn.Linear(feature_dim, feature_dim), |
|
|
nn.BatchNorm1d(feature_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.3) |
|
|
) |
|
|
|
|
|
def forward(self, x: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the attention-based extractor. |
|
|
|
|
|
Args: |
|
|
x: Input signature images (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Attention-weighted features (B, feature_dim) |
|
|
""" |
|
|
|
|
|
base_features = self.base_extractor(x) |
|
|
|
|
|
|
|
|
|
|
|
attention_map = self.attention_conv(x.mean(dim=1, keepdim=True)) |
|
|
|
|
|
|
|
|
attended_features = base_features * attention_map.mean(dim=[2, 3], keepdim=True).squeeze() |
|
|
|
|
|
|
|
|
refined_features = self.feature_refinement(attended_features) |
|
|
|
|
|
|
|
|
refined_features = F.normalize(refined_features, p=2, dim=1) |
|
|
|
|
|
return refined_features |
|
|
|