InklyAI / src /models /siamese_network.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Siamese network implementation for signature verification.
"""
import torch
import torch.nn as nn
import torch.nn.functional as F
from typing import Tuple, Optional, Union
import numpy as np
from .feature_extractor import SignatureFeatureExtractor, CustomCNNFeatureExtractor
class SiameseNetwork(nn.Module):
"""
Siamese network for signature verification using twin feature extractors.
"""
def __init__(self,
feature_extractor: str = 'resnet18',
feature_dim: int = 512,
distance_metric: str = 'cosine',
pretrained: bool = True):
"""
Initialize the Siamese network.
Args:
feature_extractor: Type of feature extractor ('resnet18', 'resnet34', 'resnet50', 'custom')
feature_dim: Dimension of feature vectors
distance_metric: Distance metric ('cosine', 'euclidean', 'learned')
pretrained: Whether to use pretrained weights
"""
super(SiameseNetwork, self).__init__()
self.feature_dim = feature_dim
self.distance_metric = distance_metric
# Create feature extractor
if feature_extractor == 'custom':
self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim)
else:
self.feature_extractor = SignatureFeatureExtractor(
backbone=feature_extractor,
feature_dim=feature_dim,
pretrained=pretrained
)
# Distance metric layer
if distance_metric == 'learned':
self.distance_layer = nn.Sequential(
nn.Linear(feature_dim * 2, feature_dim),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(feature_dim, 1),
nn.Sigmoid()
)
else:
self.distance_layer = None
def forward(self,
signature1: torch.Tensor,
signature2: torch.Tensor) -> torch.Tensor:
"""
Forward pass through the Siamese network.
Args:
signature1: First signature batch (B, C, H, W)
signature2: Second signature batch (B, C, H, W)
Returns:
Similarity scores (B, 1) or distances (B, 1)
"""
# Extract features from both signatures
features1 = self.feature_extractor(signature1)
features2 = self.feature_extractor(signature2)
# Compute similarity/distance
if self.distance_metric == 'cosine':
similarity = F.cosine_similarity(features1, features2, dim=1)
return similarity.unsqueeze(1)
elif self.distance_metric == 'euclidean':
distance = F.pairwise_distance(features1, features2)
# Convert distance to similarity (inverse relationship)
similarity = 1 / (1 + distance)
return similarity.unsqueeze(1)
elif self.distance_metric == 'learned':
# Concatenate features and pass through learned distance layer
combined_features = torch.cat([features1, features2], dim=1)
similarity = self.distance_layer(combined_features)
return similarity
else:
raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
def extract_features(self, signature: torch.Tensor) -> torch.Tensor:
"""
Extract features from a single signature.
Args:
signature: Signature batch (B, C, H, W)
Returns:
Feature vectors (B, feature_dim)
"""
return self.feature_extractor(signature)
def compute_similarity(self,
features1: torch.Tensor,
features2: torch.Tensor) -> torch.Tensor:
"""
Compute similarity between two feature vectors.
Args:
features1: First feature batch (B, feature_dim)
features2: Second feature batch (B, feature_dim)
Returns:
Similarity scores (B, 1)
"""
if self.distance_metric == 'cosine':
return F.cosine_similarity(features1, features2, dim=1).unsqueeze(1)
elif self.distance_metric == 'euclidean':
distance = F.pairwise_distance(features1, features2)
return (1 / (1 + distance)).unsqueeze(1)
elif self.distance_metric == 'learned':
combined_features = torch.cat([features1, features2], dim=1)
return self.distance_layer(combined_features)
else:
raise ValueError(f"Unsupported distance metric: {self.distance_metric}")
class TripletSiameseNetwork(nn.Module):
"""
Siamese network with triplet loss for signature verification.
"""
def __init__(self,
feature_extractor: str = 'resnet18',
feature_dim: int = 512,
margin: float = 1.0,
pretrained: bool = True):
"""
Initialize the triplet Siamese network.
Args:
feature_extractor: Type of feature extractor
feature_dim: Dimension of feature vectors
margin: Margin for triplet loss
pretrained: Whether to use pretrained weights
"""
super(TripletSiameseNetwork, self).__init__()
self.feature_dim = feature_dim
self.margin = margin
# Create feature extractor
if feature_extractor == 'custom':
self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim)
else:
self.feature_extractor = SignatureFeatureExtractor(
backbone=feature_extractor,
feature_dim=feature_dim,
pretrained=pretrained
)
def forward(self,
anchor: torch.Tensor,
positive: torch.Tensor,
negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
"""
Forward pass for triplet training.
Args:
anchor: Anchor signature batch (B, C, H, W)
positive: Positive signature batch (B, C, H, W)
negative: Negative signature batch (B, C, H, W)
Returns:
Tuple of (anchor_features, positive_features, negative_features)
"""
anchor_features = self.feature_extractor(anchor)
positive_features = self.feature_extractor(positive)
negative_features = self.feature_extractor(negative)
return anchor_features, positive_features, negative_features
def extract_features(self, signature: torch.Tensor) -> torch.Tensor:
"""
Extract features from a single signature.
Args:
signature: Signature batch (B, C, H, W)
Returns:
Feature vectors (B, feature_dim)
"""
return self.feature_extractor(signature)
class SignatureVerifier:
"""
High-level interface for signature verification.
"""
def __init__(self,
model_path: Optional[str] = None,
feature_extractor: str = 'resnet18',
feature_dim: int = 512,
distance_metric: str = 'cosine',
device: str = 'auto'):
"""
Initialize the signature verifier.
Args:
model_path: Path to saved model weights
feature_extractor: Type of feature extractor
feature_dim: Dimension of feature vectors
distance_metric: Distance metric for comparison
device: Device to run inference on ('auto', 'cpu', 'cuda')
"""
self.device = self._get_device(device)
self.feature_dim = feature_dim
# Initialize model
self.model = SiameseNetwork(
feature_extractor=feature_extractor,
feature_dim=feature_dim,
distance_metric=distance_metric
).to(self.device)
# Load weights if provided
if model_path:
self.load_model(model_path)
if hasattr(self.model, 'eval'):
self.model.eval()
def _get_device(self, device: str) -> torch.device:
"""Get the appropriate device for inference."""
if device == 'auto':
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
return torch.device(device)
def load_model(self, model_path: str):
"""Load model weights from file."""
checkpoint = torch.load(model_path, map_location=self.device)
if 'model_state_dict' in checkpoint:
self.model.load_state_dict(checkpoint['model_state_dict'])
else:
self.model.load_state_dict(checkpoint)
def save_model(self, model_path: str):
"""Save model weights to file."""
torch.save({
'model_state_dict': self.model.state_dict(),
'feature_dim': self.feature_dim
}, model_path)
def verify_signatures(self,
signature1: Union[str, torch.Tensor, np.ndarray],
signature2: Union[str, torch.Tensor, np.ndarray],
threshold: float = 0.5) -> Tuple[float, bool]:
"""
Verify if two signatures belong to the same person.
Args:
signature1: First signature (file path, tensor or numpy array)
signature2: Second signature (file path, tensor or numpy array)
threshold: Similarity threshold for verification
Returns:
Tuple of (similarity_score, is_genuine)
"""
# Handle file paths
if isinstance(signature1, str):
from ..data.preprocessing import SignaturePreprocessor
preprocessor = SignaturePreprocessor()
signature1 = preprocessor.preprocess_image(signature1)
if isinstance(signature2, str):
from ..data.preprocessing import SignaturePreprocessor
preprocessor = SignaturePreprocessor()
signature2 = preprocessor.preprocess_image(signature2)
# Convert to tensors if needed
if isinstance(signature1, np.ndarray):
signature1 = torch.from_numpy(signature1).float()
if isinstance(signature2, np.ndarray):
signature2 = torch.from_numpy(signature2).float()
# Add batch dimension if needed
if signature1.dim() == 3:
signature1 = signature1.unsqueeze(0)
if signature2.dim() == 3:
signature2 = signature2.unsqueeze(0)
# Move to device
signature1 = signature1.to(self.device)
signature2 = signature2.to(self.device)
# Compute similarity
with torch.no_grad():
similarity = self.model(signature1, signature2)
similarity_score = similarity.item()
is_genuine = similarity_score >= threshold
return similarity_score, is_genuine
def extract_signature_features(self, signature: Union[str, torch.Tensor, np.ndarray]) -> np.ndarray:
"""
Extract features from a signature.
Args:
signature: Signature (file path, tensor or numpy array)
Returns:
Feature vector as numpy array
"""
# Handle file paths
if isinstance(signature, str):
from ..data.preprocessing import SignaturePreprocessor
preprocessor = SignaturePreprocessor()
signature = preprocessor.preprocess_image(signature)
# Convert to tensor if needed
if isinstance(signature, np.ndarray):
signature = torch.from_numpy(signature).float()
# Add batch dimension if needed
if signature.dim() == 3:
signature = signature.unsqueeze(0)
# Move to device
signature = signature.to(self.device)
# Extract features
with torch.no_grad():
features = self.model.extract_features(signature)
features = features.cpu().numpy()
return features
def batch_verify(self,
signature_pairs: list,
threshold: float = 0.5) -> list:
"""
Verify multiple signature pairs in batch.
Args:
signature_pairs: List of (signature1, signature2) tuples
threshold: Similarity threshold for verification
Returns:
List of (similarity_score, is_genuine) tuples
"""
results = []
for sig1, sig2 in signature_pairs:
similarity, is_genuine = self.verify_signatures(sig1, sig2, threshold)
results.append((similarity, is_genuine))
return results