acd23's picture
Upload folder using huggingface_hub
3cc53ab verified
"""
Metrics for segmentation evaluation.
"""
import numpy as np
def compute_iou(pred, gt):
"""Intersection over Union (Jaccard Index)."""
pred = pred.astype(bool)
gt = gt.astype(bool)
intersection = np.logical_and(pred, gt).sum()
union = np.logical_or(pred, gt).sum()
if union == 0:
return 1.0 if intersection == 0 else 0.0
return float(intersection / union)
def compute_dice(pred, gt):
"""Dice score / F1 score."""
pred = pred.astype(bool)
gt = gt.astype(bool)
intersection = np.logical_and(pred, gt).sum()
total = pred.sum() + gt.sum()
if total == 0:
return 1.0
return float(2.0 * intersection / total)
def compute_pixel_accuracy(pred, gt):
"""Pixel accuracy."""
pred = pred.astype(bool)
gt = gt.astype(bool)
correct = (pred == gt).sum()
total = pred.size
return float(correct / total)
def compute_precision_recall(pred, gt):
"""Precision and recall."""
pred = pred.astype(bool)
gt = gt.astype(bool)
tp = np.logical_and(pred, gt).sum()
fp = np.logical_and(pred, ~gt).sum()
fn = np.logical_and(~pred, gt).sum()
precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0
recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0
return precision, recall
def compute_fpr_fnr(pred, gt):
"""False positive rate and false negative rate."""
pred = pred.astype(bool)
gt = gt.astype(bool)
tn = np.logical_and(~pred, ~gt).sum()
fp = np.logical_and(pred, ~gt).sum()
fn = np.logical_and(~pred, gt).sum()
tp = np.logical_and(pred, gt).sum()
fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else 0.0
fnr = float(fn / (fn + tp)) if (fn + tp) > 0 else 0.0
return fpr, fnr
def compute_all_metrics(pred_batch, gt_batch):
"""
pred_batch: (N, H, W) binary/float array
gt_batch: (N, H, W) binary/float array
"""
pred_batch = pred_batch > 0.5
gt_batch = gt_batch > 0.5
ious = []
dices = []
accs = []
precisions = []
recalls = []
fprs = []
fnrs = []
N = pred_batch.shape[0]
for i in range(N):
pred = pred_batch[i]
gt = gt_batch[i]
ious.append(compute_iou(pred, gt))
dices.append(compute_dice(pred, gt))
accs.append(compute_pixel_accuracy(pred, gt))
p, r = compute_precision_recall(pred, gt)
precisions.append(p)
recalls.append(r)
fpr, fnr = compute_fpr_fnr(pred, gt)
fprs.append(fpr)
fnrs.append(fnr)
return {
"iou": float(np.mean(ious)),
"miou": float(np.mean(ious)),
"dice_score": float(np.mean(dices)),
"pixel_accuracy": float(np.mean(accs)),
"precision": float(np.mean(precisions)),
"recall": float(np.mean(recalls)),
"fpr": float(np.mean(fprs)),
"fnr": float(np.mean(fnrs)),
}