from collections import defaultdict import numpy as np import torch from clipseg_eval.general_utils import log from torch.nn import functional as nnf class BaseMetric(object): def __init__( self, metric_names, pred_range=None, gt_index=0, pred_index=0, eval_intermediate=True, eval_validation=True ): self._names = tuple(metric_names) self._eval_intermediate = eval_intermediate self._eval_validation = eval_validation self._pred_range = pred_range self._pred_index = pred_index self._gt_index = gt_index self.predictions = [] self.ground_truths = [] def eval_intermediate(self): return self._eval_intermediate def eval_validation(self): return self._eval_validation def names(self): return self._names def add(self, predictions, ground_truth): raise NotImplementedError def value(self): raise NotImplementedError def scores(self): # similar to value but returns dict value = self.value() if isinstance(value, dict): return value else: assert type(value) in {list, tuple} return list(zip(self.names(), self.value())) def _get_pred_gt(self, predictions, ground_truth): pred = predictions[self._pred_index] gt = ground_truth[self._gt_index] if self._pred_range is not None: pred = pred[:, self._pred_range[0] : self._pred_range[1]] return pred, gt class FixedIntervalMetrics(BaseMetric): def __init__( self, sigmoid=False, ignore_mask=False, resize_to=None, resize_pred=None, n_values=51, custom_threshold=None ): super().__init__( ( "ap", "best_fgiou", "best_miou", "fgiou0.5", "fgiou0.1", "mean_iou_0p5", "mean_iou_0p1", "best_biniou", "biniou_0.5", "fgiou_thresh", ) ) self.intersections = [] self.unions = [] # self.threshold = threshold self.sigmoid = sigmoid self.resize_to = resize_to self.resize_pred = resize_pred # resize prediction to match ground truth self.class_count = defaultdict(lambda: 0) self.per_class = defaultdict(lambda: [0, 0]) self.ignore_mask = ignore_mask self.custom_threshold = custom_threshold self.scores_ap = [] self.scores_iou = [] self.gts, self.preds = [], [] self.classes = [] # [1:-1] ignores 0 and 1 self.threshold_values = np.linspace(0, 1, n_values)[1:-1] self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) def add(self, pred, gt): pred_batch = pred[0].cpu() if self.sigmoid: pred_batch = torch.sigmoid(pred_batch) gt_batch = gt[0].cpu() mask_batch = ( gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) ) cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) if self.resize_to is not None: gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode="nearest") pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode="bilinear", align_corners=False) if isinstance(cls_batch, torch.Tensor): cls_batch = cls_batch.cpu().numpy().tolist() assert len(gt_batch) == len(pred_batch) == len(cls_batch), ( f"{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}" ) for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): if self.resize_pred: predictions = nnf.interpolate( predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode="bilinear", align_corners=True, ) p = predictions.flatten() g = ground_truth.flatten() assert len(p) == len(g) if mask is not None: m = mask.flatten().bool() p = p[m] g = g[m] p_sorted = p.sort() p = p_sorted.values g = g[p_sorted.indices] tps, fps, fns, tns = [], [], [], [] for thresh in self.threshold_values: valid = torch.where(p > thresh)[0] if len(valid) > 0: n = int(valid[0]) else: n = len(g) fn = int(g[:n].sum()) tp = int(g[n:].sum()) fns += [fn] tns += [n - fn] tps += [tp] fps += [len(g) - n - tp] self.metrics["tp"] += [tps] self.metrics["fp"] += [fps] self.metrics["fn"] += [fns] self.metrics["tn"] += [tns] self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] def value(self): import time t_start = time.time() if set(self.classes) == set([None]): all_classes = None log.warning("classes were not provided, cannot compute mIoU") else: all_classes = set(int(c) for c in self.classes) # log.info(f'compute metrics for {len(all_classes)} classes') summed = { k: [ sum([self.metrics[k][i][j] for i in range(len(self.metrics[k]))]) for j in range(len(self.threshold_values)) ] for k in self.metrics.keys() } if all_classes is not None: assert len(self.classes) == len(self.metrics["tp"]) == len(self.metrics["fn"]) # group by class metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} for i in range(len(self.metrics["tp"])): for k in self.metrics.keys(): metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] # sum over all instances within the classes summed_by_cls = { k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} for k in self.metrics.keys() } # Compute average precision assert (np.array(summed["fp"]) + np.array(summed["tp"])).sum(), "no predictions is made" # only consider values where a prediction is made precisions = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j]) for j in range(len(self.threshold_values)) if summed["tp"][j] + summed["fp"][j] > 0 ] recalls = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fn"][j]) for j in range(len(self.threshold_values)) if summed["tp"][j] + summed["fp"][j] > 0 ] # remove duplicate recall-precision-pairs (and sort by recall value) recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) from scipy.integrate import simps ap = simps(precisions, recalls) # Compute best IoU fgiou_scores = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j] + summed["fn"][j]) for j in range(len(self.threshold_values)) ] biniou_scores = [ 0.5 * (summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j] + summed["fn"][j])) + 0.5 * (summed["tn"][j] / (1 + summed["tn"][j] + summed["fn"][j] + summed["fp"][j])) for j in range(len(self.threshold_values)) ] # index_0p5 = self.threshold_values.tolist().index(0.5) # index_0p1 = self.threshold_values.tolist().index(0.1) # index_0p2 = self.threshold_values.tolist().index(0.2) # index_0p3 = self.threshold_values.tolist().index(0.3) if self.custom_threshold is not None: index_ct = self.threshold_values.tolist().index(self.custom_threshold) if all_classes is not None: # mean IoU mean_ious = [ np.mean( [ summed_by_cls["tp"][c][j] / (1 + summed_by_cls["tp"][c][j] + summed_by_cls["fp"][c][j] + summed_by_cls["fn"][c][j]) for c in all_classes ] ) for j in range(len(self.threshold_values)) ] mean_iou_dict = { "miou_best": max(mean_ious) if all_classes is not None else None, # 'miou_0.5': mean_ious[index_0p5] if all_classes is not None else None, # 'miou_0.1': mean_ious[index_0p1] if all_classes is not None else None, # 'miou_0.2': mean_ious[index_0p2] if all_classes is not None else None, # 'miou_0.3': mean_ious[index_0p3] if all_classes is not None else None, "miou_best_t": self.threshold_values[np.argmax(mean_ious)], "mean_iou_ct": ( mean_ious[index_ct] if all_classes is not None and self.custom_threshold is not None else None ), "mean_iou_scores": mean_ious, } print( f"metric computation on {(len(all_classes) if all_classes is not None else 'no')} classes took {time.time() - t_start:.1f}s" ) return { "ap": ap, # fgiou "fgiou_best": max(fgiou_scores), # 'fgiou_0.5': fgiou_scores[index_0p5], # 'fgiou_0.1': fgiou_scores[index_0p1], # 'fgiou_0.2': fgiou_scores[index_0p2], # 'fgiou_0.3': fgiou_scores[index_0p3], "fgiou_best_t": self.threshold_values[np.argmax(fgiou_scores)], # mean iou # biniou "biniou_best": max(biniou_scores), # 'biniou_0.5': biniou_scores[index_0p5], # 'biniou_0.1': biniou_scores[index_0p1], # 'biniou_0.2': biniou_scores[index_0p2], # 'biniou_0.3': biniou_scores[index_0p3], "biniou_best_t": self.threshold_values[np.argmax(biniou_scores)], # custom threshold "fgiou_ct": fgiou_scores[index_ct] if self.custom_threshold is not None else None, "biniou_ct": biniou_scores[index_ct] if self.custom_threshold is not None else None, "ct": self.custom_threshold, # statistics "fgiou_scores": fgiou_scores, "biniou_scores": biniou_scores, "precision_recall_curve": sorted(list(set(zip(recalls, precisions)))), "summed_statistics": summed, "summed_by_cls_statistics": summed_by_cls, **mean_iou_dict, } # ('ap', 'best_fgiou', 'best_miou', 'fgiou0.5', 'fgiou0.1', 'mean_iou_0p5', 'mean_iou_0p1', 'best_biniou', 'biniou_0.5', 'fgiou_thresh' # return ap, best_fgiou, best_mean_iou, iou_0p5, iou_0p1, mean_iou_0p5, mean_iou_0p1, best_biniou, biniou0p5, best_fgiou_thresh, {'summed': summed, 'summed_by_cls': summed_by_cls} class FixedIntervalMetricsWithMatching(FixedIntervalMetrics): def __init__( self, sigmoid=False, ignore_mask=False, resize_to=None, resize_pred=None, n_values=51, custom_threshold=None ): super().__init__(sigmoid, ignore_mask, resize_to, resize_pred, n_values, custom_threshold) self.threshold_values = np.array([0.5]) # np.linspace(0, 1, n_values)[1:-1] self.metrics = dict(tp=[], fp=[], fn=[], tn=[]) def add(self, pred, gt): pred_batch = pred[0].cpu() if self.sigmoid: pred_batch = torch.sigmoid(pred_batch) gt_batch = gt[0].cpu() mask_batch = ( gt[1] if len(gt) > 1 and not self.ignore_mask and gt[1].numel() > 0 else ([None] * len(pred_batch)) ) cls_batch = gt[2] if len(gt) > 2 else [None] * len(pred_batch) if self.resize_to is not None: gt_batch = nnf.interpolate(gt_batch, self.resize_to, mode="nearest") pred_batch = nnf.interpolate(pred_batch, self.resize_to, mode="bilinear", align_corners=False) if isinstance(cls_batch, torch.Tensor): cls_batch = cls_batch.cpu().numpy().tolist() assert len(gt_batch) == len(pred_batch) == len(cls_batch), ( f"{len(gt_batch)} {len(pred_batch)} {len(cls_batch)}" ) for predictions, ground_truth, mask, cls in zip(pred_batch, gt_batch, mask_batch, cls_batch): if self.resize_pred: predictions = nnf.interpolate( predictions.unsqueeze(0).float(), size=ground_truth.size()[-2:], mode="bilinear", align_corners=True, ) p = predictions.flatten() g = ground_truth.flatten() assert len(p) == len(g) if mask is not None: m = mask.flatten().bool() p = p[m] g = g[m] p_sorted = p.sort() p = p_sorted.values g = g[p_sorted.indices] tps, fps, fns, tns = [], [], [], [] for thresh in self.threshold_values: valid = torch.where(p > thresh)[0] if len(valid) > 0: n = int(valid[0]) else: n = len(g) fn = int(g[:n].sum()) tp = int(g[n:].sum()) fns += [fn] tns += [n - fn] tps += [tp] fps += [len(g) - n - tp] self.metrics["tp"] += [tps] self.metrics["fp"] += [fps] self.metrics["fn"] += [fns] self.metrics["tn"] += [tns] self.classes += [cls.item() if isinstance(cls, torch.Tensor) else cls] def value(self): import time t_start = time.time() if set(self.classes) == set([None]): all_classes = None log.warning("classes were not provided, cannot compute mIoU") else: all_classes = set(int(c) for c in self.classes) log.info(f"compute metrics for {len(all_classes)} classes") summed = { k: [ sum([self.metrics[k][i][j] for i in range(len(self.metrics[k]))]) for j in range(len(self.threshold_values)) ] for k in self.metrics.keys() } if all_classes is not None: assert len(self.classes) == len(self.metrics["tp"]) == len(self.metrics["fn"]) # group by class metrics_by_class = {c: {k: [] for k in self.metrics.keys()} for c in all_classes} for i in range(len(self.metrics["tp"])): for k in self.metrics.keys(): metrics_by_class[self.classes[i]][k] += [self.metrics[k][i]] # # sum over all instances within the classes # summed_by_cls = { # k: {c: np.array(metrics_by_class[c][k]).sum(0).tolist() for c in all_classes} # for k in self.metrics.keys() # } # Compute average precision assert (np.array(summed["fp"]) + np.array(summed["tp"])).sum(), "no predictions is made" # only consider values where a prediction is made precisions = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j]) for j in range(len(self.threshold_values)) if summed["tp"][j] + summed["fp"][j] > 0 ] recalls = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fn"][j]) for j in range(len(self.threshold_values)) if summed["tp"][j] + summed["fp"][j] > 0 ] # remove duplicate recall-precision-pairs (and sort by recall value) recalls, precisions = zip(*sorted(list(set(zip(recalls, precisions))), key=lambda x: x[0])) from scipy.integrate import simps ap = simps(precisions, recalls) # Compute best IoU fgiou_scores = [ summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j] + summed["fn"][j]) for j in range(len(self.threshold_values)) ] biniou_scores = [ 0.5 * (summed["tp"][j] / (1 + summed["tp"][j] + summed["fp"][j] + summed["fn"][j])) + 0.5 * (summed["tn"][j] / (1 + summed["tn"][j] + summed["fn"][j] + summed["fp"][j])) for j in range(len(self.threshold_values)) ] print( f"metric computation on {(len(all_classes) if all_classes is not None else 'no')} classes took {time.time() - t_start:.1f}s" ) return { "ap": ap, # fgiou "fgiou_best": max(fgiou_scores), "fgiou_best_t": self.threshold_values[np.argmax(fgiou_scores)], # mean iou # biniou "biniou_best": max(biniou_scores), "biniou_best_t": self.threshold_values[np.argmax(biniou_scores)], # statistics "fgiou_scores": fgiou_scores, "biniou_scores": biniou_scores, "precision_recall_curve": sorted(list(set(zip(recalls, precisions)))), "summed_statistics": summed, }