""" Metrics for segmentation evaluation. """ import numpy as np def compute_iou(pred, gt): """Intersection over Union (Jaccard Index).""" pred = pred.astype(bool) gt = gt.astype(bool) intersection = np.logical_and(pred, gt).sum() union = np.logical_or(pred, gt).sum() if union == 0: return 1.0 if intersection == 0 else 0.0 return float(intersection / union) def compute_dice(pred, gt): """Dice score / F1 score.""" pred = pred.astype(bool) gt = gt.astype(bool) intersection = np.logical_and(pred, gt).sum() total = pred.sum() + gt.sum() if total == 0: return 1.0 return float(2.0 * intersection / total) def compute_pixel_accuracy(pred, gt): """Pixel accuracy.""" pred = pred.astype(bool) gt = gt.astype(bool) correct = (pred == gt).sum() total = pred.size return float(correct / total) def compute_precision_recall(pred, gt): """Precision and recall.""" pred = pred.astype(bool) gt = gt.astype(bool) tp = np.logical_and(pred, gt).sum() fp = np.logical_and(pred, ~gt).sum() fn = np.logical_and(~pred, gt).sum() precision = float(tp / (tp + fp)) if (tp + fp) > 0 else 0.0 recall = float(tp / (tp + fn)) if (tp + fn) > 0 else 0.0 return precision, recall def compute_fpr_fnr(pred, gt): """False positive rate and false negative rate.""" pred = pred.astype(bool) gt = gt.astype(bool) tn = np.logical_and(~pred, ~gt).sum() fp = np.logical_and(pred, ~gt).sum() fn = np.logical_and(~pred, gt).sum() tp = np.logical_and(pred, gt).sum() fpr = float(fp / (fp + tn)) if (fp + tn) > 0 else 0.0 fnr = float(fn / (fn + tp)) if (fn + tp) > 0 else 0.0 return fpr, fnr def compute_all_metrics(pred_batch, gt_batch): """ pred_batch: (N, H, W) binary/float array gt_batch: (N, H, W) binary/float array """ pred_batch = pred_batch > 0.5 gt_batch = gt_batch > 0.5 ious = [] dices = [] accs = [] precisions = [] recalls = [] fprs = [] fnrs = [] N = pred_batch.shape[0] for i in range(N): pred = pred_batch[i] gt = gt_batch[i] ious.append(compute_iou(pred, gt)) dices.append(compute_dice(pred, gt)) accs.append(compute_pixel_accuracy(pred, gt)) p, r = compute_precision_recall(pred, gt) precisions.append(p) recalls.append(r) fpr, fnr = compute_fpr_fnr(pred, gt) fprs.append(fpr) fnrs.append(fnr) return { "iou": float(np.mean(ious)), "miou": float(np.mean(ious)), "dice_score": float(np.mean(dices)), "pixel_accuracy": float(np.mean(accs)), "precision": float(np.mean(precisions)), "recall": float(np.mean(recalls)), "fpr": float(np.mean(fprs)), "fnr": float(np.mean(fnrs)), }