|
|
""" |
|
|
Demo script for signature verification model. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
import cv2 |
|
|
from PIL import Image |
|
|
import matplotlib.pyplot as plt |
|
|
import os |
|
|
import sys |
|
|
from pathlib import Path |
|
|
|
|
|
|
|
|
sys.path.append(str(Path(__file__).parent / 'src')) |
|
|
|
|
|
from src.models.siamese_network import SignatureVerifier |
|
|
from src.data.preprocessing import SignaturePreprocessor |
|
|
from src.evaluation.evaluator import SignatureEvaluator |
|
|
from src.training.trainer import SignatureTrainer, SignatureDataset |
|
|
from src.data.augmentation import SignatureAugmentationPipeline |
|
|
|
|
|
|
|
|
def create_sample_signatures(): |
|
|
"""Create sample signature images for demonstration.""" |
|
|
print("Creating sample signature images...") |
|
|
|
|
|
|
|
|
os.makedirs('data/samples', exist_ok=True) |
|
|
|
|
|
|
|
|
def create_signature_image(filename, style='normal'): |
|
|
"""Create a sample signature image.""" |
|
|
|
|
|
img = np.ones((224, 224, 3), dtype=np.uint8) * 255 |
|
|
|
|
|
if style == 'normal': |
|
|
|
|
|
points = [(50, 100), (80, 90), (120, 95), (160, 85), (180, 100)] |
|
|
for i in range(len(points) - 1): |
|
|
cv2.line(img, points[i], points[i + 1], (0, 0, 0), 3) |
|
|
|
|
|
|
|
|
cv2.ellipse(img, (60, 110), (20, 10), 0, 0, 180, (0, 0, 0), 2) |
|
|
cv2.ellipse(img, (170, 110), (15, 8), 0, 0, 180, (0, 0, 0), 2) |
|
|
|
|
|
elif style == 'cursive': |
|
|
|
|
|
points = [(40, 120), (70, 100), (100, 110), (130, 95), (160, 105), (190, 100)] |
|
|
for i in range(len(points) - 1): |
|
|
cv2.line(img, points[i], points[i + 1], (0, 0, 0), 4) |
|
|
|
|
|
|
|
|
cv2.ellipse(img, (50, 130), (25, 15), 0, 0, 180, (0, 0, 0), 2) |
|
|
cv2.ellipse(img, (180, 115), (20, 12), 0, 0, 180, (0, 0, 0), 2) |
|
|
|
|
|
elif style == 'simple': |
|
|
|
|
|
cv2.line(img, (50, 100), (180, 100), (0, 0, 0), 3) |
|
|
cv2.line(img, (50, 110), (180, 110), (0, 0, 0), 2) |
|
|
cv2.line(img, (50, 120), (180, 120), (0, 0, 0), 2) |
|
|
|
|
|
|
|
|
noise = np.random.normal(0, 10, img.shape).astype(np.uint8) |
|
|
img = np.clip(img.astype(np.int16) + noise, 0, 255).astype(np.uint8) |
|
|
|
|
|
|
|
|
cv2.imwrite(filename, cv2.cvtColor(img, cv2.COLOR_RGB2BGR)) |
|
|
return img |
|
|
|
|
|
|
|
|
signatures = [ |
|
|
('john_doe_1.png', 'normal'), |
|
|
('john_doe_2.png', 'normal'), |
|
|
('john_doe_3.png', 'cursive'), |
|
|
('jane_smith_1.png', 'simple'), |
|
|
('jane_smith_2.png', 'simple'), |
|
|
('jane_smith_3.png', 'cursive'), |
|
|
('bob_wilson_1.png', 'cursive'), |
|
|
('bob_wilson_2.png', 'cursive'), |
|
|
('bob_wilson_3.png', 'normal'), |
|
|
('alice_brown_1.png', 'simple'), |
|
|
('alice_brown_2.png', 'simple'), |
|
|
('alice_brown_3.png', 'normal'), |
|
|
] |
|
|
|
|
|
for filename, style in signatures: |
|
|
create_signature_image(f'data/samples/{filename}', style) |
|
|
|
|
|
print(f"Created {len(signatures)} sample signature images in data/samples/") |
|
|
return signatures |
|
|
|
|
|
|
|
|
def create_training_data(): |
|
|
"""Create training data pairs for demonstration.""" |
|
|
print("Creating training data pairs...") |
|
|
|
|
|
|
|
|
genuine_pairs = [ |
|
|
('data/samples/john_doe_1.png', 'data/samples/john_doe_2.png', 1), |
|
|
('data/samples/john_doe_1.png', 'data/samples/john_doe_3.png', 1), |
|
|
('data/samples/john_doe_2.png', 'data/samples/john_doe_3.png', 1), |
|
|
('data/samples/jane_smith_1.png', 'data/samples/jane_smith_2.png', 1), |
|
|
('data/samples/jane_smith_1.png', 'data/samples/jane_smith_3.png', 1), |
|
|
('data/samples/jane_smith_2.png', 'data/samples/jane_smith_3.png', 1), |
|
|
('data/samples/bob_wilson_1.png', 'data/samples/bob_wilson_2.png', 1), |
|
|
('data/samples/bob_wilson_1.png', 'data/samples/bob_wilson_3.png', 1), |
|
|
('data/samples/bob_wilson_2.png', 'data/samples/bob_wilson_3.png', 1), |
|
|
('data/samples/alice_brown_1.png', 'data/samples/alice_brown_2.png', 1), |
|
|
('data/samples/alice_brown_1.png', 'data/samples/alice_brown_3.png', 1), |
|
|
('data/samples/alice_brown_2.png', 'data/samples/alice_brown_3.png', 1), |
|
|
] |
|
|
|
|
|
|
|
|
forged_pairs = [ |
|
|
('data/samples/john_doe_1.png', 'data/samples/jane_smith_1.png', 0), |
|
|
('data/samples/john_doe_2.png', 'data/samples/bob_wilson_1.png', 0), |
|
|
('data/samples/john_doe_3.png', 'data/samples/alice_brown_1.png', 0), |
|
|
('data/samples/jane_smith_1.png', 'data/samples/bob_wilson_2.png', 0), |
|
|
('data/samples/jane_smith_2.png', 'data/samples/alice_brown_2.png', 0), |
|
|
('data/samples/jane_smith_3.png', 'data/samples/john_doe_1.png', 0), |
|
|
('data/samples/bob_wilson_1.png', 'data/samples/alice_brown_3.png', 0), |
|
|
('data/samples/bob_wilson_2.png', 'data/samples/john_doe_2.png', 0), |
|
|
('data/samples/bob_wilson_3.png', 'data/samples/jane_smith_1.png', 0), |
|
|
('data/samples/alice_brown_1.png', 'data/samples/john_doe_3.png', 0), |
|
|
('data/samples/alice_brown_2.png', 'data/samples/bob_wilson_1.png', 0), |
|
|
('data/samples/alice_brown_3.png', 'data/samples/jane_smith_2.png', 0), |
|
|
] |
|
|
|
|
|
|
|
|
all_pairs = genuine_pairs + forged_pairs |
|
|
|
|
|
print(f"Created {len(genuine_pairs)} genuine pairs and {len(forged_pairs)} forged pairs") |
|
|
return all_pairs |
|
|
|
|
|
|
|
|
def demo_basic_verification(): |
|
|
"""Demonstrate basic signature verification.""" |
|
|
print("\n" + "="*60) |
|
|
print("BASIC SIGNATURE VERIFICATION DEMO") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
signatures = create_sample_signatures() |
|
|
data_pairs = create_training_data() |
|
|
|
|
|
|
|
|
preprocessor = SignaturePreprocessor() |
|
|
verifier = SignatureVerifier(feature_extractor='resnet18', feature_dim=512) |
|
|
|
|
|
print("\nTesting signature verification on sample pairs...") |
|
|
|
|
|
|
|
|
test_pairs = [ |
|
|
('data/samples/john_doe_1.png', 'data/samples/john_doe_2.png', 'Genuine'), |
|
|
('data/samples/john_doe_1.png', 'data/samples/jane_smith_1.png', 'Forged'), |
|
|
('data/samples/jane_smith_1.png', 'data/samples/jane_smith_2.png', 'Genuine'), |
|
|
('data/samples/bob_wilson_1.png', 'data/samples/alice_brown_1.png', 'Forged'), |
|
|
] |
|
|
|
|
|
for sig1_path, sig2_path, expected in test_pairs: |
|
|
try: |
|
|
similarity, is_genuine = verifier.verify_signatures(sig1_path, sig2_path) |
|
|
result = "✓ GENUINE" if is_genuine else "✗ FORGED" |
|
|
correct = "✓" if (is_genuine and expected == "Genuine") or (not is_genuine and expected == "Forged") else "✗" |
|
|
|
|
|
print(f"{sig1_path} vs {sig2_path}") |
|
|
print(f" Expected: {expected}") |
|
|
print(f" Predicted: {result}") |
|
|
print(f" Similarity: {similarity:.4f}") |
|
|
print(f" Correct: {correct}") |
|
|
print() |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing {sig1_path} vs {sig2_path}: {e}") |
|
|
|
|
|
return verifier, preprocessor, data_pairs |
|
|
|
|
|
|
|
|
def demo_training(): |
|
|
"""Demonstrate model training.""" |
|
|
print("\n" + "="*60) |
|
|
print("MODEL TRAINING DEMO") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
signatures = create_sample_signatures() |
|
|
data_pairs = create_training_data() |
|
|
|
|
|
|
|
|
np.random.shuffle(data_pairs) |
|
|
split_idx = int(0.8 * len(data_pairs)) |
|
|
train_pairs = data_pairs[:split_idx] |
|
|
val_pairs = data_pairs[split_idx:] |
|
|
|
|
|
print(f"Training pairs: {len(train_pairs)}") |
|
|
print(f"Validation pairs: {len(val_pairs)}") |
|
|
|
|
|
|
|
|
preprocessor = SignaturePreprocessor() |
|
|
augmenter = SignatureAugmentationPipeline() |
|
|
|
|
|
|
|
|
train_dataset = SignatureDataset(train_pairs, preprocessor, augmenter, is_training=True) |
|
|
val_dataset = SignatureDataset(val_pairs, preprocessor, None, is_training=False) |
|
|
|
|
|
|
|
|
train_loader = torch.utils.data.DataLoader(train_dataset, batch_size=4, shuffle=True) |
|
|
val_loader = torch.utils.data.DataLoader(val_dataset, batch_size=4, shuffle=False) |
|
|
|
|
|
|
|
|
from src.models.siamese_network import SiameseNetwork |
|
|
model = SiameseNetwork(feature_extractor='resnet18', feature_dim=512) |
|
|
|
|
|
trainer = SignatureTrainer( |
|
|
model=model, |
|
|
learning_rate=1e-4, |
|
|
loss_type='contrastive' |
|
|
) |
|
|
|
|
|
print("\nStarting training...") |
|
|
print("Note: This is a demo with limited data. In practice, you would need much more data.") |
|
|
|
|
|
|
|
|
history = trainer.train( |
|
|
train_loader=train_loader, |
|
|
val_loader=val_loader, |
|
|
num_epochs=5, |
|
|
save_best=True, |
|
|
patience=3 |
|
|
) |
|
|
|
|
|
print("\nTraining completed!") |
|
|
print(f"Final training loss: {history['train_losses'][-1]:.4f}") |
|
|
print(f"Final validation loss: {history['val_losses'][-1]:.4f}") |
|
|
print(f"Final training accuracy: {history['train_accuracies'][-1]:.4f}") |
|
|
print(f"Final validation accuracy: {history['val_accuracies'][-1]:.4f}") |
|
|
|
|
|
|
|
|
trainer.close() |
|
|
|
|
|
return model, preprocessor, val_pairs |
|
|
|
|
|
|
|
|
def demo_evaluation(): |
|
|
"""Demonstrate model evaluation.""" |
|
|
print("\n" + "="*60) |
|
|
print("MODEL EVALUATION DEMO") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
signatures = create_sample_signatures() |
|
|
data_pairs = create_training_data() |
|
|
|
|
|
|
|
|
preprocessor = SignaturePreprocessor() |
|
|
verifier = SignatureVerifier(feature_extractor='resnet18', feature_dim=512) |
|
|
|
|
|
|
|
|
evaluator = SignatureEvaluator(verifier, preprocessor) |
|
|
|
|
|
print("Evaluating model performance...") |
|
|
|
|
|
|
|
|
metrics = evaluator.evaluate_dataset( |
|
|
data_pairs, |
|
|
threshold=0.5, |
|
|
batch_size=4, |
|
|
save_results=True, |
|
|
results_dir='evaluation_results' |
|
|
) |
|
|
|
|
|
print(f"\nEvaluation Results:") |
|
|
print(f"Accuracy: {metrics['accuracy']:.4f}") |
|
|
print(f"Precision: {metrics['precision']:.4f}") |
|
|
print(f"Recall: {metrics['recall']:.4f}") |
|
|
print(f"F1-Score: {metrics['f1_score']:.4f}") |
|
|
print(f"ROC AUC: {metrics['roc_auc']:.4f}") |
|
|
|
|
|
|
|
|
print("\nOptimizing threshold...") |
|
|
opt_metrics = evaluator.evaluate_with_threshold_optimization( |
|
|
data_pairs, |
|
|
metric='f1_score', |
|
|
batch_size=4 |
|
|
) |
|
|
|
|
|
print(f"Optimized threshold: {opt_metrics['optimized_threshold']:.4f}") |
|
|
print(f"Optimized F1-Score: {opt_metrics['f1_score']:.4f}") |
|
|
|
|
|
return metrics, opt_metrics |
|
|
|
|
|
|
|
|
def demo_feature_extraction(): |
|
|
"""Demonstrate feature extraction.""" |
|
|
print("\n" + "="*60) |
|
|
print("FEATURE EXTRACTION DEMO") |
|
|
print("="*60) |
|
|
|
|
|
|
|
|
signatures = create_sample_signatures() |
|
|
|
|
|
|
|
|
preprocessor = SignaturePreprocessor() |
|
|
verifier = SignatureVerifier(feature_extractor='resnet18', feature_dim=512) |
|
|
|
|
|
print("Extracting features from sample signatures...") |
|
|
|
|
|
|
|
|
signature_files = [ |
|
|
'data/samples/john_doe_1.png', |
|
|
'data/samples/john_doe_2.png', |
|
|
'data/samples/jane_smith_1.png', |
|
|
'data/samples/bob_wilson_1.png' |
|
|
] |
|
|
|
|
|
features = {} |
|
|
for sig_file in signature_files: |
|
|
try: |
|
|
features[sig_file] = verifier.extract_signature_features(sig_file) |
|
|
print(f"Extracted features for {sig_file}: shape {features[sig_file].shape}") |
|
|
except Exception as e: |
|
|
print(f"Error extracting features from {sig_file}: {e}") |
|
|
|
|
|
|
|
|
print("\nComputing similarities between extracted features...") |
|
|
sig_files = list(features.keys()) |
|
|
for i in range(len(sig_files)): |
|
|
for j in range(i+1, len(sig_files)): |
|
|
sig1, sig2 = sig_files[i], sig_files[j] |
|
|
feat1, feat2 = features[sig1], features[sig2] |
|
|
|
|
|
|
|
|
|
|
|
feat1_flat = feat1.flatten() |
|
|
feat2_flat = feat2.flatten() |
|
|
similarity = np.dot(feat1_flat, feat2_flat) / (np.linalg.norm(feat1_flat) * np.linalg.norm(feat2_flat)) |
|
|
|
|
|
print(f"{sig1} vs {sig2}: {similarity:.4f}") |
|
|
|
|
|
return features |
|
|
|
|
|
|
|
|
def main(): |
|
|
"""Main demo function.""" |
|
|
print("E-Signature Verification Model Demo") |
|
|
print("="*60) |
|
|
|
|
|
try: |
|
|
|
|
|
verifier, preprocessor, data_pairs = demo_basic_verification() |
|
|
|
|
|
|
|
|
features = demo_feature_extraction() |
|
|
|
|
|
|
|
|
print("\nNote: Skipping training demo to save time. Uncomment the next line to run it.") |
|
|
|
|
|
|
|
|
|
|
|
metrics, opt_metrics = demo_evaluation() |
|
|
|
|
|
print("\n" + "="*60) |
|
|
print("DEMO COMPLETED SUCCESSFULLY!") |
|
|
print("="*60) |
|
|
print("\nNext steps:") |
|
|
print("1. Collect more signature data for better training") |
|
|
print("2. Experiment with different model architectures") |
|
|
print("3. Tune hyperparameters for your specific use case") |
|
|
print("4. Deploy the model for production use") |
|
|
print("\nCheck the 'evaluation_results' directory for detailed evaluation reports.") |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Demo failed with error: {e}") |
|
|
import traceback |
|
|
traceback.print_exc() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
main() |
|
|
|