InklyAI / src /evaluation /metrics.py
pravinai's picture
Upload folder using huggingface_hub
8eab354 verified
"""
Evaluation metrics for signature verification.
"""
import torch
import numpy as np
from typing import List, Tuple, Dict, Optional
from sklearn.metrics import (
accuracy_score, precision_score, recall_score, f1_score,
roc_auc_score, roc_curve, precision_recall_curve, confusion_matrix
)
import matplotlib.pyplot as plt
import seaborn as sns
class SignatureVerificationMetrics:
"""
Comprehensive metrics for signature verification evaluation.
"""
def __init__(self, threshold: float = 0.5):
"""
Initialize metrics calculator.
Args:
threshold: Similarity threshold for binary classification
"""
self.threshold = threshold
self.reset()
def reset(self):
"""Reset all stored predictions and labels."""
self.predictions = []
self.labels = []
self.similarities = []
def update(self,
similarities: np.ndarray,
labels: np.ndarray):
"""
Update metrics with new predictions.
Args:
similarities: Similarity scores
labels: Ground truth labels (1 for genuine, 0 for forged)
"""
self.similarities.extend(similarities)
self.labels.extend(labels)
# Convert similarities to binary predictions
predictions = (similarities >= self.threshold).astype(int)
self.predictions.extend(predictions)
def compute_metrics(self) -> Dict[str, float]:
"""
Compute all evaluation metrics.
Returns:
Dictionary of metrics
"""
if not self.predictions or not self.labels:
raise ValueError("No predictions or labels available. Call update() first.")
similarities = np.array(self.similarities)
labels = np.array(self.labels)
predictions = np.array(self.predictions)
# Basic classification metrics
accuracy = accuracy_score(labels, predictions)
precision = precision_score(labels, predictions, zero_division=0)
recall = recall_score(labels, predictions, zero_division=0)
f1 = f1_score(labels, predictions, zero_division=0)
# ROC AUC
try:
roc_auc = roc_auc_score(labels, similarities)
except ValueError:
roc_auc = 0.0
# Precision-Recall AUC
try:
precision_vals, recall_vals, _ = precision_recall_curve(labels, similarities)
pr_auc = np.trapz(precision_vals, recall_vals)
except ValueError:
pr_auc = 0.0
# Confusion matrix
cm = confusion_matrix(labels, predictions)
tn, fp, fn, tp = cm.ravel() if cm.size == 4 else (0, 0, 0, 0)
# Additional metrics
specificity = tn / (tn + fp) if (tn + fp) > 0 else 0.0
sensitivity = tp / (tp + fn) if (tp + fn) > 0 else 0.0
# Equal Error Rate (EER)
eer = self._compute_eer(labels, similarities)
# False Acceptance Rate (FAR) and False Rejection Rate (FRR)
far = fp / (fp + tn) if (fp + tn) > 0 else 0.0
frr = fn / (fn + tp) if (fn + tp) > 0 else 0.0
metrics = {
'accuracy': accuracy,
'precision': precision,
'recall': recall,
'f1_score': f1,
'roc_auc': roc_auc,
'pr_auc': pr_auc,
'specificity': specificity,
'sensitivity': sensitivity,
'eer': eer,
'far': far,
'frr': frr,
'threshold': self.threshold
}
return metrics
def _compute_eer(self, labels: np.ndarray, similarities: np.ndarray) -> float:
"""
Compute Equal Error Rate (EER).
Args:
labels: Ground truth labels
similarities: Similarity scores
Returns:
Equal Error Rate
"""
try:
fpr, tpr, thresholds = roc_curve(labels, similarities)
fnr = 1 - tpr
eer_threshold = thresholds[np.nanargmin(np.absolute((fnr - fpr)))]
eer = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
return float(eer)
except (ValueError, IndexError):
return 0.0
def plot_roc_curve(self, save_path: Optional[str] = None):
"""
Plot ROC curve.
Args:
save_path: Path to save the plot
"""
if not self.similarities or not self.labels:
raise ValueError("No data available for plotting.")
similarities = np.array(self.similarities)
labels = np.array(self.labels)
fpr, tpr, _ = roc_curve(labels, similarities)
roc_auc = roc_auc_score(labels, similarities)
plt.figure(figsize=(8, 6))
plt.plot(fpr, tpr, color='darkorange', lw=2, label=f'ROC curve (AUC = {roc_auc:.2f})')
plt.plot([0, 1], [0, 1], color='navy', lw=2, linestyle='--')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('False Positive Rate')
plt.ylabel('True Positive Rate')
plt.title('Receiver Operating Characteristic (ROC) Curve')
plt.legend(loc="lower right")
plt.grid(True)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_precision_recall_curve(self, save_path: Optional[str] = None):
"""
Plot Precision-Recall curve.
Args:
save_path: Path to save the plot
"""
if not self.similarities or not self.labels:
raise ValueError("No data available for plotting.")
similarities = np.array(self.similarities)
labels = np.array(self.labels)
precision, recall, _ = precision_recall_curve(labels, similarities)
pr_auc = np.trapz(precision, recall)
plt.figure(figsize=(8, 6))
plt.plot(recall, precision, color='darkorange', lw=2, label=f'PR curve (AUC = {pr_auc:.2f})')
plt.xlim([0.0, 1.0])
plt.ylim([0.0, 1.05])
plt.xlabel('Recall')
plt.ylabel('Precision')
plt.title('Precision-Recall Curve')
plt.legend(loc="lower left")
plt.grid(True)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_confusion_matrix(self, save_path: Optional[str] = None):
"""
Plot confusion matrix.
Args:
save_path: Path to save the plot
"""
if not self.predictions or not self.labels:
raise ValueError("No data available for plotting.")
cm = confusion_matrix(self.labels, self.predictions)
plt.figure(figsize=(8, 6))
sns.heatmap(cm, annot=True, fmt='d', cmap='Blues',
xticklabels=['Forged', 'Genuine'],
yticklabels=['Forged', 'Genuine'])
plt.title('Confusion Matrix')
plt.xlabel('Predicted')
plt.ylabel('Actual')
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
def plot_similarity_distribution(self, save_path: Optional[str] = None):
"""
Plot distribution of similarity scores for genuine and forged pairs.
Args:
save_path: Path to save the plot
"""
if not self.similarities or not self.labels:
raise ValueError("No data available for plotting.")
similarities = np.array(self.similarities)
labels = np.array(self.labels)
genuine_similarities = similarities[labels == 1]
forged_similarities = similarities[labels == 0]
plt.figure(figsize=(10, 6))
plt.hist(genuine_similarities, bins=50, alpha=0.7, label='Genuine', color='green')
plt.hist(forged_similarities, bins=50, alpha=0.7, label='Forged', color='red')
plt.axvline(self.threshold, color='black', linestyle='--', label=f'Threshold = {self.threshold}')
plt.xlabel('Similarity Score')
plt.ylabel('Frequency')
plt.title('Distribution of Similarity Scores')
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
class ThresholdOptimizer:
"""
Optimize threshold for signature verification.
"""
def __init__(self, metric: str = 'f1_score'):
"""
Initialize threshold optimizer.
Args:
metric: Metric to optimize ('f1_score', 'accuracy', 'eer')
"""
self.metric = metric
self.best_threshold = 0.5
self.best_score = 0.0
def optimize(self,
similarities: np.ndarray,
labels: np.ndarray,
threshold_range: Tuple[float, float] = (0.0, 1.0),
num_thresholds: int = 100) -> Dict[str, float]:
"""
Optimize threshold for given metric.
Args:
similarities: Similarity scores
labels: Ground truth labels
threshold_range: Range of thresholds to test
num_thresholds: Number of thresholds to test
Returns:
Dictionary with best threshold and score
"""
thresholds = np.linspace(threshold_range[0], threshold_range[1], num_thresholds)
scores = []
for threshold in thresholds:
predictions = (similarities >= threshold).astype(int)
if self.metric == 'f1_score':
score = f1_score(labels, predictions, zero_division=0)
elif self.metric == 'accuracy':
score = accuracy_score(labels, predictions)
elif self.metric == 'eer':
# Compute EER for this threshold
fpr, tpr, _ = roc_curve(labels, similarities)
fnr = 1 - tpr
try:
score = fpr[np.nanargmin(np.absolute((fnr - fpr)))]
except (ValueError, IndexError):
score = 1.0
else:
raise ValueError(f"Unsupported metric: {self.metric}")
scores.append(score)
# Find best threshold
if self.metric == 'eer':
best_idx = np.argmin(scores)
else:
best_idx = np.argmax(scores)
self.best_threshold = thresholds[best_idx]
self.best_score = scores[best_idx]
return {
'best_threshold': self.best_threshold,
'best_score': self.best_score,
'thresholds': thresholds,
'scores': scores
}
def plot_threshold_analysis(self,
similarities: np.ndarray,
labels: np.ndarray,
save_path: Optional[str] = None):
"""
Plot threshold analysis.
Args:
similarities: Similarity scores
labels: Ground truth labels
save_path: Path to save the plot
"""
result = self.optimize(similarities, labels)
plt.figure(figsize=(10, 6))
plt.plot(result['thresholds'], result['scores'], 'b-', linewidth=2)
plt.axvline(self.best_threshold, color='red', linestyle='--',
label=f'Best threshold = {self.best_threshold:.3f}')
plt.xlabel('Threshold')
plt.ylabel(f'{self.metric.upper()}')
plt.title(f'Threshold Optimization - {self.metric.upper()}')
plt.legend()
plt.grid(True, alpha=0.3)
if save_path:
plt.savefig(save_path, dpi=300, bbox_inches='tight')
plt.show()
class CrossValidationEvaluator:
"""
Cross-validation evaluator for signature verification.
"""
def __init__(self,
model: torch.nn.Module,
k_folds: int = 5,
threshold: float = 0.5):
"""
Initialize cross-validation evaluator.
Args:
model: Model to evaluate
k_folds: Number of folds for cross-validation
threshold: Similarity threshold
"""
self.model = model
self.k_folds = k_folds
self.threshold = threshold
self.results = []
def evaluate(self,
data_pairs: List[Tuple[str, str, int]],
preprocessor,
batch_size: int = 32) -> Dict[str, float]:
"""
Perform k-fold cross-validation.
Args:
data_pairs: List of (signature1_path, signature2_path, label) tuples
preprocessor: Image preprocessor
batch_size: Batch size for evaluation
Returns:
Average metrics across all folds
"""
from sklearn.model_selection import KFold
kf = KFold(n_splits=self.k_folds, shuffle=True, random_state=42)
data_pairs = np.array(data_pairs)
fold_metrics = []
for fold, (train_idx, val_idx) in enumerate(kf.split(data_pairs)):
print(f"Evaluating fold {fold + 1}/{self.k_folds}")
val_pairs = data_pairs[val_idx]
# Evaluate on validation set
fold_metrics.append(self._evaluate_fold(val_pairs, preprocessor, batch_size))
# Compute average metrics
avg_metrics = {}
for metric in fold_metrics[0].keys():
avg_metrics[metric] = np.mean([fold[metric] for fold in fold_metrics])
avg_metrics[f'{metric}_std'] = np.std([fold[metric] for fold in fold_metrics])
self.results = fold_metrics
return avg_metrics
def _evaluate_fold(self,
val_pairs: np.ndarray,
preprocessor,
batch_size: int) -> Dict[str, float]:
"""
Evaluate a single fold.
Args:
val_pairs: Validation pairs
preprocessor: Image preprocessor
batch_size: Batch size
Returns:
Metrics for this fold
"""
self.model.eval()
similarities = []
labels = []
with torch.no_grad():
for i in range(0, len(val_pairs), batch_size):
batch_pairs = val_pairs[i:i+batch_size]
for sig1_path, sig2_path, label in batch_pairs:
# Load and preprocess images
sig1 = preprocessor.preprocess_image(sig1_path)
sig2 = preprocessor.preprocess_image(sig2_path)
# Add batch dimension
sig1 = sig1.unsqueeze(0)
sig2 = sig2.unsqueeze(0)
# Compute similarity
similarity = self.model(sig1, sig2)
similarities.append(similarity.item())
labels.append(label)
# Compute metrics
similarities = np.array(similarities)
labels = np.array(labels)
metrics_calculator = SignatureVerificationMetrics(threshold=self.threshold)
metrics_calculator.update(similarities, labels)
return metrics_calculator.compute_metrics()