InklyAI / src /evaluation /evaluator.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Comprehensive evaluator for signature verification models.
"""
import torch
import numpy as np
from typing import List, Tuple, Dict, Optional, Union
import os
import json
from tqdm import tqdm
import matplotlib.pyplot as plt
import seaborn as sns
from ..models.siamese_network import SiameseNetwork, SignatureVerifier
from ..data.preprocessing import SignaturePreprocessor
from .metrics import SignatureVerificationMetrics, ThresholdOptimizer, CrossValidationEvaluator
class SignatureEvaluator:
"""
Comprehensive evaluator for signature verification models.
"""
def __init__(self,
model: Union[SiameseNetwork, SignatureVerifier],
preprocessor: SignaturePreprocessor,
device: str = 'auto'):
"""
Initialize the evaluator.
Args:
model: Trained signature verification model
preprocessor: Image preprocessor
device: Device to run evaluation on
"""
self.model = model
self.preprocessor = preprocessor
self.device = self._get_device(device)
# Move model to device
if hasattr(self.model, 'to'):
self.model.to(self.device)
if hasattr(self.model, 'eval'):
self.model.eval()
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'eval'):
self.model.model.eval()
def _get_device(self, device: str) -> torch.device:
"""Get the appropriate device."""
if device == 'auto':
return torch.device('cuda' if torch.cuda.is_available() else 'cpu')
else:
return torch.device(device)
def evaluate_dataset(self,
data_pairs: List[Tuple[str, str, int]],
threshold: float = 0.5,
batch_size: int = 32,
save_results: bool = True,
results_dir: str = 'evaluation_results') -> Dict[str, float]:
"""
Evaluate model on a dataset.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
threshold: Similarity threshold for binary classification
batch_size: Batch size for evaluation
save_results: Whether to save results
results_dir: Directory to save results
Returns:
Dictionary of evaluation metrics
"""
print(f"Evaluating on {len(data_pairs)} signature pairs...")
# Initialize metrics calculator
metrics_calculator = SignatureVerificationMetrics(threshold=threshold)
# Process data in batches
similarities = []
labels = []
with torch.no_grad():
for i in tqdm(range(0, len(data_pairs), batch_size), desc="Evaluating"):
batch_pairs = data_pairs[i:i+batch_size]
for sig1_path, sig2_path, label in batch_pairs:
try:
# Load and preprocess images
sig1 = self.preprocessor.preprocess_image(sig1_path)
sig2 = self.preprocessor.preprocess_image(sig2_path)
# Add batch dimension
sig1 = sig1.unsqueeze(0).to(self.device)
sig2 = sig2.unsqueeze(0).to(self.device)
# Compute similarity
if hasattr(self.model, 'verify_signatures'):
# Using SignatureVerifier
similarity, _ = self.model.verify_signatures(sig1, sig2, threshold)
else:
# Using SiameseNetwork directly
similarity = self.model(sig1, sig2)
similarity = similarity.item()
similarities.append(similarity)
labels.append(label)
except Exception as e:
print(f"Error processing pair {sig1_path}, {sig2_path}: {e}")
continue
# Update metrics
similarities = np.array(similarities)
labels = np.array(labels)
metrics_calculator.update(similarities, labels)
# Compute metrics
metrics = metrics_calculator.compute_metrics()
# Print results
print("\n" + "="*50)
print("EVALUATION RESULTS")
print("="*50)
print(f"Accuracy: {metrics['accuracy']:.4f}")
print(f"Precision: {metrics['precision']:.4f}")
print(f"Recall: {metrics['recall']:.4f}")
print(f"F1-Score: {metrics['f1_score']:.4f}")
print(f"ROC AUC: {metrics['roc_auc']:.4f}")
print(f"PR AUC: {metrics['pr_auc']:.4f}")
print(f"EER: {metrics['eer']:.4f}")
print(f"FAR: {metrics['far']:.4f}")
print(f"FRR: {metrics['frr']:.4f}")
print("="*50)
# Save results if requested
if save_results:
self._save_evaluation_results(metrics, similarities, labels, results_dir)
return metrics
def evaluate_with_threshold_optimization(self,
data_pairs: List[Tuple[str, str, int]],
metric: str = 'f1_score',
batch_size: int = 32) -> Dict[str, float]:
"""
Evaluate model with threshold optimization.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
metric: Metric to optimize ('f1_score', 'accuracy', 'eer')
batch_size: Batch size for evaluation
Returns:
Dictionary of evaluation metrics with optimized threshold
"""
print(f"Evaluating with threshold optimization on {len(data_pairs)} signature pairs...")
# First, get all similarities and labels
similarities = []
labels = []
with torch.no_grad():
for i in tqdm(range(0, len(data_pairs), batch_size), desc="Computing similarities"):
batch_pairs = data_pairs[i:i+batch_size]
for sig1_path, sig2_path, label in batch_pairs:
try:
# Load and preprocess images
sig1 = self.preprocessor.preprocess_image(sig1_path)
sig2 = self.preprocessor.preprocess_image(sig2_path)
# Add batch dimension
sig1 = sig1.unsqueeze(0).to(self.device)
sig2 = sig2.unsqueeze(0).to(self.device)
# Compute similarity
if hasattr(self.model, 'verify_signatures'):
similarity, _ = self.model.verify_signatures(sig1, sig2, 0.5)
else:
similarity = self.model(sig1, sig2)
similarity = similarity.item()
similarities.append(similarity)
labels.append(label)
except Exception as e:
print(f"Error processing pair {sig1_path}, {sig2_path}: {e}")
continue
similarities = np.array(similarities)
labels = np.array(labels)
# Optimize threshold
optimizer = ThresholdOptimizer(metric=metric)
optimization_result = optimizer.optimize(similarities, labels)
print(f"Optimized threshold: {optimization_result['best_threshold']:.4f}")
print(f"Best {metric}: {optimization_result['best_score']:.4f}")
# Evaluate with optimized threshold
metrics_calculator = SignatureVerificationMetrics(threshold=optimization_result['best_threshold'])
metrics_calculator.update(similarities, labels)
metrics = metrics_calculator.compute_metrics()
# Add optimization info
metrics['optimized_threshold'] = optimization_result['best_threshold']
metrics['optimization_metric'] = metric
metrics['optimization_score'] = optimization_result['best_score']
return metrics
def cross_validate(self,
data_pairs: List[Tuple[str, str, int]],
k_folds: int = 5,
threshold: float = 0.5,
batch_size: int = 32) -> Dict[str, float]:
"""
Perform k-fold cross-validation.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
k_folds: Number of folds
threshold: Similarity threshold
batch_size: Batch size for evaluation
Returns:
Average metrics across all folds
"""
print(f"Performing {k_folds}-fold cross-validation on {len(data_pairs)} signature pairs...")
evaluator = CrossValidationEvaluator(
model=self.model,
k_folds=k_folds,
threshold=threshold
)
metrics = evaluator.evaluate(data_pairs, self.preprocessor, batch_size)
# Print results
print("\n" + "="*50)
print("CROSS-VALIDATION RESULTS")
print("="*50)
for metric, value in metrics.items():
if not metric.endswith('_std'):
std_key = f"{metric}_std"
std_value = metrics.get(std_key, 0.0)
print(f"{metric.upper()}: {value:.4f} ± {std_value:.4f}")
print("="*50)
return metrics
def evaluate_by_difficulty(self,
data_pairs: List[Tuple[str, str, int]],
difficulty_categories: Dict[str, List[int]],
threshold: float = 0.5,
batch_size: int = 32) -> Dict[str, Dict[str, float]]:
"""
Evaluate model performance by difficulty categories.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
difficulty_categories: Dictionary mapping category names to indices
threshold: Similarity threshold
batch_size: Batch size for evaluation
Returns:
Dictionary of metrics for each difficulty category
"""
print("Evaluating by difficulty categories...")
results = {}
for category, indices in difficulty_categories.items():
print(f"Evaluating {category} category ({len(indices)} pairs)...")
category_pairs = [data_pairs[i] for i in indices if i < len(data_pairs)]
if not category_pairs:
print(f"No pairs found for category {category}")
continue
# Evaluate this category
category_metrics = self.evaluate_dataset(
category_pairs, threshold, batch_size, save_results=False
)
results[category] = category_metrics
return results
def generate_evaluation_report(self,
data_pairs: List[Tuple[str, str, int]],
output_dir: str = 'evaluation_report',
threshold: float = 0.5,
batch_size: int = 32) -> str:
"""
Generate comprehensive evaluation report.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
output_dir: Directory to save report
threshold: Similarity threshold
batch_size: Batch size for evaluation
Returns:
Path to the generated report
"""
os.makedirs(output_dir, exist_ok=True)
print("Generating comprehensive evaluation report...")
# Basic evaluation
metrics = self.evaluate_dataset(data_pairs, threshold, batch_size, save_results=False)
# Threshold optimization
opt_metrics = self.evaluate_with_threshold_optimization(data_pairs, 'f1_score', batch_size)
# Get similarities for plotting
similarities = []
labels = []
with torch.no_grad():
for sig1_path, sig2_path, label in data_pairs[:1000]: # Limit for plotting
try:
sig1 = self.preprocessor.preprocess_image(sig1_path)
sig2 = self.preprocessor.preprocess_image(sig2_path)
sig1 = sig1.unsqueeze(0).to(self.device)
sig2 = sig2.unsqueeze(0).to(self.device)
if hasattr(self.model, 'verify_signatures'):
similarity, _ = self.model.verify_signatures(sig1, sig2, threshold)
else:
similarity = self.model(sig1, sig2)
similarity = similarity.item()
similarities.append(similarity)
labels.append(label)
except:
continue
similarities = np.array(similarities)
labels = np.array(labels)
# Generate plots
metrics_calculator = SignatureVerificationMetrics(threshold=threshold)
metrics_calculator.update(similarities, labels)
# ROC curve
metrics_calculator.plot_roc_curve(os.path.join(output_dir, 'roc_curve.png'))
# Precision-Recall curve
metrics_calculator.plot_precision_recall_curve(os.path.join(output_dir, 'pr_curve.png'))
# Confusion matrix
metrics_calculator.plot_confusion_matrix(os.path.join(output_dir, 'confusion_matrix.png'))
# Similarity distribution
metrics_calculator.plot_similarity_distribution(os.path.join(output_dir, 'similarity_distribution.png'))
# Threshold analysis
optimizer = ThresholdOptimizer('f1_score')
optimizer.plot_threshold_analysis(similarities, labels,
os.path.join(output_dir, 'threshold_analysis.png'))
# Save metrics to JSON
report_data = {
'basic_metrics': metrics,
'optimized_metrics': opt_metrics,
'dataset_size': len(data_pairs),
'threshold_used': threshold,
'optimized_threshold': opt_metrics.get('optimized_threshold', threshold)
}
with open(os.path.join(output_dir, 'metrics.json'), 'w') as f:
json.dump(report_data, f, indent=2)
# Generate HTML report
html_report = self._generate_html_report(report_data, output_dir)
print(f"Evaluation report saved to: {output_dir}")
return output_dir
def _save_evaluation_results(self,
metrics: Dict[str, float],
similarities: np.ndarray,
labels: np.ndarray,
results_dir: str):
"""Save evaluation results to files."""
os.makedirs(results_dir, exist_ok=True)
# Save metrics
with open(os.path.join(results_dir, 'metrics.json'), 'w') as f:
json.dump(metrics, f, indent=2)
# Save raw data
np.save(os.path.join(results_dir, 'similarities.npy'), similarities)
np.save(os.path.join(results_dir, 'labels.npy'), labels)
def _generate_html_report(self,
report_data: Dict,
output_dir: str) -> str:
"""Generate HTML evaluation report."""
html_content = f"""
<!DOCTYPE html>
<html>
<head>
<title>Signature Verification Evaluation Report</title>
<style>
body {{ font-family: Arial, sans-serif; margin: 40px; }}
.header {{ background-color: #f0f0f0; padding: 20px; border-radius: 5px; }}
.metrics {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px; margin: 20px 0; }}
.metric-card {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; border-left: 4px solid #007acc; }}
.metric-value {{ font-size: 24px; font-weight: bold; color: #007acc; }}
.metric-label {{ font-size: 14px; color: #666; }}
.plot {{ margin: 20px 0; text-align: center; }}
.plot img {{ max-width: 100%; height: auto; }}
</style>
</head>
<body>
<div class="header">
<h1>Signature Verification Evaluation Report</h1>
<p>Dataset Size: {report_data['dataset_size']} pairs</p>
<p>Threshold Used: {report_data['threshold_used']:.4f}</p>
<p>Optimized Threshold: {report_data['optimized_metrics'].get('optimized_threshold', 'N/A'):.4f}</p>
</div>
<h2>Basic Metrics</h2>
<div class="metrics">
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['accuracy']:.4f}</div>
<div class="metric-label">Accuracy</div>
</div>
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['precision']:.4f}</div>
<div class="metric-label">Precision</div>
</div>
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['recall']:.4f}</div>
<div class="metric-label">Recall</div>
</div>
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['f1_score']:.4f}</div>
<div class="metric-label">F1-Score</div>
</div>
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['roc_auc']:.4f}</div>
<div class="metric-label">ROC AUC</div>
</div>
<div class="metric-card">
<div class="metric-value">{report_data['basic_metrics']['eer']:.4f}</div>
<div class="metric-label">EER</div>
</div>
</div>
<h2>Visualizations</h2>
<div class="plot">
<h3>ROC Curve</h3>
<img src="roc_curve.png" alt="ROC Curve">
</div>
<div class="plot">
<h3>Precision-Recall Curve</h3>
<img src="pr_curve.png" alt="Precision-Recall Curve">
</div>
<div class="plot">
<h3>Confusion Matrix</h3>
<img src="confusion_matrix.png" alt="Confusion Matrix">
</div>
<div class="plot">
<h3>Similarity Distribution</h3>
<img src="similarity_distribution.png" alt="Similarity Distribution">
</div>
<div class="plot">
<h3>Threshold Analysis</h3>
<img src="threshold_analysis.png" alt="Threshold Analysis">
</div>
</body>
</html>
"""
html_path = os.path.join(output_dir, 'report.html')
with open(html_path, 'w') as f:
f.write(html_content)
return html_path