| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
|
|
|
|
| import warnings |
| from typing import Tuple, Literal |
|
|
| import cv2 |
| import numpy as np |
| from scipy.ndimage import measurements |
| from scipy.ndimage.morphology import binary_fill_holes |
| from skimage.segmentation import watershed |
| import torch |
|
|
| |
| |
| from tools import get_bounding_box, remove_small_objects |
|
|
|
|
| def noop(*args, **kargs): |
| pass |
|
|
|
|
| warnings.warn = noop |
|
|
|
|
| class DetectionCellPostProcessor: |
| def __init__( |
| self, |
| nr_types: int = None, |
| magnification: Literal[20, 40] = 40, |
| gt: bool = False, |
| ) -> None: |
| """DetectionCellPostProcessor for postprocessing prediction maps and get detected cells |
| |
| Args: |
| nr_types (int, optional): Number of cell types, including background (background = 0). Defaults to None. |
| magnification (Literal[20, 40], optional): Which magnification the data has. Defaults to 40. |
| gt (bool, optional): If this is gt data (used that we do not suppress tiny cells that may be noise in a prediction map). |
| Defaults to False. |
| |
| Raises: |
| NotImplementedError: Unknown magnification |
| """ |
| self.nr_types = nr_types |
| self.magnification = magnification |
| self.gt = gt |
|
|
| if magnification == 40: |
| self.object_size = 10 |
| self.k_size = 21 |
| elif magnification == 20: |
| self.object_size = 3 |
| self.k_size = 11 |
| else: |
| raise NotImplementedError("Unknown magnification") |
| if gt: |
| self.object_size = 100 |
| self.k_size = 21 |
|
|
| def post_process_cell_segmentation( |
| self, |
| pred_map: np.ndarray, |
| ) -> Tuple[np.ndarray, dict]: |
| """Post processing of one image tile |
| |
| Args: |
| pred_map (np.ndarray): Combined output of tp, np and hv branches, in the same order. Shape: (H, W, 4) |
| |
| Returns: |
| Tuple[np.ndarray, dict]: |
| np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) |
| dict: Instance dictionary. Main Key is the nuclei instance number (int), with a dict as value. |
| For each instance, the dictionary contains the keys: bbox (bounding box), centroid (centroid coordinates), |
| contour, type_prob (probability), type (nuclei type) |
| """ |
| if self.nr_types is not None: |
| pred_type = pred_map[..., 1:2] |
| pred_inst = pred_map[..., 2:] |
| pred_type = pred_type.astype(np.int32) |
| |
| else: |
| pred_inst = pred_map |
|
|
| pred_inst = np.squeeze(pred_inst) |
| pred_inst = self.__proc_np_hv( |
| pred_inst, object_size=self.object_size, ksize=self.k_size |
| ) |
| |
| inst_id_list = np.unique(pred_inst)[1:] |
| |
| inst_info_dict = {} |
| for inst_id in inst_id_list: |
| inst_map = pred_inst == inst_id |
| rmin, rmax, cmin, cmax = get_bounding_box(inst_map) |
| inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) |
| inst_map = inst_map[ |
| inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] |
| ] |
| inst_map = inst_map.astype(np.uint8) |
| inst_moment = cv2.moments(inst_map) |
| inst_contour = cv2.findContours( |
| inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
| ) |
| |
| inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) |
| |
| |
| if inst_contour.shape[0] < 3: |
| continue |
| if len(inst_contour.shape) != 2: |
| continue |
| inst_centroid = [ |
| (inst_moment["m10"] / inst_moment["m00"]), |
| (inst_moment["m01"] / inst_moment["m00"]), |
| ] |
| inst_centroid = np.array(inst_centroid) |
| inst_contour[:, 0] += inst_bbox[0][1] |
| inst_contour[:, 1] += inst_bbox[0][0] |
| inst_centroid[0] += inst_bbox[0][1] |
| inst_centroid[1] += inst_bbox[0][0] |
| inst_info_dict[inst_id] = { |
| "bbox": inst_bbox, |
| "centroid": inst_centroid, |
| "contour": inst_contour, |
| "type_prob": None, |
| "type": None, |
| "all_type_prob": [] |
| } |
|
|
| |
| for inst_id in list(inst_info_dict.keys()): |
| rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() |
| inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] |
| inst_type_crop = pred_type[rmin:rmax, cmin:cmax] |
| inst_map_crop = inst_map_crop == inst_id |
| inst_type = inst_type_crop[inst_map_crop] |
| type_list, type_pixels = np.unique(inst_type, return_counts=True) |
| type_list = list(zip(type_list, type_pixels)) |
|
|
| type_probs = {} |
| total_pixels = np.sum(inst_map_crop) + 1.0e-6 |
| for cell_type, pixel_count in type_list: |
| type_probs[cell_type] = float(pixel_count / total_pixels) |
|
|
| all_type_prob = [type_probs.get(i, 0.0) for i in range(self.nr_types)] |
| inst_info_dict[inst_id]["all_type_prob"] = all_type_prob |
| |
|
|
| type_list = sorted(type_list, key=lambda x: x[1], reverse=True) |
| inst_type = type_list[0][0] |
| if inst_type == 0: |
| if len(type_list) > 1: |
| inst_type = type_list[1][0] |
| type_dict = {v[0]: v[1] for v in type_list} |
| type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) |
| inst_info_dict[inst_id]["type"] = int(inst_type) |
| inst_info_dict[inst_id]["type_prob"] = float(type_prob) |
|
|
| return pred_inst, inst_info_dict |
|
|
| def __proc_np_hv( |
| self, pred: np.ndarray, object_size: int = 10, ksize: int = 21 |
| ) -> np.ndarray: |
| """Process Nuclei Prediction with XY Coordinate Map and generate instance map (each instance has unique integer) |
| |
| Separate Instances (also overlapping ones) from binary nuclei map and hv map by using morphological operations and watershed |
| |
| Args: |
| pred (np.ndarray): Prediction output, assuming. Shape: (H, W, 3) |
| * channel 0 contain probability map of nuclei |
| * channel 1 containing the regressed X-map |
| * channel 2 containing the regressed Y-map |
| object_size (int, optional): Smallest oject size for filtering. Defaults to 10 |
| k_size (int, optional): Sobel Kernel size. Defaults to 21 |
| Returns: |
| np.ndarray: Instance map for one image. Each nuclei has own integer. Shape: (H, W) |
| """ |
| pred = np.array(pred, dtype=np.float32) |
|
|
| blb_raw = pred[..., 0] |
| h_dir_raw = pred[..., 1] |
| v_dir_raw = pred[..., 2] |
|
|
| |
| blb = np.array(blb_raw >= 0.5, dtype=np.int32) |
|
|
| blb = measurements.label(blb)[0] |
| blb = remove_small_objects(blb, min_size=10) |
| blb[blb > 0] = 1 |
|
|
| h_dir = cv2.normalize( |
| h_dir_raw, |
| None, |
| alpha=0, |
| beta=1, |
| norm_type=cv2.NORM_MINMAX, |
| dtype=cv2.CV_32F, |
| ) |
| v_dir = cv2.normalize( |
| v_dir_raw, |
| None, |
| alpha=0, |
| beta=1, |
| norm_type=cv2.NORM_MINMAX, |
| dtype=cv2.CV_32F, |
| ) |
|
|
| |
| |
|
|
| sobelh = cv2.Sobel(h_dir, cv2.CV_64F, 1, 0, ksize=ksize) |
| sobelv = cv2.Sobel(v_dir, cv2.CV_64F, 0, 1, ksize=ksize) |
|
|
| sobelh = 1 - ( |
| cv2.normalize( |
| sobelh, |
| None, |
| alpha=0, |
| beta=1, |
| norm_type=cv2.NORM_MINMAX, |
| dtype=cv2.CV_32F, |
| ) |
| ) |
| sobelv = 1 - ( |
| cv2.normalize( |
| sobelv, |
| None, |
| alpha=0, |
| beta=1, |
| norm_type=cv2.NORM_MINMAX, |
| dtype=cv2.CV_32F, |
| ) |
| ) |
|
|
| overall = np.maximum(sobelh, sobelv) |
| overall = overall - (1 - blb) |
| overall[overall < 0] = 0 |
|
|
| dist = (1.0 - overall) * blb |
| |
| dist = -cv2.GaussianBlur(dist, (3, 3), 0) |
|
|
| overall = np.array(overall >= 0.4, dtype=np.int32) |
|
|
| marker = blb - overall |
| marker[marker < 0] = 0 |
| marker = binary_fill_holes(marker).astype("uint8") |
| kernel = cv2.getStructuringElement(cv2.MORPH_ELLIPSE, (5, 5)) |
| marker = cv2.morphologyEx(marker, cv2.MORPH_OPEN, kernel) |
| marker = measurements.label(marker)[0] |
| marker = remove_small_objects(marker, min_size=object_size) |
|
|
| proced_pred = watershed(dist, markers=marker, mask=blb) |
|
|
| return proced_pred |
|
|
|
|
| def calculate_instances( |
| pred_types: torch.Tensor, pred_insts: torch.Tensor |
| ) -> list[dict]: |
| """Best used for GT |
| |
| Args: |
| pred_types (torch.Tensor): Binary or type map ground-truth. |
| Shape must be (B, C, H, W) with C=1 for binary or num_nuclei_types for multi-class. |
| pred_insts (torch.Tensor): Ground-Truth instance map with shape (B, H, W) |
| |
| Returns: |
| list[dict]: Dictionary with nuclei informations, output similar to post_process_cell_segmentation |
| """ |
| type_preds = [] |
| pred_types = pred_types.permute(0, 2, 3, 1) |
| for i in range(pred_types.shape[0]): |
| pred_type = torch.argmax(pred_types, dim=-1)[i].detach().cpu().numpy() |
| pred_inst = pred_insts[i].detach().cpu().numpy() |
| inst_id_list = np.unique(pred_inst)[1:] |
| inst_info_dict = {} |
| for inst_id in inst_id_list: |
| inst_map = pred_inst == inst_id |
| rmin, rmax, cmin, cmax = get_bounding_box(inst_map) |
| inst_bbox = np.array([[rmin, cmin], [rmax, cmax]]) |
| inst_map = inst_map[ |
| inst_bbox[0][0] : inst_bbox[1][0], inst_bbox[0][1] : inst_bbox[1][1] |
| ] |
| inst_map = inst_map.astype(np.uint8) |
| inst_moment = cv2.moments(inst_map) |
| inst_contour = cv2.findContours( |
| inst_map, cv2.RETR_TREE, cv2.CHAIN_APPROX_SIMPLE |
| ) |
| |
| inst_contour = np.squeeze(inst_contour[0][0].astype("int32")) |
| |
| |
| if inst_contour.shape[0] < 3: |
| continue |
| if len(inst_contour.shape) != 2: |
| continue |
| inst_centroid = [ |
| (inst_moment["m10"] / inst_moment["m00"]), |
| (inst_moment["m01"] / inst_moment["m00"]), |
| ] |
| inst_centroid = np.array(inst_centroid) |
| inst_contour[:, 0] += inst_bbox[0][1] |
| inst_contour[:, 1] += inst_bbox[0][0] |
| inst_centroid[0] += inst_bbox[0][1] |
| inst_centroid[1] += inst_bbox[0][0] |
| inst_info_dict[inst_id] = { |
| "bbox": inst_bbox, |
| "centroid": inst_centroid, |
| "contour": inst_contour, |
| "type_prob": None, |
| "type": None, |
| } |
| |
| for inst_id in list(inst_info_dict.keys()): |
| rmin, cmin, rmax, cmax = (inst_info_dict[inst_id]["bbox"]).flatten() |
| inst_map_crop = pred_inst[rmin:rmax, cmin:cmax] |
| inst_type_crop = pred_type[rmin:rmax, cmin:cmax] |
| inst_map_crop = inst_map_crop == inst_id |
| inst_type = inst_type_crop[inst_map_crop] |
| type_list, type_pixels = np.unique(inst_type, return_counts=True) |
| type_list = list(zip(type_list, type_pixels)) |
| type_list = sorted(type_list, key=lambda x: x[1], reverse=True) |
| inst_type = type_list[0][0] |
| if inst_type == 0: |
| if len(type_list) > 1: |
| inst_type = type_list[1][0] |
| type_dict = {v[0]: v[1] for v in type_list} |
| type_prob = type_dict[inst_type] / (np.sum(inst_map_crop) + 1.0e-6) |
| inst_info_dict[inst_id]["type"] = int(inst_type) |
| inst_info_dict[inst_id]["type_prob"] = float(type_prob) |
| type_preds.append(inst_info_dict) |
|
|
| return type_preds |
|
|
|
|
|
|
|
|
|
|
|
|