|
|
""" |
|
|
Evaluation metrics for signature verification. |
|
|
""" |
|
|
|
|
|
import torch |
|
|
import numpy as np |
|
|
from typing import List, Tuple, Dict, Optional |
|
|
from sklearn.metrics import ( |
|
|
accuracy_score, precision_score, recall_score, f1_score, |
|
|
roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix |
|
|
) |
|
|
import matplotlib.pyplot as plt |
|
|
import seaborn as sns |
|
|
|
|
|
|
|
|
class SignatureVerificationMetrics: |
|
|
""" |
|
|
Comprehensive metrics for signature verification evaluation. |
|
|
""" |
|
|
|
|
|
def __init__(self, threshold: float = 0.5): |
|
|
""" |
|
|
Initialize metrics calculator. |
|
|
|
|
|
Args: |
|
|
threshold: Similarity threshold for binary classification |
|
|
""" |
|
|
self.threshold = threshold |
|
|
self.reset() |
|
|
|
|
|
def reset(self): |
|
|
"""Reset all stored predictions and labels.""" |
|
|
self.predictions = [] |
|
|
self.labels = [] |
|
|
self.similarities = [] |
|
|
|
|
|
def update(self, |
|
|
similarities: np.ndarray, |
|
|
labels: np.ndarray): |
|
|
""" |
|
|
Update metrics with new predictions. |
|
|
|
|
|
Args: |
|
|
similarities: Similarity scores |
|
|
labels: Ground truth labels (1 for genuine, 0 for forged) |
|
|
""" |
|
|
self.similarities.extend(similarities) |
|
|
self.labels.extend(labels) |
|
|
|
|
|
|
|
|
predictions = (similarities >= self.threshold).astype(int) |
|
|
self.predictions.extend(predictions) |
|
|
|
|
|
def compute_metrics(self) -> Dict[str, float]: |
|
|
""" |
|
|
Compute all evaluation metrics. |
|
|
|
|
|
Returns: |
|
|
Dictionary of metrics |
|
|
""" |
|
|
if not self.predictions or not self.labels: |
|
|
raise ValueError("No predictions or labels available. Call update() first.") |
|
|
|
|
|
similarities = np.array(self.similarities) |
|
|
labels = np.array(self.labels) |
|
|
predictions = np.array(self.predictions) |
|
|
|
|
|
|
|
|
accuracy = accuracy_score(labels, predictions) |
|
|
precision = precision_score(labels, predictions, zero_division=0) |
|
|
recall = recall_score(labels, predictions, zero_division=0) |
|
|
f1 = f1_score(labels, predictions, zero_division=0) |
|
|
|
|
|
|
|
|
try: |
|
|
roc_auc = roc_auc_score(labels, similarities) |
|
|
except ValueError: |
|
|
roc_auc = 0.0 |
|
|
|
|
|
|
|
|
try: |
|
|
precision_vals, recall_vals, _ = precision_recall_curve(labels, similarities) |
|
|
pr_auc = np.trapz(precision_vals, recall_vals) |
|
|
except ValueError: |
|
|
pr_auc = 0.0 |
|
|
|
|
|
|
|
|
cm = confusion_matrix(labels, predictions) |
|
|
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0) |
|
|
|
|
|
|
|
|
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0 |
|
|
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0 |
|
|
|
|
|
|
|
|
eer = self._compute_eer(labels, similarities) |
|
|
|
|
|
|
|
|
far = fp / (fp + tn) if (fp + tn) > 0 else 0.0 |
|
|
frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0 |
|
|
|
|
|
metrics = { |
|
|
'accuracy': accuracy, |
|
|
'precision': precision, |
|
|
'recall': recall, |
|
|
'f1_score': f1, |
|
|
'roc_auc': roc_auc, |
|
|
'pr_auc': pr_auc, |
|
|
'specificity': specificity, |
|
|
'sensitivity': sensitivity, |
|
|
'eer': eer, |
|
|
'far': far, |
|
|
'frr': frr, |
|
|
'threshold': self.threshold |
|
|
} |
|
|
|
|
|
return metrics |
|
|
|
|
|
def _compute_eer(self, labels: np.ndarray, similarities: np.ndarray) -> float: |
|
|
""" |
|
|
Compute Equal Error Rate (EER). |
|
|
|
|
|
Args: |
|
|
labels: Ground truth labels |
|
|
similarities: Similarity scores |
|
|
|
|
|
Returns: |
|
|
Equal Error Rate |
|
|
""" |
|
|
try: |
|
|
fpr, tpr, thresholds = roc_curve(labels, similarities) |
|
|
fnr = 1 - tpr |
|
|
eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
return float(eer) |
|
|
except (ValueError, IndexError): |
|
|
return 0.0 |
|
|
|
|
|
def plot_roc_curve(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot ROC curve. |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.similarities or not self.labels: |
|
|
raise ValueError("No data available for plotting.") |
|
|
|
|
|
similarities = np.array(self.similarities) |
|
|
labels = np.array(self.labels) |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels, similarities) |
|
|
roc_auc = roc_auc_score(labels, similarities) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})') |
|
|
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--') |
|
|
plt.xlim([0.0, 1.0]) |
|
|
plt.ylim([0.0, 1.05]) |
|
|
plt.xlabel('False Positive Rate') |
|
|
plt.ylabel('True Positive Rate') |
|
|
plt.title('Receiver Operating Characteristic (ROC) Curve') |
|
|
plt.legend(loc="lower right") |
|
|
plt.grid(True) |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
def plot_precision_recall_curve(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot Precision-Recall curve. |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.similarities or not self.labels: |
|
|
raise ValueError("No data available for plotting.") |
|
|
|
|
|
similarities = np.array(self.similarities) |
|
|
labels = np.array(self.labels) |
|
|
|
|
|
precision, recall, _ = precision_recall_curve(labels, similarities) |
|
|
pr_auc = np.trapz(precision, recall) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
plt.plot(recall, precision, color='darkorange', lw=2, label=f'PR curve (AUC = {pr_auc:.2f})') |
|
|
plt.xlim([0.0, 1.0]) |
|
|
plt.ylim([0.0, 1.05]) |
|
|
plt.xlabel('Recall') |
|
|
plt.ylabel('Precision') |
|
|
plt.title('Precision-Recall Curve') |
|
|
plt.legend(loc="lower left") |
|
|
plt.grid(True) |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
def plot_confusion_matrix(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot confusion matrix. |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.predictions or not self.labels: |
|
|
raise ValueError("No data available for plotting.") |
|
|
|
|
|
cm = confusion_matrix(self.labels, self.predictions) |
|
|
|
|
|
plt.figure(figsize=(8, 6)) |
|
|
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues', |
|
|
xticklabels=['Forged', 'Genuine'], |
|
|
yticklabels=['Forged', 'Genuine']) |
|
|
plt.title('Confusion Matrix') |
|
|
plt.xlabel('Predicted') |
|
|
plt.ylabel('Actual') |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
def plot_similarity_distribution(self, save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot distribution of similarity scores for genuine and forged pairs. |
|
|
|
|
|
Args: |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
if not self.similarities or not self.labels: |
|
|
raise ValueError("No data available for plotting.") |
|
|
|
|
|
similarities = np.array(self.similarities) |
|
|
labels = np.array(self.labels) |
|
|
|
|
|
genuine_similarities = similarities[labels == 1] |
|
|
forged_similarities = similarities[labels == 0] |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.hist(genuine_similarities, bins=50, alpha=0.7, label='Genuine', color='green') |
|
|
plt.hist(forged_similarities, bins=50, alpha=0.7, label='Forged', color='red') |
|
|
plt.axvline(self.threshold, color='black', linestyle='--', label=f'Threshold = {self.threshold}') |
|
|
plt.xlabel('Similarity Score') |
|
|
plt.ylabel('Frequency') |
|
|
plt.title('Distribution of Similarity Scores') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
class ThresholdOptimizer: |
|
|
""" |
|
|
Optimize threshold for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, metric: str = 'f1_score'): |
|
|
""" |
|
|
Initialize threshold optimizer. |
|
|
|
|
|
Args: |
|
|
metric: Metric to optimize ('f1_score', 'accuracy', 'eer') |
|
|
""" |
|
|
self.metric = metric |
|
|
self.best_threshold = 0.5 |
|
|
self.best_score = 0.0 |
|
|
|
|
|
def optimize(self, |
|
|
similarities: np.ndarray, |
|
|
labels: np.ndarray, |
|
|
threshold_range: Tuple[float, float] = (0.0, 1.0), |
|
|
num_thresholds: int = 100) -> Dict[str, float]: |
|
|
""" |
|
|
Optimize threshold for given metric. |
|
|
|
|
|
Args: |
|
|
similarities: Similarity scores |
|
|
labels: Ground truth labels |
|
|
threshold_range: Range of thresholds to test |
|
|
num_thresholds: Number of thresholds to test |
|
|
|
|
|
Returns: |
|
|
Dictionary with best threshold and score |
|
|
""" |
|
|
thresholds = np.linspace(threshold_range[0], threshold_range[1], num_thresholds) |
|
|
scores = [] |
|
|
|
|
|
for threshold in thresholds: |
|
|
predictions = (similarities >= threshold).astype(int) |
|
|
|
|
|
if self.metric == 'f1_score': |
|
|
score = f1_score(labels, predictions, zero_division=0) |
|
|
elif self.metric == 'accuracy': |
|
|
score = accuracy_score(labels, predictions) |
|
|
elif self.metric == 'eer': |
|
|
|
|
|
fpr, tpr, _ = roc_curve(labels, similarities) |
|
|
fnr = 1 - tpr |
|
|
try: |
|
|
score = fpr[np.nanargmin(np.absolute((fnr - fpr)))] |
|
|
except (ValueError, IndexError): |
|
|
score = 1.0 |
|
|
else: |
|
|
raise ValueError(f"Unsupported metric: {self.metric}") |
|
|
|
|
|
scores.append(score) |
|
|
|
|
|
|
|
|
if self.metric == 'eer': |
|
|
best_idx = np.argmin(scores) |
|
|
else: |
|
|
best_idx = np.argmax(scores) |
|
|
|
|
|
self.best_threshold = thresholds[best_idx] |
|
|
self.best_score = scores[best_idx] |
|
|
|
|
|
return { |
|
|
'best_threshold': self.best_threshold, |
|
|
'best_score': self.best_score, |
|
|
'thresholds': thresholds, |
|
|
'scores': scores |
|
|
} |
|
|
|
|
|
def plot_threshold_analysis(self, |
|
|
similarities: np.ndarray, |
|
|
labels: np.ndarray, |
|
|
save_path: Optional[str] = None): |
|
|
""" |
|
|
Plot threshold analysis. |
|
|
|
|
|
Args: |
|
|
similarities: Similarity scores |
|
|
labels: Ground truth labels |
|
|
save_path: Path to save the plot |
|
|
""" |
|
|
result = self.optimize(similarities, labels) |
|
|
|
|
|
plt.figure(figsize=(10, 6)) |
|
|
plt.plot(result['thresholds'], result['scores'], 'b-', linewidth=2) |
|
|
plt.axvline(self.best_threshold, color='red', linestyle='--', |
|
|
label=f'Best threshold = {self.best_threshold:.3f}') |
|
|
plt.xlabel('Threshold') |
|
|
plt.ylabel(f'{self.metric.upper()}') |
|
|
plt.title(f'Threshold Optimization - {self.metric.upper()}') |
|
|
plt.legend() |
|
|
plt.grid(True, alpha=0.3) |
|
|
|
|
|
if save_path: |
|
|
plt.savefig(save_path, dpi=300, bbox_inches='tight') |
|
|
plt.show() |
|
|
|
|
|
|
|
|
class CrossValidationEvaluator: |
|
|
""" |
|
|
Cross-validation evaluator for signature verification. |
|
|
""" |
|
|
|
|
|
def __init__(self, |
|
|
model: torch.nn.Module, |
|
|
k_folds: int = 5, |
|
|
threshold: float = 0.5): |
|
|
""" |
|
|
Initialize cross-validation evaluator. |
|
|
|
|
|
Args: |
|
|
model: Model to evaluate |
|
|
k_folds: Number of folds for cross-validation |
|
|
threshold: Similarity threshold |
|
|
""" |
|
|
self.model = model |
|
|
self.k_folds = k_folds |
|
|
self.threshold = threshold |
|
|
self.results = [] |
|
|
|
|
|
def evaluate(self, |
|
|
data_pairs: List[Tuple[str, str, int]], |
|
|
preprocessor, |
|
|
batch_size: int = 32) -> Dict[str, float]: |
|
|
""" |
|
|
Perform k-fold cross-validation. |
|
|
|
|
|
Args: |
|
|
data_pairs: List of (signature1_path, signature2_path, label) tuples |
|
|
preprocessor: Image preprocessor |
|
|
batch_size: Batch size for evaluation |
|
|
|
|
|
Returns: |
|
|
Average metrics across all folds |
|
|
""" |
|
|
from sklearn.model_selection import KFold |
|
|
|
|
|
kf = KFold(n_splits=self.k_folds, shuffle=True, random_state=42) |
|
|
data_pairs = np.array(data_pairs) |
|
|
|
|
|
fold_metrics = [] |
|
|
|
|
|
for fold, (train_idx, val_idx) in enumerate(kf.split(data_pairs)): |
|
|
print(f"Evaluating fold {fold + 1}/{self.k_folds}") |
|
|
|
|
|
val_pairs = data_pairs[val_idx] |
|
|
|
|
|
|
|
|
fold_metrics.append(self._evaluate_fold(val_pairs, preprocessor, batch_size)) |
|
|
|
|
|
|
|
|
avg_metrics = {} |
|
|
for metric in fold_metrics[0].keys(): |
|
|
avg_metrics[metric] = np.mean([fold[metric] for fold in fold_metrics]) |
|
|
avg_metrics[f'{metric}_std'] = np.std([fold[metric] for fold in fold_metrics]) |
|
|
|
|
|
self.results = fold_metrics |
|
|
return avg_metrics |
|
|
|
|
|
def _evaluate_fold(self, |
|
|
val_pairs: np.ndarray, |
|
|
preprocessor, |
|
|
batch_size: int) -> Dict[str, float]: |
|
|
""" |
|
|
Evaluate a single fold. |
|
|
|
|
|
Args: |
|
|
val_pairs: Validation pairs |
|
|
preprocessor: Image preprocessor |
|
|
batch_size: Batch size |
|
|
|
|
|
Returns: |
|
|
Metrics for this fold |
|
|
""" |
|
|
self.model.eval() |
|
|
similarities = [] |
|
|
labels = [] |
|
|
|
|
|
with torch.no_grad(): |
|
|
for i in range(0, len(val_pairs), batch_size): |
|
|
batch_pairs = val_pairs[i:i+batch_size] |
|
|
|
|
|
for sig1_path, sig2_path, label in batch_pairs: |
|
|
|
|
|
sig1 = preprocessor.preprocess_image(sig1_path) |
|
|
sig2 = preprocessor.preprocess_image(sig2_path) |
|
|
|
|
|
|
|
|
sig1 = sig1.unsqueeze(0) |
|
|
sig2 = sig2.unsqueeze(0) |
|
|
|
|
|
|
|
|
similarity = self.model(sig1, sig2) |
|
|
similarities.append(similarity.item()) |
|
|
labels.append(label) |
|
|
|
|
|
|
|
|
similarities = np.array(similarities) |
|
|
labels = np.array(labels) |
|
|
|
|
|
metrics_calculator = SignatureVerificationMetrics(threshold=self.threshold) |
|
|
metrics_calculator.update(similarities, labels) |
|
|
|
|
|
return metrics_calculator.compute_metrics() |
|
|
|