|
|
""" |
|
|
Siamese network implementation for signature verification. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import torch.nn as nn |
|
|
import torch.nn.functional as F |
|
|
from typing import Tuple, Optional, Union |
|
|
import numpy as np |
|
|
from .feature_extractor import SignatureFeatureExtractor, CustomCNNFeatureExtractor |
|
|
|
|
|
|
|
|
class SiameseNetwork(nn.Module): |
|
|
""" |
|
|
Siamese network for signature verification using twin feature extractors. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
feature_extractor: str = 'resnet18', |
|
|
feature_dim: int = 512, |
|
|
distance_metric: str = 'cosine', |
|
|
pretrained: bool = True): |
|
|
""" |
|
|
Initialize the Siamese network. |
|
|
|
|
|
Args: |
|
|
feature_extractor: Type of feature extractor ('resnet18', 'resnet34', 'resnet50', 'custom') |
|
|
feature_dim: Dimension of feature vectors |
|
|
distance_metric: Distance metric ('cosine', 'euclidean', 'learned') |
|
|
pretrained: Whether to use pretrained weights |
|
|
""" |
|
|
super(SiameseNetwork, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
self.distance_metric = distance_metric |
|
|
|
|
|
|
|
|
if feature_extractor == 'custom': |
|
|
self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim) |
|
|
else: |
|
|
self.feature_extractor = SignatureFeatureExtractor( |
|
|
backbone=feature_extractor, |
|
|
feature_dim=feature_dim, |
|
|
pretrained=pretrained |
|
|
) |
|
|
|
|
|
|
|
|
if distance_metric == 'learned': |
|
|
self.distance_layer = nn.Sequential( |
|
|
nn.Linear(feature_dim * 2, feature_dim), |
|
|
nn.ReLU(inplace=True), |
|
|
nn.Dropout(0.3), |
|
|
nn.Linear(feature_dim, 1), |
|
|
nn.Sigmoid() |
|
|
) |
|
|
else: |
|
|
self.distance_layer = None |
|
|
|
|
|
def forward(self, |
|
|
signature1: torch.Tensor, |
|
|
signature2: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Forward pass through the Siamese network. |
|
|
|
|
|
Args: |
|
|
signature1: First signature batch (B, C, H, W) |
|
|
signature2: Second signature batch (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Similarity scores (B, 1) or distances (B, 1) |
|
|
""" |
|
|
|
|
|
features1 = self.feature_extractor(signature1) |
|
|
features2 = self.feature_extractor(signature2) |
|
|
|
|
|
|
|
|
if self.distance_metric == 'cosine': |
|
|
similarity = F.cosine_similarity(features1, features2, dim=1) |
|
|
return similarity.unsqueeze(1) |
|
|
|
|
|
elif self.distance_metric == 'euclidean': |
|
|
distance = F.pairwise_distance(features1, features2) |
|
|
|
|
|
similarity = 1 / (1 + distance) |
|
|
return similarity.unsqueeze(1) |
|
|
|
|
|
elif self.distance_metric == 'learned': |
|
|
|
|
|
combined_features = torch.cat([features1, features2], dim=1) |
|
|
similarity = self.distance_layer(combined_features) |
|
|
return similarity |
|
|
|
|
|
else: |
|
|
raise ValueError(f"Unsupported distance metric: {self.distance_metric}") |
|
|
|
|
|
def extract_features(self, signature: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Extract features from a single signature. |
|
|
|
|
|
Args: |
|
|
signature: Signature batch (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Feature vectors (B, feature_dim) |
|
|
""" |
|
|
return self.feature_extractor(signature) |
|
|
|
|
|
def compute_similarity(self, |
|
|
features1: torch.Tensor, |
|
|
features2: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Compute similarity between two feature vectors. |
|
|
|
|
|
Args: |
|
|
features1: First feature batch (B, feature_dim) |
|
|
features2: Second feature batch (B, feature_dim) |
|
|
|
|
|
Returns: |
|
|
Similarity scores (B, 1) |
|
|
""" |
|
|
if self.distance_metric == 'cosine': |
|
|
return F.cosine_similarity(features1, features2, dim=1).unsqueeze(1) |
|
|
elif self.distance_metric == 'euclidean': |
|
|
distance = F.pairwise_distance(features1, features2) |
|
|
return (1 / (1 + distance)).unsqueeze(1) |
|
|
elif self.distance_metric == 'learned': |
|
|
combined_features = torch.cat([features1, features2], dim=1) |
|
|
return self.distance_layer(combined_features) |
|
|
else: |
|
|
raise ValueError(f"Unsupported distance metric: {self.distance_metric}") |
|
|
|
|
|
|
|
|
class TripletSiameseNetwork(nn.Module): |
|
|
""" |
|
|
Siamese network with triplet loss for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
feature_extractor: str = 'resnet18', |
|
|
feature_dim: int = 512, |
|
|
margin: float = 1.0, |
|
|
pretrained: bool = True): |
|
|
""" |
|
|
Initialize the triplet Siamese network. |
|
|
|
|
|
Args: |
|
|
feature_extractor: Type of feature extractor |
|
|
feature_dim: Dimension of feature vectors |
|
|
margin: Margin for triplet loss |
|
|
pretrained: Whether to use pretrained weights |
|
|
""" |
|
|
super(TripletSiameseNetwork, self).__init__() |
|
|
|
|
|
self.feature_dim = feature_dim |
|
|
self.margin = margin |
|
|
|
|
|
|
|
|
if feature_extractor == 'custom': |
|
|
self.feature_extractor = CustomCNNFeatureExtractor(feature_dim=feature_dim) |
|
|
else: |
|
|
self.feature_extractor = SignatureFeatureExtractor( |
|
|
backbone=feature_extractor, |
|
|
feature_dim=feature_dim, |
|
|
pretrained=pretrained |
|
|
) |
|
|
|
|
|
def forward(self, |
|
|
anchor: torch.Tensor, |
|
|
positive: torch.Tensor, |
|
|
negative: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]: |
|
|
""" |
|
|
Forward pass for triplet training. |
|
|
|
|
|
Args: |
|
|
anchor: Anchor signature batch (B, C, H, W) |
|
|
positive: Positive signature batch (B, C, H, W) |
|
|
negative: Negative signature batch (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Tuple of (anchor_features, positive_features, negative_features) |
|
|
""" |
|
|
anchor_features = self.feature_extractor(anchor) |
|
|
positive_features = self.feature_extractor(positive) |
|
|
negative_features = self.feature_extractor(negative) |
|
|
|
|
|
return anchor_features, positive_features, negative_features |
|
|
|
|
|
def extract_features(self, signature: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Extract features from a single signature. |
|
|
|
|
|
Args: |
|
|
signature: Signature batch (B, C, H, W) |
|
|
|
|
|
Returns: |
|
|
Feature vectors (B, feature_dim) |
|
|
""" |
|
|
return self.feature_extractor(signature) |
|
|
|
|
|
|
|
|
class SignatureVerifier: |
|
|
""" |
|
|
High-level interface for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model_path: Optional[str] = None, |
|
|
feature_extractor: str = 'resnet18', |
|
|
feature_dim: int = 512, |
|
|
distance_metric: str = 'cosine', |
|
|
device: str = 'auto'): |
|
|
""" |
|
|
Initialize the signature verifier. |
|
|
|
|
|
Args: |
|
|
model_path: Path to saved model weights |
|
|
feature_extractor: Type of feature extractor |
|
|
feature_dim: Dimension of feature vectors |
|
|
distance_metric: Distance metric for comparison |
|
|
device: Device to run inference on ('auto', 'cpu', 'cuda') |
|
|
""" |
|
|
self.device = self._get_device(device) |
|
|
self.feature_dim = feature_dim |
|
|
|
|
|
|
|
|
self.model = SiameseNetwork( |
|
|
feature_extractor=feature_extractor, |
|
|
feature_dim=feature_dim, |
|
|
distance_metric=distance_metric |
|
|
).to(self.device) |
|
|
|
|
|
|
|
|
if model_path: |
|
|
self.load_model(model_path) |
|
|
|
|
|
if hasattr(self.model, 'eval'): |
|
|
self.model.eval() |
|
|
|
|
|
def _get_device(self, device: str) -> torch.device: |
|
|
"""Get the appropriate device for inference.""" |
|
|
if device == 'auto': |
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
return torch.device(device) |
|
|
|
|
|
def load_model(self, model_path: str): |
|
|
"""Load model weights from file.""" |
|
|
checkpoint = torch.load(model_path, map_location=self.device) |
|
|
if 'model_state_dict' in checkpoint: |
|
|
self.model.load_state_dict(checkpoint['model_state_dict']) |
|
|
else: |
|
|
self.model.load_state_dict(checkpoint) |
|
|
|
|
|
def save_model(self, model_path: str): |
|
|
"""Save model weights to file.""" |
|
|
torch.save({ |
|
|
'model_state_dict': self.model.state_dict(), |
|
|
'feature_dim': self.feature_dim |
|
|
}, model_path) |
|
|
|
|
|
def verify_signatures(self, |
|
|
signature1: Union[str, torch.Tensor, np.ndarray], |
|
|
signature2: Union[str, torch.Tensor, np.ndarray], |
|
|
threshold: float = 0.5) -> Tuple[float, bool]: |
|
|
""" |
|
|
Verify if two signatures belong to the same person. |
|
|
|
|
|
Args: |
|
|
signature1: First signature (file path, tensor or numpy array) |
|
|
signature2: Second signature (file path, tensor or numpy array) |
|
|
threshold: Similarity threshold for verification |
|
|
|
|
|
Returns: |
|
|
Tuple of (similarity_score, is_genuine) |
|
|
""" |
|
|
|
|
|
if isinstance(signature1, str): |
|
|
from ..data.preprocessing import SignaturePreprocessor |
|
|
preprocessor = SignaturePreprocessor() |
|
|
signature1 = preprocessor.preprocess_image(signature1) |
|
|
if isinstance(signature2, str): |
|
|
from ..data.preprocessing import SignaturePreprocessor |
|
|
preprocessor = SignaturePreprocessor() |
|
|
signature2 = preprocessor.preprocess_image(signature2) |
|
|
|
|
|
|
|
|
if isinstance(signature1, np.ndarray): |
|
|
signature1 = torch.from_numpy(signature1).float() |
|
|
if isinstance(signature2, np.ndarray): |
|
|
signature2 = torch.from_numpy(signature2).float() |
|
|
|
|
|
|
|
|
if signature1.dim() == 3: |
|
|
signature1 = signature1.unsqueeze(0) |
|
|
if signature2.dim() == 3: |
|
|
signature2 = signature2.unsqueeze(0) |
|
|
|
|
|
|
|
|
signature1 = signature1.to(self.device) |
|
|
signature2 = signature2.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
similarity = self.model(signature1, signature2) |
|
|
similarity_score = similarity.item() |
|
|
is_genuine = similarity_score >= threshold |
|
|
|
|
|
return similarity_score, is_genuine |
|
|
|
|
|
def extract_signature_features(self, signature: Union[str, torch.Tensor, np.ndarray]) -> np.ndarray: |
|
|
""" |
|
|
Extract features from a signature. |
|
|
|
|
|
Args: |
|
|
signature: Signature (file path, tensor or numpy array) |
|
|
|
|
|
Returns: |
|
|
Feature vector as numpy array |
|
|
""" |
|
|
|
|
|
if isinstance(signature, str): |
|
|
from ..data.preprocessing import SignaturePreprocessor |
|
|
preprocessor = SignaturePreprocessor() |
|
|
signature = preprocessor.preprocess_image(signature) |
|
|
|
|
|
|
|
|
if isinstance(signature, np.ndarray): |
|
|
signature = torch.from_numpy(signature).float() |
|
|
|
|
|
|
|
|
if signature.dim() == 3: |
|
|
signature = signature.unsqueeze(0) |
|
|
|
|
|
|
|
|
signature = signature.to(self.device) |
|
|
|
|
|
|
|
|
with torch.no_grad(): |
|
|
features = self.model.extract_features(signature) |
|
|
features = features.cpu().numpy() |
|
|
|
|
|
return features |
|
|
|
|
|
def batch_verify(self, |
|
|
signature_pairs: list, |
|
|
threshold: float = 0.5) -> list: |
|
|
""" |
|
|
Verify multiple signature pairs in batch. |
|
|
|
|
|
Args: |
|
|
signature_pairs: List of (signature1, signature2) tuples |
|
|
threshold: Similarity threshold for verification |
|
|
|
|
|
Returns: |
|
|
List of (similarity_score, is_genuine) tuples |
|
|
""" |
|
|
results = [] |
|
|
for sig1, sig2 in signature_pairs: |
|
|
similarity, is_genuine = self.verify_signatures(sig1, sig2, threshold) |
|
|
results.append((similarity, is_genuine)) |
|
|
|
|
|
return results |
|
|
|