File size: 13,679 Bytes
01bd570 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 |
# Copyright (c) MONAI Consortium
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# http://www.apache.org/licenses/LICENSE-2.0
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
from __future__ import annotations
from collections.abc import Sequence
import torch
from monai.metrics.metric import CumulativeIterationMetric
from monai.metrics.utils import do_metric_reduction, remap_instance_id
from monai.utils import MetricReduction, ensure_tuple, optional_import
linear_sum_assignment, _ = optional_import("scipy.optimize", name="linear_sum_assignment")
__all__ = ["PanopticQualityMetric", "compute_panoptic_quality"]
class PanopticQualityMetric(CumulativeIterationMetric):
"""
Compute Panoptic Quality between two instance segmentation masks. If specifying `metric_name` to "SQ" or "RQ",
Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.
Panoptic Quality is a metric used in panoptic segmentation tasks. This task unifies the typically distinct tasks
of semantic segmentation (assign a class label to each pixel) and
instance segmentation (detect and segment each object instance). Compared with semantic segmentation, panoptic
segmentation distinguish different instances that belong to same class.
Compared with instance segmentation, panoptic segmentation does not allow overlap and only one semantic label and
one instance id can be assigned to each pixel.
Please refer to the following paper for more details:
https://openaccess.thecvf.com/content_CVPR_2019/papers/Kirillov_Panoptic_Segmentation_CVPR_2019_paper.pdf
This class also refers to the following implementation:
https://github.com/TissueImageAnalytics/CoNIC
Args:
num_classes: number of classes. The number should not count the background.
metric_name: output metric. The value can be "pq", "sq" or "rq".
Except for input only one metric, multiple metrics are also supported via input a sequence of metric names
such as ("pq", "sq", "rq"). If input a sequence, a list of results with the same order
as the input names will be returned.
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
match_iou_threshold: IOU threshold to determine the pairing between `y_pred` and `y`. Usually,
it should >= 0.5, the pairing between instances of `y_pred` and `y` are identical.
If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
maximal amount of unique pairing.
smooth_numerator: a small constant added to the numerator to avoid zero.
"""
def __init__(
self,
num_classes: int,
metric_name: Sequence[str] | str = "pq",
reduction: MetricReduction | str = MetricReduction.MEAN_BATCH,
match_iou_threshold: float = 0.5,
smooth_numerator: float = 1e-6,
) -> None:
super().__init__()
self.num_classes = num_classes
self.reduction = reduction
self.match_iou_threshold = match_iou_threshold
self.smooth_numerator = smooth_numerator
self.metric_name = ensure_tuple(metric_name)
def _compute_tensor(self, y_pred: torch.Tensor, y: torch.Tensor) -> torch.Tensor: # type: ignore[override]
"""
Args:
y_pred: Predictions. It must be in the form of B2HW and have integer type. The first channel and the
second channel represent the instance predictions and classification predictions respectively.
y: ground truth. It must have the same shape as `y_pred` and have integer type. The first channel and the
second channel represent the instance labels and classification labels respectively.
Values in the second channel of `y_pred` and `y` should be in the range of 0 to `self.num_classes`,
where 0 represents the background.
Raises:
ValueError: when `y_pred` and `y` have different shapes.
ValueError: when `y_pred` and `y` have != 2 channels.
ValueError: when `y_pred` and `y` have != 4 dimensions.
"""
if y_pred.shape != y.shape:
raise ValueError(f"y_pred and y should have same shapes, got {y_pred.shape} and {y.shape}.")
if y_pred.shape[1] != 2:
raise ValueError(
f"for panoptic quality calculation, only 2 channels input is supported, got {y_pred.shape[1]}."
)
dims = y_pred.ndimension()
if dims != 4:
raise ValueError(f"y_pred should have 4 dimensions (batch, 2, h, w), got {dims}.")
batch_size = y_pred.shape[0]
outputs = torch.zeros([batch_size, self.num_classes, 4], device=y_pred.device)
for b in range(batch_size):
true_instance, pred_instance = y[b, 0], y_pred[b, 0]
true_class, pred_class = y[b, 1], y_pred[b, 1]
for c in range(self.num_classes):
pred_instance_c = (pred_class == c + 1) * pred_instance
true_instance_c = (true_class == c + 1) * true_instance
outputs[b, c] = compute_panoptic_quality(
pred=pred_instance_c,
gt=true_instance_c,
remap=True,
match_iou_threshold=self.match_iou_threshold,
output_confusion_matrix=True,
)
return outputs
def aggregate(self, reduction: MetricReduction | str | None = None) -> torch.Tensor | list[torch.Tensor]:
"""
Execute reduction logic for the output of `compute_panoptic_quality`.
Args:
reduction: define mode of reduction to the metrics, will only apply reduction on `not-nan` values,
available reduction modes: {``"none"``, ``"mean"``, ``"sum"``, ``"mean_batch"``, ``"sum_batch"``,
``"mean_channel"``, ``"sum_channel"``}, default to `self.reduction`. if "none", will not do reduction.
"""
data = self.get_buffer()
if not isinstance(data, torch.Tensor):
raise ValueError("the data to aggregate must be PyTorch Tensor.")
# do metric reduction
f, _ = do_metric_reduction(data, reduction or self.reduction)
tp, fp, fn, iou_sum = f[..., 0], f[..., 1], f[..., 2], f[..., 3]
results = []
for metric_name in self.metric_name:
metric_name = _check_panoptic_metric_name(metric_name)
if metric_name == "rq":
results.append(tp / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))
elif metric_name == "sq":
results.append(iou_sum / (tp + self.smooth_numerator))
else:
results.append(iou_sum / (tp + 0.5 * fp + 0.5 * fn + self.smooth_numerator))
return results[0] if len(results) == 1 else results
def compute_panoptic_quality(
pred: torch.Tensor,
gt: torch.Tensor,
metric_name: str = "pq",
remap: bool = True,
match_iou_threshold: float = 0.5,
smooth_numerator: float = 1e-6,
output_confusion_matrix: bool = False,
) -> torch.Tensor:
"""Computes Panoptic Quality (PQ). If specifying `metric_name` to "SQ" or "RQ",
Segmentation Quality (SQ) or Recognition Quality (RQ) will be returned instead.
In addition, if `output_confusion_matrix` is True, the function will return a tensor with shape 4, which
represents the true positive, false positive, false negative and the sum of iou. These four values are used to
calculate PQ, and returning them directly enables further calculation over all images.
Args:
pred: input data to compute, it must be in the form of HW and have integer type.
gt: ground truth. It must have the same shape as `pred` and have integer type.
metric_name: output metric. The value can be "pq", "sq" or "rq".
remap: whether to remap `pred` and `gt` to ensure contiguous ordering of instance id.
match_iou_threshold: IOU threshold to determine the pairing between `pred` and `gt`. Usually,
it should >= 0.5, the pairing between instances of `pred` and `gt` are identical.
If set `match_iou_threshold` < 0.5, this function uses Munkres assignment to find the
maximal amount of unique pairing.
smooth_numerator: a small constant added to the numerator to avoid zero.
Raises:
ValueError: when `pred` and `gt` have different shapes.
ValueError: when `match_iou_threshold` <= 0.0 or > 1.0.
"""
if gt.shape != pred.shape:
raise ValueError(f"pred and gt should have same shapes, got {pred.shape} and {gt.shape}.")
if match_iou_threshold <= 0.0 or match_iou_threshold > 1.0:
raise ValueError(f"'match_iou_threshold' should be within (0, 1], got: {match_iou_threshold}.")
gt = gt.int()
pred = pred.int()
if remap is True:
gt = remap_instance_id(gt)
pred = remap_instance_id(pred)
pairwise_iou, true_id_list, pred_id_list = _get_pairwise_iou(pred, gt, device=pred.device)
paired_iou, paired_true, paired_pred = _get_paired_iou(
pairwise_iou, match_iou_threshold, device=pairwise_iou.device
)
unpaired_true = [idx for idx in true_id_list[1:] if idx not in paired_true]
unpaired_pred = [idx for idx in pred_id_list[1:] if idx not in paired_pred]
tp, fp, fn = len(paired_true), len(unpaired_pred), len(unpaired_true)
iou_sum = paired_iou.sum()
if output_confusion_matrix:
return torch.as_tensor([tp, fp, fn, iou_sum], device=pred.device)
metric_name = _check_panoptic_metric_name(metric_name)
if metric_name == "rq":
return torch.as_tensor(tp / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)
if metric_name == "sq":
return torch.as_tensor(iou_sum / (tp + smooth_numerator), device=pred.device)
return torch.as_tensor(iou_sum / (tp + 0.5 * fp + 0.5 * fn + smooth_numerator), device=pred.device)
def _get_id_list(gt: torch.Tensor) -> list[torch.Tensor]:
id_list = list(gt.unique())
# ensure id 0 is included
if 0 not in id_list:
id_list.insert(0, torch.tensor(0).int())
return id_list
def _get_pairwise_iou(
pred: torch.Tensor, gt: torch.Tensor, device: str | torch.device = "cpu"
) -> tuple[torch.Tensor, list[torch.Tensor], list[torch.Tensor]]:
pred_id_list = _get_id_list(pred)
true_id_list = _get_id_list(gt)
pairwise_iou = torch.zeros([len(true_id_list) - 1, len(pred_id_list) - 1], dtype=torch.float, device=device)
true_masks: list[torch.Tensor] = []
pred_masks: list[torch.Tensor] = []
for t in true_id_list[1:]:
t_mask = torch.as_tensor(gt == t, device=device).int()
true_masks.append(t_mask)
for p in pred_id_list[1:]:
p_mask = torch.as_tensor(pred == p, device=device).int()
pred_masks.append(p_mask)
for true_id in range(1, len(true_id_list)):
t_mask = true_masks[true_id - 1]
pred_true_overlap = pred[t_mask > 0]
pred_true_overlap_id = list(pred_true_overlap.unique())
for pred_id in pred_true_overlap_id:
if pred_id == 0:
continue
p_mask = pred_masks[pred_id - 1]
total = (t_mask + p_mask).sum()
inter = (t_mask * p_mask).sum()
iou = inter / (total - inter)
pairwise_iou[true_id - 1, pred_id - 1] = iou
return pairwise_iou, true_id_list, pred_id_list
def _get_paired_iou(
pairwise_iou: torch.Tensor, match_iou_threshold: float = 0.5, device: str | torch.device = "cpu"
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
if match_iou_threshold >= 0.5:
pairwise_iou[pairwise_iou <= match_iou_threshold] = 0.0
paired_true, paired_pred = torch.nonzero(pairwise_iou)[:, 0], torch.nonzero(pairwise_iou)[:, 1]
paired_iou = pairwise_iou[paired_true, paired_pred]
paired_true += 1
paired_pred += 1
return paired_iou, paired_true, paired_pred
pairwise_iou = pairwise_iou.cpu().numpy()
paired_true, paired_pred = linear_sum_assignment(-pairwise_iou)
paired_iou = pairwise_iou[paired_true, paired_pred]
paired_true = torch.as_tensor(list(paired_true[paired_iou > match_iou_threshold] + 1), device=device)
paired_pred = torch.as_tensor(list(paired_pred[paired_iou > match_iou_threshold] + 1), device=device)
paired_iou = paired_iou[paired_iou > match_iou_threshold]
return paired_iou, paired_true, paired_pred
def _check_panoptic_metric_name(metric_name: str) -> str:
metric_name = metric_name.replace(" ", "_")
metric_name = metric_name.lower()
if metric_name in ["panoptic_quality", "pq"]:
return "pq"
if metric_name in ["segmentation_quality", "sq"]:
return "sq"
if metric_name in ["recognition_quality", "rq"]:
return "rq"
raise ValueError(f"metric name: {metric_name} is wrong, please use 'pq', 'sq' or 'rq'.")
|