| """
|
| Training utilities and metrics
|
| Implements Critical Fix #9: Dataset-Aware Metric Computation
|
| """
|
|
|
| import torch
|
| import numpy as np
|
| from typing import Dict, List, Optional
|
| from sklearn.metrics import (
|
| accuracy_score, f1_score, precision_score, recall_score,
|
| confusion_matrix
|
| )
|
|
|
|
|
| class SegmentationMetrics:
|
| """
|
| Segmentation metrics (IoU, Dice)
|
| Only computed for datasets with pixel masks (Critical Fix #9)
|
| """
|
|
|
| def __init__(self):
|
| """Initialize metrics"""
|
| self.reset()
|
|
|
| def reset(self):
|
| """Reset all metrics"""
|
| self.intersection = 0
|
| self.union = 0
|
| self.pred_sum = 0
|
| self.target_sum = 0
|
| self.total_samples = 0
|
|
|
| def update(self,
|
| pred: torch.Tensor,
|
| target: torch.Tensor,
|
| has_pixel_mask: bool = True):
|
| """
|
| Update metrics with batch
|
|
|
| Args:
|
| pred: Predicted probabilities (B, 1, H, W)
|
| target: Ground truth masks (B, 1, H, W)
|
| has_pixel_mask: Whether to compute metrics (Critical Fix #9)
|
| """
|
| if not has_pixel_mask:
|
| return
|
|
|
|
|
| pred_binary = (pred > 0.5).float()
|
|
|
|
|
| intersection = (pred_binary * target).sum().item()
|
| union = pred_binary.sum().item() + target.sum().item() - intersection
|
|
|
| self.intersection += intersection
|
| self.union += union
|
| self.pred_sum += pred_binary.sum().item()
|
| self.target_sum += target.sum().item()
|
| self.total_samples += pred.shape[0]
|
|
|
| def compute(self) -> Dict[str, float]:
|
| """
|
| Compute final metrics
|
|
|
| Returns:
|
| Dictionary with IoU, Dice, Precision, Recall
|
| """
|
|
|
| iou = self.intersection / (self.union + 1e-8)
|
|
|
|
|
| dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8)
|
|
|
|
|
| precision = self.intersection / (self.pred_sum + 1e-8)
|
|
|
|
|
| recall = self.intersection / (self.target_sum + 1e-8)
|
|
|
| return {
|
| 'iou': iou,
|
| 'dice': dice,
|
| 'precision': precision,
|
| 'recall': recall
|
| }
|
|
|
|
|
| class ClassificationMetrics:
|
| """Classification metrics for forgery type classification"""
|
|
|
| def __init__(self, num_classes: int = 3):
|
| """
|
| Initialize metrics
|
|
|
| Args:
|
| num_classes: Number of forgery types
|
| """
|
| self.num_classes = num_classes
|
| self.reset()
|
|
|
| def reset(self):
|
| """Reset all metrics"""
|
| self.predictions = []
|
| self.targets = []
|
| self.confidences = []
|
|
|
| def update(self,
|
| pred: np.ndarray,
|
| target: np.ndarray,
|
| confidence: Optional[np.ndarray] = None):
|
| """
|
| Update metrics with predictions
|
|
|
| Args:
|
| pred: Predicted class indices
|
| target: Ground truth class indices
|
| confidence: Optional prediction confidences
|
| """
|
| self.predictions.extend(pred.tolist())
|
| self.targets.extend(target.tolist())
|
| if confidence is not None:
|
| self.confidences.extend(confidence.tolist())
|
|
|
| def compute(self) -> Dict[str, float]:
|
| """
|
| Compute final metrics
|
|
|
| Returns:
|
| Dictionary with Accuracy, F1, Precision, Recall
|
| """
|
| if len(self.predictions) == 0:
|
| return {
|
| 'accuracy': 0.0,
|
| 'f1_macro': 0.0,
|
| 'f1_weighted': 0.0,
|
| 'precision': 0.0,
|
| 'recall': 0.0
|
| }
|
|
|
| preds = np.array(self.predictions)
|
| targets = np.array(self.targets)
|
|
|
|
|
| accuracy = accuracy_score(targets, preds)
|
|
|
|
|
| f1_macro = f1_score(targets, preds, average='macro', zero_division=0)
|
| f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0)
|
|
|
|
|
| precision = precision_score(targets, preds, average='macro', zero_division=0)
|
| recall = recall_score(targets, preds, average='macro', zero_division=0)
|
|
|
|
|
| cm = confusion_matrix(targets, preds, labels=range(self.num_classes))
|
|
|
| return {
|
| 'accuracy': accuracy,
|
| 'f1_macro': f1_macro,
|
| 'f1_weighted': f1_weighted,
|
| 'precision': precision,
|
| 'recall': recall,
|
| 'confusion_matrix': cm.tolist()
|
| }
|
|
|
|
|
| class MetricsTracker:
|
| """Track all metrics during training"""
|
|
|
| def __init__(self, config):
|
| """
|
| Initialize metrics tracker
|
|
|
| Args:
|
| config: Configuration object
|
| """
|
| self.config = config
|
| self.num_classes = config.get('data.num_classes', 3)
|
|
|
| self.seg_metrics = SegmentationMetrics()
|
| self.cls_metrics = ClassificationMetrics(self.num_classes)
|
|
|
| self.history = {
|
| 'train_loss': [],
|
| 'val_loss': [],
|
| 'train_iou': [],
|
| 'val_iou': [],
|
| 'train_dice': [],
|
| 'val_dice': [],
|
| 'train_precision': [],
|
| 'val_precision': [],
|
| 'train_recall': [],
|
| 'val_recall': []
|
| }
|
|
|
| def reset(self):
|
| """Reset metrics for new epoch"""
|
| self.seg_metrics.reset()
|
| self.cls_metrics.reset()
|
|
|
| def update_segmentation(self,
|
| pred: torch.Tensor,
|
| target: torch.Tensor,
|
| dataset_name: str):
|
| """Update segmentation metrics (dataset-aware)"""
|
| has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name)
|
| self.seg_metrics.update(pred, target, has_pixel_mask)
|
|
|
| def update_classification(self,
|
| pred: np.ndarray,
|
| target: np.ndarray,
|
| confidence: Optional[np.ndarray] = None):
|
| """Update classification metrics"""
|
| self.cls_metrics.update(pred, target, confidence)
|
|
|
| def compute_all(self) -> Dict[str, float]:
|
| """Compute all metrics"""
|
| seg = self.seg_metrics.compute()
|
|
|
|
|
| if len(self.cls_metrics.predictions) > 0:
|
| cls = self.cls_metrics.compute()
|
|
|
| cls_prefixed = {f'cls_{k}': v for k, v in cls.items()}
|
| return {**seg, **cls_prefixed}
|
|
|
| return seg
|
|
|
| def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict):
|
| """Log metrics for epoch"""
|
| prefix = f'{phase}_'
|
|
|
| self.history[f'{phase}_loss'].append(loss)
|
|
|
| if 'iou' in metrics:
|
| self.history[f'{phase}_iou'].append(metrics['iou'])
|
| if 'dice' in metrics:
|
| self.history[f'{phase}_dice'].append(metrics['dice'])
|
| if 'precision' in metrics:
|
| self.history[f'{phase}_precision'].append(metrics['precision'])
|
| if 'recall' in metrics:
|
| self.history[f'{phase}_recall'].append(metrics['recall'])
|
|
|
| def get_history(self) -> Dict:
|
| """Get full training history"""
|
| return self.history
|
|
|
|
|
| class EarlyStopping:
|
| """Early stopping to prevent overfitting"""
|
|
|
| def __init__(self,
|
| patience: int = 10,
|
| min_delta: float = 0.001,
|
| mode: str = 'max'):
|
| """
|
| Initialize early stopping
|
|
|
| Args:
|
| patience: Number of epochs to wait
|
| min_delta: Minimum improvement required
|
| mode: 'min' for loss, 'max' for metrics
|
| """
|
| self.patience = patience
|
| self.min_delta = min_delta
|
| self.mode = mode
|
|
|
| self.counter = 0
|
| self.best_value = None
|
| self.should_stop = False
|
|
|
| def __call__(self, value: float) -> bool:
|
| """
|
| Check if training should stop
|
|
|
| Args:
|
| value: Current metric value
|
|
|
| Returns:
|
| True if should stop
|
| """
|
| if self.best_value is None:
|
| self.best_value = value
|
| return False
|
|
|
| if self.mode == 'max':
|
| improved = value > self.best_value + self.min_delta
|
| else:
|
| improved = value < self.best_value - self.min_delta
|
|
|
| if improved:
|
| self.best_value = value
|
| self.counter = 0
|
| else:
|
| self.counter += 1
|
|
|
| if self.counter >= self.patience:
|
| self.should_stop = True
|
|
|
| return self.should_stop
|
|
|
|
|
| def get_metrics_tracker(config) -> MetricsTracker:
|
| """Factory function for metrics tracker"""
|
| return MetricsTracker(config)
|
|
|