SentinelWatch / utils /evaluation.py
VishaliniS456's picture
Upload 8 files
9875bf8 verified
"""Evaluation metrics for change and cloud detection."""
import numpy as np
from typing import Dict, Optional
def calculate_metrics(
pred_mask: np.ndarray,
gt_mask: np.ndarray,
threshold: float = 0.5
) -> Dict[str, float]:
"""
Calculate pixel-level classification metrics.
Args:
pred_mask: Predicted binary mask (H, W) or confidence map (H, W)
gt_mask: Ground truth binary mask (H, W)
threshold: Threshold to binarise pred_mask if it's a confidence map
Returns:
Dict with keys: accuracy, precision, recall, f1, iou
"""
# Binarise predictions if needed
if pred_mask.dtype != np.uint8 or pred_mask.max() > 1:
pred = (pred_mask > threshold).astype(np.uint8)
else:
pred = pred_mask.astype(np.uint8)
gt = (gt_mask > 0).astype(np.uint8)
tp = int(np.sum((pred == 1) & (gt == 1)))
tn = int(np.sum((pred == 0) & (gt == 0)))
fp = int(np.sum((pred == 1) & (gt == 0)))
fn = int(np.sum((pred == 0) & (gt == 1)))
total = tp + tn + fp + fn
accuracy = (tp + tn) / total if total > 0 else 0.0
precision = tp / (tp + fp) if (tp + fp) > 0 else 0.0
recall = tp / (tp + fn) if (tp + fn) > 0 else 0.0
f1 = (2 * precision * recall / (precision + recall)
if (precision + recall) > 0 else 0.0)
iou = tp / (tp + fp + fn) if (tp + fp + fn) > 0 else 0.0
return {
"accuracy": accuracy,
"precision": precision,
"recall": recall,
"f1": f1,
"iou": iou,
"tp": tp,
"tn": tn,
"fp": fp,
"fn": fn,
}