|
|
""" |
|
|
Comprehensive evaluator for signature verification models. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Tuple, Dict, Optional, Union |
|
|
import os |
|
|
import json |
|
|
from tqdm import tqdm |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
|
|
|
from ..models.siamese_network import SiameseNetwork, SignatureVerifier |
|
|
from ..data.preprocessing import SignaturePreprocessor |
|
|
from .metrics import SignatureVerificationMetrics, ThresholdOptimizer, CrossValidationEvaluator |
|
|
|
|
|
|
|
|
class SignatureEvaluator: |
|
|
""" |
|
|
Comprehensive evaluator for signature verification models. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model: Union[SiameseNetwork, SignatureVerifier], |
|
|
preprocessor: SignaturePreprocessor, |
|
|
device: str = 'auto'): |
|
|
""" |
|
|
Initialize the evaluator. |
|
|
|
|
|
Args: |
|
|
model: Trained signature verification model |
|
|
preprocessor: Image preprocessor |
|
|
device: Device to run evaluation on |
|
|
""" |
|
|
self.model = model |
|
|
self.preprocessor = preprocessor |
|
|
self.device = self._get_device(device) |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'to'): |
|
|
self.model.to(self.device) |
|
|
|
|
|
if hasattr(self.model, 'eval'): |
|
|
self.model.eval() |
|
|
elif hasattr(self.model, 'model') and hasattr(self.model.model, 'eval'): |
|
|
self.model.model.eval() |
|
|
|
|
|
def _get_device(self, device: str) -> torch.device: |
|
|
"""Get the appropriate device.""" |
|
|
if device == 'auto': |
|
|
return torch.device('cuda' if torch.cuda.is_available() else 'cpu') |
|
|
else: |
|
|
return torch.device(device) |
|
|
|
|
|
def evaluate_dataset(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
threshold: float = 0.5, |
|
|
batch_size: int = 32, |
|
|
save_results: bool = True, |
|
|
results_dir: str = 'evaluation_results') -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate model on a dataset. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
threshold: Similarity threshold for binary classification |
|
|
batch_size: Batch size for evaluation |
|
|
save_results: Whether to save results |
|
|
results_dir: Directory to save results |
|
|
|
|
|
Returns: |
|
|
Dictionary of evaluation metrics |
|
|
""" |
|
|
print(f"Evaluating on {len(data_pairs)} signature pairs...") |
|
|
|
|
|
|
|
|
metrics_calculator = SignatureVerificationMetrics(threshold=threshold) |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, len(data_pairs), batch_size), desc="Evaluating"): |
|
|
batch_pairs = data_pairs[i:i+batch_size] |
|
|
|
|
|
for sig1_path, sig2_path, label in batch_pairs: |
|
|
try: |
|
|
|
|
|
sig1 = self.preprocessor.preprocess_image(sig1_path) |
|
|
sig2 = self.preprocessor.preprocess_image(sig2_path) |
|
|
|
|
|
|
|
|
sig1 = sig1.unsqueeze(0).to(self.device) |
|
|
sig2 = sig2.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'verify_signatures'): |
|
|
|
|
|
similarity, _ = self.model.verify_signatures(sig1, sig2, threshold) |
|
|
else: |
|
|
|
|
|
similarity = self.model(sig1, sig2) |
|
|
similarity = similarity.item() |
|
|
|
|
|
similarities.append(similarity) |
|
|
labels.append(label) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing pair {sig1_path}, {sig2_path}: {e}") |
|
|
continue |
|
|
|
|
|
|
|
|
similarities = np.array(similarities) |
|
|
labels = np.array(labels) |
|
|
metrics_calculator.update(similarities, labels) |
|
|
|
|
|
|
|
|
metrics = metrics_calculator.compute_metrics() |
|
|
|
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("EVALUATION RESULTS") |
|
|
print("="*50) |
|
|
print(f"Accuracy: {metrics['accuracy']:.4f}") |
|
|
print(f"Precision: {metrics['precision']:.4f}") |
|
|
print(f"Recall: {metrics['recall']:.4f}") |
|
|
print(f"F1-Score: {metrics['f1_score']:.4f}") |
|
|
print(f"ROC AUC: {metrics['roc_auc']:.4f}") |
|
|
print(f"PR AUC: {metrics['pr_auc']:.4f}") |
|
|
print(f"EER: {metrics['eer']:.4f}") |
|
|
print(f"FAR: {metrics['far']:.4f}") |
|
|
print(f"FRR: {metrics['frr']:.4f}") |
|
|
print("="*50) |
|
|
|
|
|
|
|
|
if save_results: |
|
|
self._save_evaluation_results(metrics, similarities, labels, results_dir) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def evaluate_with_threshold_optimization(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
metric: str = 'f1_score', |
|
|
batch_size: int = 32) -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate model with threshold optimization. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
metric: Metric to optimize ('f1_score', 'accuracy', 'eer') |
|
|
batch_size: Batch size for evaluation |
|
|
|
|
|
Returns: |
|
|
Dictionary of evaluation metrics with optimized threshold |
|
|
""" |
|
|
print(f"Evaluating with threshold optimization on {len(data_pairs)} signature pairs...") |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in tqdm(range(0, len(data_pairs), batch_size), desc="Computing similarities"): |
|
|
batch_pairs = data_pairs[i:i+batch_size] |
|
|
|
|
|
for sig1_path, sig2_path, label in batch_pairs: |
|
|
try: |
|
|
|
|
|
sig1 = self.preprocessor.preprocess_image(sig1_path) |
|
|
sig2 = self.preprocessor.preprocess_image(sig2_path) |
|
|
|
|
|
|
|
|
sig1 = sig1.unsqueeze(0).to(self.device) |
|
|
sig2 = sig2.unsqueeze(0).to(self.device) |
|
|
|
|
|
|
|
|
if hasattr(self.model, 'verify_signatures'): |
|
|
similarity, _ = self.model.verify_signatures(sig1, sig2, 0.5) |
|
|
else: |
|
|
similarity = self.model(sig1, sig2) |
|
|
similarity = similarity.item() |
|
|
|
|
|
similarities.append(similarity) |
|
|
labels.append(label) |
|
|
|
|
|
except Exception as e: |
|
|
print(f"Error processing pair {sig1_path}, {sig2_path}: {e}") |
|
|
continue |
|
|
|
|
|
similarities = np.array(similarities) |
|
|
labels = np.array(labels) |
|
|
|
|
|
|
|
|
optimizer = ThresholdOptimizer(metric=metric) |
|
|
optimization_result = optimizer.optimize(similarities, labels) |
|
|
|
|
|
print(f"Optimized threshold: {optimization_result['best_threshold']:.4f}") |
|
|
print(f"Best {metric}: {optimization_result['best_score']:.4f}") |
|
|
|
|
|
|
|
|
metrics_calculator = SignatureVerificationMetrics(threshold=optimization_result['best_threshold']) |
|
|
metrics_calculator.update(similarities, labels) |
|
|
metrics = metrics_calculator.compute_metrics() |
|
|
|
|
|
|
|
|
metrics['optimized_threshold'] = optimization_result['best_threshold'] |
|
|
metrics['optimization_metric'] = metric |
|
|
metrics['optimization_score'] = optimization_result['best_score'] |
|
|
|
|
|
return metrics |
|
|
|
|
|
def cross_validate(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
k_folds: int = 5, |
|
|
threshold: float = 0.5, |
|
|
batch_size: int = 32) -> Dict[str, float]: |
|
|
""" |
|
|
Perform k-fold cross-validation. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
k_folds: Number of folds |
|
|
threshold: Similarity threshold |
|
|
batch_size: Batch size for evaluation |
|
|
|
|
|
Returns: |
|
|
Average metrics across all folds |
|
|
""" |
|
|
print(f"Performing {k_folds}-fold cross-validation on {len(data_pairs)} signature pairs...") |
|
|
|
|
|
evaluator = CrossValidationEvaluator( |
|
|
model=self.model, |
|
|
k_folds=k_folds, |
|
|
threshold=threshold |
|
|
) |
|
|
|
|
|
metrics = evaluator.evaluate(data_pairs, self.preprocessor, batch_size) |
|
|
|
|
|
|
|
|
print("\n" + "="*50) |
|
|
print("CROSS-VALIDATION RESULTS") |
|
|
print("="*50) |
|
|
for metric, value in metrics.items(): |
|
|
if not metric.endswith('_std'): |
|
|
std_key = f"{metric}_std" |
|
|
std_value = metrics.get(std_key, 0.0) |
|
|
print(f"{metric.upper()}: {value:.4f} ± {std_value:.4f}") |
|
|
print("="*50) |
|
|
|
|
|
return metrics |
|
|
|
|
|
def evaluate_by_difficulty(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
difficulty_categories: Dict[str, List[int]], |
|
|
threshold: float = 0.5, |
|
|
batch_size: int = 32) -> Dict[str, Dict[str, float]]: |
|
|
""" |
|
|
Evaluate model performance by difficulty categories. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
difficulty_categories: Dictionary mapping category names to indices |
|
|
threshold: Similarity threshold |
|
|
batch_size: Batch size for evaluation |
|
|
|
|
|
Returns: |
|
|
Dictionary of metrics for each difficulty category |
|
|
""" |
|
|
print("Evaluating by difficulty categories...") |
|
|
|
|
|
results = {} |
|
|
|
|
|
for category, indices in difficulty_categories.items(): |
|
|
print(f"Evaluating {category} category ({len(indices)} pairs)...") |
|
|
|
|
|
category_pairs = [data_pairs[i] for i in indices if i < len(data_pairs)] |
|
|
|
|
|
if not category_pairs: |
|
|
print(f"No pairs found for category {category}") |
|
|
continue |
|
|
|
|
|
|
|
|
category_metrics = self.evaluate_dataset( |
|
|
category_pairs, threshold, batch_size, save_results=False |
|
|
) |
|
|
|
|
|
results[category] = category_metrics |
|
|
|
|
|
return results |
|
|
|
|
|
def generate_evaluation_report(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
output_dir: str = 'evaluation_report', |
|
|
threshold: float = 0.5, |
|
|
batch_size: int = 32) -> str: |
|
|
""" |
|
|
Generate comprehensive evaluation report. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
output_dir: Directory to save report |
|
|
threshold: Similarity threshold |
|
|
batch_size: Batch size for evaluation |
|
|
|
|
|
Returns: |
|
|
Path to the generated report |
|
|
""" |
|
|
os.makedirs(output_dir, exist_ok=True) |
|
|
|
|
|
print("Generating comprehensive evaluation report...") |
|
|
|
|
|
|
|
|
metrics = self.evaluate_dataset(data_pairs, threshold, batch_size, save_results=False) |
|
|
|
|
|
|
|
|
opt_metrics = self.evaluate_with_threshold_optimization(data_pairs, 'f1_score', batch_size) |
|
|
|
|
|
|
|
|
similarities = [] |
|
|
labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for sig1_path, sig2_path, label in data_pairs[:1000]: |
|
|
try: |
|
|
sig1 = self.preprocessor.preprocess_image(sig1_path) |
|
|
sig2 = self.preprocessor.preprocess_image(sig2_path) |
|
|
|
|
|
sig1 = sig1.unsqueeze(0).to(self.device) |
|
|
sig2 = sig2.unsqueeze(0).to(self.device) |
|
|
|
|
|
if hasattr(self.model, 'verify_signatures'): |
|
|
similarity, _ = self.model.verify_signatures(sig1, sig2, threshold) |
|
|
else: |
|
|
similarity = self.model(sig1, sig2) |
|
|
similarity = similarity.item() |
|
|
|
|
|
similarities.append(similarity) |
|
|
labels.append(label) |
|
|
except: |
|
|
continue |
|
|
|
|
|
similarities = np.array(similarities) |
|
|
labels = np.array(labels) |
|
|
|
|
|
|
|
|
metrics_calculator = SignatureVerificationMetrics(threshold=threshold) |
|
|
metrics_calculator.update(similarities, labels) |
|
|
|
|
|
|
|
|
metrics_calculator.plot_roc_curve(os.path.join(output_dir, 'roc_curve.png')) |
|
|
|
|
|
|
|
|
metrics_calculator.plot_precision_recall_curve(os.path.join(output_dir, 'pr_curve.png')) |
|
|
|
|
|
|
|
|
metrics_calculator.plot_confusion_matrix(os.path.join(output_dir, 'confusion_matrix.png')) |
|
|
|
|
|
|
|
|
metrics_calculator.plot_similarity_distribution(os.path.join(output_dir, 'similarity_distribution.png')) |
|
|
|
|
|
|
|
|
optimizer = ThresholdOptimizer('f1_score') |
|
|
optimizer.plot_threshold_analysis(similarities, labels, |
|
|
os.path.join(output_dir, 'threshold_analysis.png')) |
|
|
|
|
|
|
|
|
report_data = { |
|
|
'basic_metrics': metrics, |
|
|
'optimized_metrics': opt_metrics, |
|
|
'dataset_size': len(data_pairs), |
|
|
'threshold_used': threshold, |
|
|
'optimized_threshold': opt_metrics.get('optimized_threshold', threshold) |
|
|
} |
|
|
|
|
|
with open(os.path.join(output_dir, 'metrics.json'), 'w') as f: |
|
|
json.dump(report_data, f, indent=2) |
|
|
|
|
|
|
|
|
html_report = self._generate_html_report(report_data, output_dir) |
|
|
|
|
|
print(f"Evaluation report saved to: {output_dir}") |
|
|
return output_dir |
|
|
|
|
|
def _save_evaluation_results(self, |
|
|
metrics: Dict[str, float], |
|
|
similarities: np.ndarray, |
|
|
labels: np.ndarray, |
|
|
results_dir: str): |
|
|
"""Save evaluation results to files.""" |
|
|
os.makedirs(results_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
with open(os.path.join(results_dir, 'metrics.json'), 'w') as f: |
|
|
json.dump(metrics, f, indent=2) |
|
|
|
|
|
|
|
|
np.save(os.path.join(results_dir, 'similarities.npy'), similarities) |
|
|
np.save(os.path.join(results_dir, 'labels.npy'), labels) |
|
|
|
|
|
def _generate_html_report(self, |
|
|
report_data: Dict, |
|
|
output_dir: str) -> str: |
|
|
"""Generate HTML evaluation report.""" |
|
|
html_content = f""" |
|
|
<!DOCTYPE html> |
|
|
<html> |
|
|
<head> |
|
|
<title>Signature Verification Evaluation Report</title> |
|
|
<style> |
|
|
body {{ font-family: Arial, sans-serif; margin: 40px; }} |
|
|
.header {{ background-color: #f0f0f0; padding: 20px; border-radius: 5px; }} |
|
|
.metrics {{ display: grid; grid-template-columns: repeat(auto-fit, minmax(200px, 1fr)); gap: 20px; margin: 20px 0; }} |
|
|
.metric-card {{ background-color: #f9f9f9; padding: 15px; border-radius: 5px; border-left: 4px solid #007acc; }} |
|
|
.metric-value {{ font-size: 24px; font-weight: bold; color: #007acc; }} |
|
|
.metric-label {{ font-size: 14px; color: #666; }} |
|
|
.plot {{ margin: 20px 0; text-align: center; }} |
|
|
.plot img {{ max-width: 100%; height: auto; }} |
|
|
</style> |
|
|
</head> |
|
|
<body> |
|
|
<div class="header"> |
|
|
<h1>Signature Verification Evaluation Report</h1> |
|
|
<p>Dataset Size: {report_data['dataset_size']} pairs</p> |
|
|
<p>Threshold Used: {report_data['threshold_used']:.4f}</p> |
|
|
<p>Optimized Threshold: {report_data['optimized_metrics'].get('optimized_threshold', 'N/A'):.4f}</p> |
|
|
</div> |
|
|
|
|
|
<h2>Basic Metrics</h2> |
|
|
<div class="metrics"> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['accuracy']:.4f}</div> |
|
|
<div class="metric-label">Accuracy</div> |
|
|
</div> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['precision']:.4f}</div> |
|
|
<div class="metric-label">Precision</div> |
|
|
</div> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['recall']:.4f}</div> |
|
|
<div class="metric-label">Recall</div> |
|
|
</div> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['f1_score']:.4f}</div> |
|
|
<div class="metric-label">F1-Score</div> |
|
|
</div> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['roc_auc']:.4f}</div> |
|
|
<div class="metric-label">ROC AUC</div> |
|
|
</div> |
|
|
<div class="metric-card"> |
|
|
<div class="metric-value">{report_data['basic_metrics']['eer']:.4f}</div> |
|
|
<div class="metric-label">EER</div> |
|
|
</div> |
|
|
</div> |
|
|
|
|
|
<h2>Visualizations</h2> |
|
|
<div class="plot"> |
|
|
<h3>ROC Curve</h3> |
|
|
<img src="roc_curve.png" alt="ROC Curve"> |
|
|
</div> |
|
|
<div class="plot"> |
|
|
<h3>Precision-Recall Curve</h3> |
|
|
<img src="pr_curve.png" alt="Precision-Recall Curve"> |
|
|
</div> |
|
|
<div class="plot"> |
|
|
<h3>Confusion Matrix</h3> |
|
|
<img src="confusion_matrix.png" alt="Confusion Matrix"> |
|
|
</div> |
|
|
<div class="plot"> |
|
|
<h3>Similarity Distribution</h3> |
|
|
<img src="similarity_distribution.png" alt="Similarity Distribution"> |
|
|
</div> |
|
|
<div class="plot"> |
|
|
<h3>Threshold Analysis</h3> |
|
|
<img src="threshold_analysis.png" alt="Threshold Analysis"> |
|
|
</div> |
|
|
</body> |
|
|
</html> |
|
|
""" |
|
|
|
|
|
html_path = os.path.join(output_dir, 'report.html') |
|
|
with open(html_path, 'w') as f: |
|
|
f.write(html_content) |
|
|
|
|
|
return html_path |
|
|
|