""" Training utilities and metrics Implements Critical Fix #9: Dataset-Aware Metric Computation """ import torch import numpy as np from typing import Dict, List, Optional from sklearn.metrics import ( accuracy_score, f1_score, precision_score, recall_score, confusion_matrix ) class SegmentationMetrics: """ Segmentation metrics (IoU, Dice) Only computed for datasets with pixel masks (Critical Fix #9) """ def __init__(self): """Initialize metrics""" self.reset() def reset(self): """Reset all metrics""" self.intersection = 0 self.union = 0 self.pred_sum = 0 self.target_sum = 0 self.total_samples = 0 def update(self, pred: torch.Tensor, target: torch.Tensor, has_pixel_mask: bool = True): """ Update metrics with batch Args: pred: Predicted probabilities (B, 1, H, W) target: Ground truth masks (B, 1, H, W) has_pixel_mask: Whether to compute metrics (Critical Fix #9) """ if not has_pixel_mask: return # Binarize predictions pred_binary = (pred > 0.5).float() # Compute intersection and union intersection = (pred_binary * target).sum().item() union = pred_binary.sum().item() + target.sum().item() - intersection self.intersection += intersection self.union += union self.pred_sum += pred_binary.sum().item() self.target_sum += target.sum().item() self.total_samples += pred.shape[0] def compute(self) -> Dict[str, float]: """ Compute final metrics Returns: Dictionary with IoU, Dice, Precision, Recall """ # IoU (Jaccard) iou = self.intersection / (self.union + 1e-8) # Dice (F1) dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8) # Precision precision = self.intersection / (self.pred_sum + 1e-8) # Recall recall = self.intersection / (self.target_sum + 1e-8) return { 'iou': iou, 'dice': dice, 'precision': precision, 'recall': recall } class ClassificationMetrics: """Classification metrics for forgery type classification""" def __init__(self, num_classes: int = 3): """ Initialize metrics Args: num_classes: Number of forgery types """ self.num_classes = num_classes self.reset() def reset(self): """Reset all metrics""" self.predictions = [] self.targets = [] self.confidences = [] def update(self, pred: np.ndarray, target: np.ndarray, confidence: Optional[np.ndarray] = None): """ Update metrics with predictions Args: pred: Predicted class indices target: Ground truth class indices confidence: Optional prediction confidences """ self.predictions.extend(pred.tolist()) self.targets.extend(target.tolist()) if confidence is not None: self.confidences.extend(confidence.tolist()) def compute(self) -> Dict[str, float]: """ Compute final metrics Returns: Dictionary with Accuracy, F1, Precision, Recall """ if len(self.predictions) == 0: return { 'accuracy': 0.0, 'f1_macro': 0.0, 'f1_weighted': 0.0, 'precision': 0.0, 'recall': 0.0 } preds = np.array(self.predictions) targets = np.array(self.targets) # Accuracy accuracy = accuracy_score(targets, preds) # F1 score (macro and weighted) f1_macro = f1_score(targets, preds, average='macro', zero_division=0) f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0) # Precision and Recall precision = precision_score(targets, preds, average='macro', zero_division=0) recall = recall_score(targets, preds, average='macro', zero_division=0) # Confusion matrix cm = confusion_matrix(targets, preds, labels=range(self.num_classes)) return { 'accuracy': accuracy, 'f1_macro': f1_macro, 'f1_weighted': f1_weighted, 'precision': precision, 'recall': recall, 'confusion_matrix': cm.tolist() } class MetricsTracker: """Track all metrics during training""" def __init__(self, config): """ Initialize metrics tracker Args: config: Configuration object """ self.config = config self.num_classes = config.get('data.num_classes', 3) self.seg_metrics = SegmentationMetrics() self.cls_metrics = ClassificationMetrics(self.num_classes) self.history = { 'train_loss': [], 'val_loss': [], 'train_iou': [], 'val_iou': [], 'train_dice': [], 'val_dice': [], 'train_precision': [], 'val_precision': [], 'train_recall': [], 'val_recall': [] } def reset(self): """Reset metrics for new epoch""" self.seg_metrics.reset() self.cls_metrics.reset() def update_segmentation(self, pred: torch.Tensor, target: torch.Tensor, dataset_name: str): """Update segmentation metrics (dataset-aware)""" has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name) self.seg_metrics.update(pred, target, has_pixel_mask) def update_classification(self, pred: np.ndarray, target: np.ndarray, confidence: Optional[np.ndarray] = None): """Update classification metrics""" self.cls_metrics.update(pred, target, confidence) def compute_all(self) -> Dict[str, float]: """Compute all metrics""" seg = self.seg_metrics.compute() # Only include classification metrics if they have data if len(self.cls_metrics.predictions) > 0: cls = self.cls_metrics.compute() # Prefix classification metrics to avoid collision cls_prefixed = {f'cls_{k}': v for k, v in cls.items()} return {**seg, **cls_prefixed} return seg def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict): """Log metrics for epoch""" prefix = f'{phase}_' self.history[f'{phase}_loss'].append(loss) if 'iou' in metrics: self.history[f'{phase}_iou'].append(metrics['iou']) if 'dice' in metrics: self.history[f'{phase}_dice'].append(metrics['dice']) if 'precision' in metrics: self.history[f'{phase}_precision'].append(metrics['precision']) if 'recall' in metrics: self.history[f'{phase}_recall'].append(metrics['recall']) def get_history(self) -> Dict: """Get full training history""" return self.history class EarlyStopping: """Early stopping to prevent overfitting""" def __init__(self, patience: int = 10, min_delta: float = 0.001, mode: str = 'max'): """ Initialize early stopping Args: patience: Number of epochs to wait min_delta: Minimum improvement required mode: 'min' for loss, 'max' for metrics """ self.patience = patience self.min_delta = min_delta self.mode = mode self.counter = 0 self.best_value = None self.should_stop = False def __call__(self, value: float) -> bool: """ Check if training should stop Args: value: Current metric value Returns: True if should stop """ if self.best_value is None: self.best_value = value return False if self.mode == 'max': improved = value > self.best_value + self.min_delta else: improved = value < self.best_value - self.min_delta if improved: self.best_value = value self.counter = 0 else: self.counter += 1 if self.counter >= self.patience: self.should_stop = True return self.should_stop def get_metrics_tracker(config) -> MetricsTracker: """Factory function for metrics tracker""" return MetricsTracker(config)