InklyAI / src /models /feature_extractor.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Feature extraction module for signature verification using CNN-based approaches.
"""
import torch
import torch.nn as nn
import torchvision.models as models
from typing import Tuple, Optional
import torch.nn.functional as F
class SignatureFeatureExtractor(nn.Module):
"""
CNN-based feature extractor for signature images.
"""
def __init__(self,
backbone: str = 'resnet18',
feature_dim: int = 512,
pretrained: bool = True,
freeze_backbone: bool = False):
"""
Initialize the feature extractor.
Args:
backbone: Backbone architecture ('resnet18', 'resnet34', 'resnet50', 'efficientnet')
feature_dim: Dimension of output features
pretrained: Whether to use pretrained weights
freeze_backbone: Whether to freeze backbone parameters
"""
super(SignatureFeatureExtractor, self).__init__()
self.backbone_name = backbone
self.feature_dim = feature_dim
self.pretrained = pretrained
# Load backbone
self.backbone = self._get_backbone(backbone, pretrained)
# Freeze backbone if specified
if freeze_backbone:
for param in self.backbone.parameters():
param.requires_grad = False
# Get the number of input features from backbone
if 'resnet' in backbone:
backbone_features = self.backbone.fc.in_features
self.backbone.fc = nn.Identity() # Remove final classification layer
elif 'efficientnet' in backbone:
backbone_features = self.backbone.classifier.in_features
self.backbone.classifier = nn.Identity()
else:
raise ValueError(f"Unsupported backbone: {backbone}")
# Feature projection layers
self.feature_projection = nn.Sequential(
nn.Linear(backbone_features, feature_dim * 2),
nn.BatchNorm1d(feature_dim * 2),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(feature_dim * 2, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True)
)
# Initialize weights
self._initialize_weights()
def _get_backbone(self, backbone: str, pretrained: bool):
"""Get the backbone model."""
if backbone == 'resnet18':
return models.resnet18(pretrained=pretrained)
elif backbone == 'resnet34':
return models.resnet34(pretrained=pretrained)
elif backbone == 'resnet50':
return models.resnet50(pretrained=pretrained)
elif backbone == 'efficientnet_b0':
return models.efficientnet_b0(pretrained=pretrained)
elif backbone == 'efficientnet_b1':
return models.efficientnet_b1(pretrained=pretrained)
else:
raise ValueError(f"Unsupported backbone: {backbone}")
def _initialize_weights(self):
"""Initialize weights for the projection layers."""
for m in self.feature_projection.modules():
if isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm1d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the feature extractor.
Args:
x: Input signature images (B, C, H, W)
Returns:
Extracted features (B, feature_dim)
"""
# Extract features using backbone
features = self.backbone(x)
# Project to desired feature dimension
features = self.feature_projection(features)
# L2 normalize features
features = F.normalize(features, p=2, dim=1)
return features
class CustomCNNFeatureExtractor(nn.Module):
"""
Custom CNN architecture specifically designed for signature verification.
"""
def __init__(self,
input_channels: int = 3,
feature_dim: int = 512,
dropout_rate: float = 0.3):
"""
Initialize the custom CNN feature extractor.
Args:
input_channels: Number of input channels
feature_dim: Dimension of output features
dropout_rate: Dropout rate for regularization
"""
super(CustomCNNFeatureExtractor, self).__init__()
self.feature_dim = feature_dim
self.dropout_rate = dropout_rate
# Convolutional layers
self.conv_layers = nn.Sequential(
# First block
nn.Conv2d(input_channels, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.Conv2d(32, 32, kernel_size=3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Second block
nn.Conv2d(32, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.Conv2d(64, 64, kernel_size=3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Third block
nn.Conv2d(64, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.Conv2d(128, 128, kernel_size=3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Fourth block
nn.Conv2d(128, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.Conv2d(256, 256, kernel_size=3, padding=1),
nn.BatchNorm2d(256),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
# Fifth block
nn.Conv2d(256, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.Conv2d(512, 512, kernel_size=3, padding=1),
nn.BatchNorm2d(512),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((1, 1))
)
# Fully connected layers
self.fc_layers = nn.Sequential(
nn.Flatten(),
nn.Linear(512, feature_dim * 2),
nn.BatchNorm1d(feature_dim * 2),
nn.ReLU(inplace=True),
nn.Dropout(dropout_rate),
nn.Linear(feature_dim * 2, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True)
)
# Initialize weights
self._initialize_weights()
def _initialize_weights(self):
"""Initialize weights for all layers."""
for m in self.modules():
if isinstance(m, nn.Conv2d):
nn.init.kaiming_normal_(m.weight, mode='fan_out', nonlinearity='relu')
if m.bias is not None:
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.BatchNorm2d):
nn.init.constant_(m.weight, 1)
nn.init.constant_(m.bias, 0)
elif isinstance(m, nn.Linear):
nn.init.xavier_uniform_(m.weight)
if m.bias is not None:
nn.init.constant_(m.bias, 0)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the custom CNN.
Args:
x: Input signature images (B, C, H, W)
Returns:
Extracted features (B, feature_dim)
"""
# Extract features using convolutional layers
features = self.conv_layers(x)
# Project to desired feature dimension
features = self.fc_layers(features)
# L2 normalize features
features = F.normalize(features, p=2, dim=1)
return features
class MultiScaleFeatureExtractor(nn.Module):
"""
Multi-scale feature extractor that captures features at different scales.
"""
def __init__(self,
input_channels: int = 3,
feature_dim: int = 512,
scales: list = [1, 2, 4]):
"""
Initialize the multi-scale feature extractor.
Args:
input_channels: Number of input channels
feature_dim: Dimension of output features
scales: List of scales for multi-scale processing
"""
super(MultiScaleFeatureExtractor, self).__init__()
self.scales = scales
self.feature_dim = feature_dim
# Create feature extractors for each scale
self.scale_extractors = nn.ModuleList()
for scale in scales:
extractor = CustomCNNFeatureExtractor(
input_channels=input_channels,
feature_dim=feature_dim // len(scales)
)
self.scale_extractors.append(extractor)
# Fusion layer
self.fusion = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the multi-scale extractor.
Args:
x: Input signature images (B, C, H, W)
Returns:
Multi-scale features (B, feature_dim)
"""
scale_features = []
for i, scale in enumerate(self.scales):
# Resize input to different scales
if scale != 1:
scaled_x = F.interpolate(x, scale_factor=1/scale, mode='bilinear', align_corners=False)
else:
scaled_x = x
# Extract features at this scale
features = self.scale_extractors[i](scaled_x)
scale_features.append(features)
# Concatenate features from all scales
multi_scale_features = torch.cat(scale_features, dim=1)
# Fuse features
fused_features = self.fusion(multi_scale_features)
# L2 normalize features
fused_features = F.normalize(fused_features, p=2, dim=1)
return fused_features
class AttentionFeatureExtractor(nn.Module):
"""
Feature extractor with attention mechanism for focusing on important signature regions.
"""
def __init__(self,
input_channels: int = 3,
feature_dim: int = 512,
attention_dim: int = 256):
"""
Initialize the attention-based feature extractor.
Args:
input_channels: Number of input channels
feature_dim: Dimension of output features
attention_dim: Dimension of attention features
"""
super(AttentionFeatureExtractor, self).__init__()
self.feature_dim = feature_dim
self.attention_dim = attention_dim
# Base feature extractor
self.base_extractor = CustomCNNFeatureExtractor(
input_channels=input_channels,
feature_dim=feature_dim
)
# Attention mechanism
self.attention_conv = nn.Sequential(
nn.Conv2d(512, attention_dim, kernel_size=1),
nn.BatchNorm2d(attention_dim),
nn.ReLU(inplace=True),
nn.Conv2d(attention_dim, 1, kernel_size=1),
nn.Sigmoid()
)
# Feature refinement
self.feature_refinement = nn.Sequential(
nn.Linear(feature_dim, feature_dim),
nn.BatchNorm1d(feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3)
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the attention-based extractor.
Args:
x: Input signature images (B, C, H, W)
Returns:
Attention-weighted features (B, feature_dim)
"""
# Get base features
base_features = self.base_extractor(x)
# Get attention map (simplified - in practice, you'd extract intermediate features)
# For now, we'll use a simplified approach
attention_map = self.attention_conv(x.mean(dim=1, keepdim=True))
# Apply attention to features (simplified)
attended_features = base_features * attention_map.mean(dim=[2, 3], keepdim=True).squeeze()
# Refine features
refined_features = self.feature_refinement(attended_features)
# L2 normalize features
refined_features = F.normalize(refined_features, p=2, dim=1)
return refined_features