""" Siamese network implementation for signature verification. """ import torch import torch.nn as nn import torch.nn.functional as F from typing import Tuple, Optional, Union import numpy as np from .feature_extractor import SignatureFeatureExtractor, CustomCNNFeatureExtractor class SiameseNetwork(nn.Module): """ Siamese network for signature verification using twin feature extractors. """ def __init__(self, feature_extractor: str = 'resnet18', feature_dim: int = 512, distance_metric: str = 'cosine', pretrained: bool = True): """ Initialize the Siamese network. Args: feature_extractor: Type of feature extractor ('resnet18', 'resnet34', 'resnet50', 'custom') feature_dim: Dimension of feature vectors distance_metric: Distance metric ('cosine', 'euclidean', 'learned') pretrained: Whether to use pretrained weights """ super(SiameseNetwork, self).__init__() self.feature_dim = feature_dim self.distance_metric = distance_metric # Create feature extractor if feature_extractor == 'custom': self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim) else: self.feature_extractor = SignatureFeatureExtractor( backbone=feature_extractor, feature_dim=feature_dim, pretrained=pretrained ) # Distance metric layer if distance_metric == 'learned': self.distance_layer = nn.Sequential( nn.Linear(feature_dim * 2, feature_dim), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(feature_dim, 1), nn.Sigmoid() ) else: self.distance_layer = None def forward(self, signature1: torch.Tensor, signature2: torch.Tensor) -> torch.Tensor: """ Forward pass through the Siamese network. Args: signature1: First signature batch (B, C, H, W) signature2: Second signature batch (B, C, H, W) Returns: Similarity scores (B, 1) or distances (B, 1) """ # Extract features from both signatures features1 = self.feature_extractor(signature1) features2 = self.feature_extractor(signature2) # Compute similarity/distance if self.distance_metric == 'cosine': similarity = F.cosine_similarity(features1, features2, dim=1) return similarity.unsqueeze(1) elif self.distance_metric == 'euclidean': distance = F.pairwise_distance(features1, features2) # Convert distance to similarity (inverse relationship) similarity = 1 / (1 + distance) return similarity.unsqueeze(1) elif self.distance_metric == 'learned': # Concatenate features and pass through learned distance layer combined_features = torch.cat([features1, features2], dim=1) similarity = self.distance_layer(combined_features) return similarity else: raise ValueError(f"Unsupported distance metric: {self.distance_metric}") def extract_features(self, signature: torch.Tensor) -> torch.Tensor: """ Extract features from a single signature. Args: signature: Signature batch (B, C, H, W) Returns: Feature vectors (B, feature_dim) """ return self.feature_extractor(signature) def compute_similarity(self, features1: torch.Tensor, features2: torch.Tensor) -> torch.Tensor: """ Compute similarity between two feature vectors. Args: features1: First feature batch (B, feature_dim) features2: Second feature batch (B, feature_dim) Returns: Similarity scores (B, 1) """ if self.distance_metric == 'cosine': return F.cosine_similarity(features1, features2, dim=1).unsqueeze(1) elif self.distance_metric == 'euclidean': distance = F.pairwise_distance(features1, features2) return (1 / (1 + distance)).unsqueeze(1) elif self.distance_metric == 'learned': combined_features = torch.cat([features1, features2], dim=1) return self.distance_layer(combined_features) else: raise ValueError(f"Unsupported distance metric: {self.distance_metric}") class TripletSiameseNetwork(nn.Module): """ Siamese network with triplet loss for signature verification. """ def __init__(self, feature_extractor: str = 'resnet18', feature_dim: int = 512, margin: float = 1.0, pretrained: bool = True): """ Initialize the triplet Siamese network. Args: feature_extractor: Type of feature extractor feature_dim: Dimension of feature vectors margin: Margin for triplet loss pretrained: Whether to use pretrained weights """ super(TripletSiameseNetwork, self).__init__() self.feature_dim = feature_dim self.margin = margin # Create feature extractor if feature_extractor == 'custom': self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim) else: self.feature_extractor = SignatureFeatureExtractor( backbone=feature_extractor, feature_dim=feature_dim, pretrained=pretrained ) def forward(self, anchor: torch.Tensor, positive: torch.Tensor, negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: """ Forward pass for triplet training. Args: anchor: Anchor signature batch (B, C, H, W) positive: Positive signature batch (B, C, H, W) negative: Negative signature batch (B, C, H, W) Returns: Tuple of (anchor_features, positive_features, negative_features) """ anchor_features = self.feature_extractor(anchor) positive_features = self.feature_extractor(positive) negative_features = self.feature_extractor(negative) return anchor_features, positive_features, negative_features def extract_features(self, signature: torch.Tensor) -> torch.Tensor: """ Extract features from a single signature. Args: signature: Signature batch (B, C, H, W) Returns: Feature vectors (B, feature_dim) """ return self.feature_extractor(signature) class SignatureVerifier: """ High-level interface for signature verification. """ def __init__(self, model_path: Optional[str] = None, feature_extractor: str = 'resnet18', feature_dim: int = 512, distance_metric: str = 'cosine', device: str = 'auto'): """ Initialize the signature verifier. Args: model_path: Path to saved model weights feature_extractor: Type of feature extractor feature_dim: Dimension of feature vectors distance_metric: Distance metric for comparison device: Device to run inference on ('auto', 'cpu', 'cuda') """ self.device = self._get_device(device) self.feature_dim = feature_dim # Initialize model self.model = SiameseNetwork( feature_extractor=feature_extractor, feature_dim=feature_dim, distance_metric=distance_metric ).to(self.device) # Load weights if provided if model_path: self.load_model(model_path) if hasattr(self.model, 'eval'): self.model.eval() def _get_device(self, device: str) -> torch.device: """Get the appropriate device for inference.""" if device == 'auto': return torch.device('cuda' if torch.cuda.is_available() else 'cpu') else: return torch.device(device) def load_model(self, model_path: str): """Load model weights from file.""" checkpoint = torch.load(model_path, map_location=self.device) if 'model_state_dict' in checkpoint: self.model.load_state_dict(checkpoint['model_state_dict']) else: self.model.load_state_dict(checkpoint) def save_model(self, model_path: str): """Save model weights to file.""" torch.save({ 'model_state_dict': self.model.state_dict(), 'feature_dim': self.feature_dim }, model_path) def verify_signatures(self, signature1: Union[str, torch.Tensor, np.ndarray], signature2: Union[str, torch.Tensor, np.ndarray], threshold: float = 0.5) -> Tuple[float, bool]: """ Verify if two signatures belong to the same person. Args: signature1: First signature (file path, tensor or numpy array) signature2: Second signature (file path, tensor or numpy array) threshold: Similarity threshold for verification Returns: Tuple of (similarity_score, is_genuine) """ # Handle file paths if isinstance(signature1, str): from ..data.preprocessing import SignaturePreprocessor preprocessor = SignaturePreprocessor() signature1 = preprocessor.preprocess_image(signature1) if isinstance(signature2, str): from ..data.preprocessing import SignaturePreprocessor preprocessor = SignaturePreprocessor() signature2 = preprocessor.preprocess_image(signature2) # Convert to tensors if needed if isinstance(signature1, np.ndarray): signature1 = torch.from_numpy(signature1).float() if isinstance(signature2, np.ndarray): signature2 = torch.from_numpy(signature2).float() # Add batch dimension if needed if signature1.dim() == 3: signature1 = signature1.unsqueeze(0) if signature2.dim() == 3: signature2 = signature2.unsqueeze(0) # Move to device signature1 = signature1.to(self.device) signature2 = signature2.to(self.device) # Compute similarity with torch.no_grad(): similarity = self.model(signature1, signature2) similarity_score = similarity.item() is_genuine = similarity_score >= threshold return similarity_score, is_genuine def extract_signature_features(self, signature: Union[str, torch.Tensor, np.ndarray]) -> np.ndarray: """ Extract features from a signature. Args: signature: Signature (file path, tensor or numpy array) Returns: Feature vector as numpy array """ # Handle file paths if isinstance(signature, str): from ..data.preprocessing import SignaturePreprocessor preprocessor = SignaturePreprocessor() signature = preprocessor.preprocess_image(signature) # Convert to tensor if needed if isinstance(signature, np.ndarray): signature = torch.from_numpy(signature).float() # Add batch dimension if needed if signature.dim() == 3: signature = signature.unsqueeze(0) # Move to device signature = signature.to(self.device) # Extract features with torch.no_grad(): features = self.model.extract_features(signature) features = features.cpu().numpy() return features def batch_verify(self, signature_pairs: list, threshold: float = 0.5) -> list: """ Verify multiple signature pairs in batch. Args: signature_pairs: List of (signature1, signature2) tuples threshold: Similarity threshold for verification Returns: List of (similarity_score, is_genuine) tuples """ results = [] for sig1, sig2 in signature_pairs: similarity, is_genuine = self.verify_signatures(sig1, sig2, threshold) results.append((similarity, is_genuine)) return results