Spaces:
Running
on
Zero
Running
on
Zero
| import numpy as np | |
| import cv2 | |
| from sklearn.metrics import roc_auc_score, precision_recall_curve, average_precision_score | |
| from skimage.measure import label, regionprops | |
| class AnomalyEvaluator: | |
| def __init__(self, pixel_subsample_rate=0.01, compute_pro=False): | |
| self.subsample_rate = pixel_subsample_rate | |
| self.compute_pro = compute_pro | |
| self.reset() | |
| def reset(self): | |
| self.img_preds = [] | |
| self.img_labels = [] | |
| self.pix_preds = [] | |
| self.pix_labels = [] | |
| self.full_amaps = [] | |
| self.full_masks = [] | |
| def update(self, image_score, gt_label, anomaly_map=None, gt_mask=None): | |
| self.img_preds.append(image_score) | |
| self.img_labels.append(gt_label) | |
| if anomaly_map is not None and gt_mask is not None: | |
| self._update_pixel_metrics(anomaly_map, gt_mask) | |
| def _update_pixel_metrics(self, amap, mask): | |
| if mask.shape != amap.shape: | |
| mask = cv2.resize(mask, (amap.shape[1], amap.shape[0]), interpolation=cv2.INTER_NEAREST) | |
| mask = (mask > 0).astype(int) | |
| if self.compute_pro: | |
| self.full_amaps.append(amap) | |
| self.full_masks.append(mask) | |
| flat_amap = amap.flatten() | |
| flat_mask = mask.flatten() | |
| if self.compute_pro or self.subsample_rate >= 1.0: | |
| self.pix_preds.extend(flat_amap) | |
| self.pix_labels.extend(flat_mask) | |
| else: | |
| # Random Subsampling to save memory | |
| num_pixels = len(flat_mask) | |
| sample_size = int(num_pixels * self.subsample_rate) | |
| indices = np.random.choice(num_pixels, sample_size, replace=False) | |
| self.pix_preds.extend(flat_amap[indices]) | |
| self.pix_labels.extend(flat_mask[indices]) | |
| def compute(self): | |
| results = {} | |
| y_true = np.array(self.img_labels) | |
| y_score = np.array(self.img_preds) | |
| results['image_auroc'] = roc_auc_score(y_true, y_score) | |
| results['image_ap'] = average_precision_score(y_true, y_score) | |
| prec, rec, _ = precision_recall_curve(y_true, y_score) | |
| f1_scores = 2 * (prec * rec) / (prec + rec + 1e-8) | |
| results['image_f1_max'] = np.max(f1_scores) | |
| if len(self.pix_labels) > 0: | |
| pix_true = np.array(self.pix_labels) | |
| pix_score = np.array(self.pix_preds) | |
| results['pixel_auroc'] = roc_auc_score(pix_true, pix_score) | |
| prec_p, rec_p, thresholds_p = precision_recall_curve(pix_true, pix_score) | |
| f1_p = 2 * (prec_p * rec_p) / (prec_p + rec_p + 1e-8) | |
| best_idx = np.argmax(f1_p) | |
| best_threshold = thresholds_p[best_idx] if best_idx < len(thresholds_p) else 0.5 | |
| results['pixel_f1_max'] = np.max(f1_p) | |
| if self.compute_pro: | |
| results['pixel_pro'] = self._compute_pro(best_threshold) | |
| return results | |
| def _compute_pro(self, threshold): | |
| total_pro = 0 | |
| n_defects = 0 | |
| for i in range(len(self.full_amaps)): | |
| gt = self.full_masks[i] | |
| # Skip normal images | |
| if np.sum(gt) == 0: | |
| continue | |
| pred_mask = (self.full_amaps[i] >= threshold).astype(int) | |
| # Label connected components in Ground Truth | |
| labeled_gt = label(gt) | |
| regions = regionprops(labeled_gt) | |
| for region in regions: | |
| n_defects += 1 | |
| blob_mask = (labeled_gt == region.label) | |
| overlap_pixels = np.sum(pred_mask & blob_mask) | |
| blob_area = region.area | |
| total_pro += (overlap_pixels / blob_area) | |
| return total_pro / n_defects if n_defects > 0 else 0.0 |