Spaces:
Sleeping
Sleeping
| """ | |
| Training utilities and metrics | |
| Implements Critical Fix #9: Dataset-Aware Metric Computation | |
| """ | |
| import torch | |
| import numpy as np | |
| from typing import Dict, List, Optional | |
| from sklearn.metrics import ( | |
| accuracy_score, f1_score, precision_score, recall_score, | |
| confusion_matrix | |
| ) | |
| class SegmentationMetrics: | |
| """ | |
| Segmentation metrics (IoU, Dice) | |
| Only computed for datasets with pixel masks (Critical Fix #9) | |
| """ | |
| def __init__(self): | |
| """Initialize metrics""" | |
| self.reset() | |
| def reset(self): | |
| """Reset all metrics""" | |
| self.intersection = 0 | |
| self.union = 0 | |
| self.pred_sum = 0 | |
| self.target_sum = 0 | |
| self.total_samples = 0 | |
| def update(self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| has_pixel_mask: bool = True): | |
| """ | |
| Update metrics with batch | |
| Args: | |
| pred: Predicted probabilities (B, 1, H, W) | |
| target: Ground truth masks (B, 1, H, W) | |
| has_pixel_mask: Whether to compute metrics (Critical Fix #9) | |
| """ | |
| if not has_pixel_mask: | |
| return | |
| # Binarize predictions | |
| pred_binary = (pred > 0.5).float() | |
| # Compute intersection and union | |
| intersection = (pred_binary * target).sum().item() | |
| union = pred_binary.sum().item() + target.sum().item() - intersection | |
| self.intersection += intersection | |
| self.union += union | |
| self.pred_sum += pred_binary.sum().item() | |
| self.target_sum += target.sum().item() | |
| self.total_samples += pred.shape[0] | |
| def compute(self) -> Dict[str, float]: | |
| """ | |
| Compute final metrics | |
| Returns: | |
| Dictionary with IoU, Dice, Precision, Recall | |
| """ | |
| # IoU (Jaccard) | |
| iou = self.intersection / (self.union + 1e-8) | |
| # Dice (F1) | |
| dice = (2 * self.intersection) / (self.pred_sum + self.target_sum + 1e-8) | |
| # Precision | |
| precision = self.intersection / (self.pred_sum + 1e-8) | |
| # Recall | |
| recall = self.intersection / (self.target_sum + 1e-8) | |
| return { | |
| 'iou': iou, | |
| 'dice': dice, | |
| 'precision': precision, | |
| 'recall': recall | |
| } | |
| class ClassificationMetrics: | |
| """Classification metrics for forgery type classification""" | |
| def __init__(self, num_classes: int = 3): | |
| """ | |
| Initialize metrics | |
| Args: | |
| num_classes: Number of forgery types | |
| """ | |
| self.num_classes = num_classes | |
| self.reset() | |
| def reset(self): | |
| """Reset all metrics""" | |
| self.predictions = [] | |
| self.targets = [] | |
| self.confidences = [] | |
| def update(self, | |
| pred: np.ndarray, | |
| target: np.ndarray, | |
| confidence: Optional[np.ndarray] = None): | |
| """ | |
| Update metrics with predictions | |
| Args: | |
| pred: Predicted class indices | |
| target: Ground truth class indices | |
| confidence: Optional prediction confidences | |
| """ | |
| self.predictions.extend(pred.tolist()) | |
| self.targets.extend(target.tolist()) | |
| if confidence is not None: | |
| self.confidences.extend(confidence.tolist()) | |
| def compute(self) -> Dict[str, float]: | |
| """ | |
| Compute final metrics | |
| Returns: | |
| Dictionary with Accuracy, F1, Precision, Recall | |
| """ | |
| if len(self.predictions) == 0: | |
| return { | |
| 'accuracy': 0.0, | |
| 'f1_macro': 0.0, | |
| 'f1_weighted': 0.0, | |
| 'precision': 0.0, | |
| 'recall': 0.0 | |
| } | |
| preds = np.array(self.predictions) | |
| targets = np.array(self.targets) | |
| # Accuracy | |
| accuracy = accuracy_score(targets, preds) | |
| # F1 score (macro and weighted) | |
| f1_macro = f1_score(targets, preds, average='macro', zero_division=0) | |
| f1_weighted = f1_score(targets, preds, average='weighted', zero_division=0) | |
| # Precision and Recall | |
| precision = precision_score(targets, preds, average='macro', zero_division=0) | |
| recall = recall_score(targets, preds, average='macro', zero_division=0) | |
| # Confusion matrix | |
| cm = confusion_matrix(targets, preds, labels=range(self.num_classes)) | |
| return { | |
| 'accuracy': accuracy, | |
| 'f1_macro': f1_macro, | |
| 'f1_weighted': f1_weighted, | |
| 'precision': precision, | |
| 'recall': recall, | |
| 'confusion_matrix': cm.tolist() | |
| } | |
| class MetricsTracker: | |
| """Track all metrics during training""" | |
| def __init__(self, config): | |
| """ | |
| Initialize metrics tracker | |
| Args: | |
| config: Configuration object | |
| """ | |
| self.config = config | |
| self.num_classes = config.get('data.num_classes', 3) | |
| self.seg_metrics = SegmentationMetrics() | |
| self.cls_metrics = ClassificationMetrics(self.num_classes) | |
| self.history = { | |
| 'train_loss': [], | |
| 'val_loss': [], | |
| 'train_iou': [], | |
| 'val_iou': [], | |
| 'train_dice': [], | |
| 'val_dice': [], | |
| 'train_precision': [], | |
| 'val_precision': [], | |
| 'train_recall': [], | |
| 'val_recall': [] | |
| } | |
| def reset(self): | |
| """Reset metrics for new epoch""" | |
| self.seg_metrics.reset() | |
| self.cls_metrics.reset() | |
| def update_segmentation(self, | |
| pred: torch.Tensor, | |
| target: torch.Tensor, | |
| dataset_name: str): | |
| """Update segmentation metrics (dataset-aware)""" | |
| has_pixel_mask = self.config.should_compute_localization_metrics(dataset_name) | |
| self.seg_metrics.update(pred, target, has_pixel_mask) | |
| def update_classification(self, | |
| pred: np.ndarray, | |
| target: np.ndarray, | |
| confidence: Optional[np.ndarray] = None): | |
| """Update classification metrics""" | |
| self.cls_metrics.update(pred, target, confidence) | |
| def compute_all(self) -> Dict[str, float]: | |
| """Compute all metrics""" | |
| seg = self.seg_metrics.compute() | |
| # Only include classification metrics if they have data | |
| if len(self.cls_metrics.predictions) > 0: | |
| cls = self.cls_metrics.compute() | |
| # Prefix classification metrics to avoid collision | |
| cls_prefixed = {f'cls_{k}': v for k, v in cls.items()} | |
| return {**seg, **cls_prefixed} | |
| return seg | |
| def log_epoch(self, epoch: int, phase: str, loss: float, metrics: Dict): | |
| """Log metrics for epoch""" | |
| prefix = f'{phase}_' | |
| self.history[f'{phase}_loss'].append(loss) | |
| if 'iou' in metrics: | |
| self.history[f'{phase}_iou'].append(metrics['iou']) | |
| if 'dice' in metrics: | |
| self.history[f'{phase}_dice'].append(metrics['dice']) | |
| if 'precision' in metrics: | |
| self.history[f'{phase}_precision'].append(metrics['precision']) | |
| if 'recall' in metrics: | |
| self.history[f'{phase}_recall'].append(metrics['recall']) | |
| def get_history(self) -> Dict: | |
| """Get full training history""" | |
| return self.history | |
| class EarlyStopping: | |
| """Early stopping to prevent overfitting""" | |
| def __init__(self, | |
| patience: int = 10, | |
| min_delta: float = 0.001, | |
| mode: str = 'max'): | |
| """ | |
| Initialize early stopping | |
| Args: | |
| patience: Number of epochs to wait | |
| min_delta: Minimum improvement required | |
| mode: 'min' for loss, 'max' for metrics | |
| """ | |
| self.patience = patience | |
| self.min_delta = min_delta | |
| self.mode = mode | |
| self.counter = 0 | |
| self.best_value = None | |
| self.should_stop = False | |
| def __call__(self, value: float) -> bool: | |
| """ | |
| Check if training should stop | |
| Args: | |
| value: Current metric value | |
| Returns: | |
| True if should stop | |
| """ | |
| if self.best_value is None: | |
| self.best_value = value | |
| return False | |
| if self.mode == 'max': | |
| improved = value > self.best_value + self.min_delta | |
| else: | |
| improved = value < self.best_value - self.min_delta | |
| if improved: | |
| self.best_value = value | |
| self.counter = 0 | |
| else: | |
| self.counter += 1 | |
| if self.counter >= self.patience: | |
| self.should_stop = True | |
| return self.should_stop | |
| def get_metrics_tracker(config) -> MetricsTracker: | |
| """Factory function for metrics tracker""" | |
| return MetricsTracker(config) | |