| | from monai.transforms import Transform, Compose, LoadImage, EnsureChannelFirst |
| | import torch |
| | import skimage |
| | import torch |
| | import SimpleITK as sitk |
| | import numpy as np |
| | from PIL import Image |
| | from io import BytesIO |
| | import matplotlib.pyplot as plt |
| | import SimpleITK as sitk |
| | from matplotlib.colors import ListedColormap |
| | import base64 |
| | import numpy as np |
| | from cv2 import dilate |
| | from scipy.ndimage import label |
| | from Model_Seg import RgbaToGrayscale |
| |
|
| | def image_to_base64(image_path): |
| | with open(image_path, "rb") as image_file: |
| | return base64.b64encode(image_file.read()).decode('utf-8') |
| |
|
| | class CustomCLAHE(Transform): |
| | """Implements Contrast-Limited Adaptive Histogram Equalization (CLAHE) as a custom transform, as described by Qiu et al. |
| | |
| | Attributes: |
| | p1 (float): Weighting factor, determines degree of of contour enhacement. Default is 0.6. |
| | p2 (None or int): Kernel size for adaptive histogram. Default is None. |
| | p3 (float): Clip limit for histogram equalization. Default is 0.01. |
| | |
| | """ |
| |
|
| | def __init__(self, p1=0.6, p2=None, p3=0.01): |
| | self.p1 = p1 |
| | self.p2 = p2 |
| | self.p3 = p3 |
| |
|
| | def __call__(self, data): |
| | """Apply the CLAHE algorithm to input data. |
| | |
| | Args: |
| | data (Union[dict, np.ndarray]): Input data. Could be a dictionary containing the image or the image array itself. |
| | |
| | Returns: |
| | torch.Tensor: Transformed data. |
| | """ |
| | |
| | if isinstance(data, dict): |
| | im = data["image"] |
| |
|
| | else: |
| | im = data |
| | im = im.numpy() |
| | |
| |
|
| | |
| | im = im[0] |
| | im = im[None, :, :] |
| | |
| | im = skimage.exposure.rescale_intensity(im, in_range="image", out_range=(0, 1)) |
| | im_noi = skimage.filters.median(im) |
| | im_fil = im_noi - self.p1 * skimage.filters.gaussian(im_noi, sigma=1) |
| | im_fil = skimage.exposure.rescale_intensity(im_fil, in_range="image", out_range=(0, 1)) |
| | im_ce = skimage.exposure.equalize_adapthist(im_fil, kernel_size=self.p2, clip_limit=self.p3) |
| | if isinstance(data, dict): |
| | data["image"] = torch.Tensor(im_ce) |
| | else: |
| | data = torch.Tensor(im_ce) |
| | |
| | return data |
| |
|
| |
|
| |
|
| | def custom_colormap(): |
| |
|
| | cdict = [(0, 0, 0, 0), |
| | (0, 1, 0, 0.5), |
| | (1, 0, 0, 0.5), |
| | (1, 1, 0, 0.5)] |
| | cmap = ListedColormap(cdict) |
| | return cmap |
| |
|
| | def read_image(image_path): |
| | read_transforms = Compose([ |
| | LoadImage(image_only=True), |
| | EnsureChannelFirst(), |
| | RgbaToGrayscale(), |
| | ]) |
| | try: |
| | original_image = read_transforms(image_path) |
| | original_image_np = original_image.numpy().astype(np.uint8) |
| | return original_image_np.squeeze() |
| |
|
| | except Exception as e: |
| | try : |
| | original_image = sitk.ReadImage(image_path) |
| | original_image_np = sitk.GetArrayFromImage(original_image) |
| | return original_image_np.squeeze() |
| | except Exception as e: |
| | print("Failed Loading the Image: ", e) |
| | return None |
| |
|
| | def overlay_mask(image_path, image_mask): |
| | original_image_np = read_image(image_path).squeeze().astype(np.uint8) |
| |
|
| | |
| | image_mask_disp = image_mask |
| | plt.figure(figsize=(10, 10)) |
| | plt.imshow(original_image_np, cmap='gray') |
| |
|
| | plt.imshow(image_mask_disp, cmap=custom_colormap(), alpha=0.5) |
| | plt.axis('off') |
| |
|
| | |
| | buffer = BytesIO() |
| | plt.savefig(buffer, format='png', bbox_inches='tight', pad_inches=0) |
| | buffer.seek(0) |
| | overlay_image_np = np.array(Image.open(buffer)) |
| | return overlay_image_np, original_image_np |
| |
|
| |
|
| | def bounding_box_mask(image, label): |
| | """Generates a bounding box mask around a labeled region in an image |
| | |
| | Args: |
| | image (SimpleITK.Image): The input image. |
| | label (SimpleITK.Image): The labeled image containing the region of interest. |
| | |
| | Returns: |
| | SimpleITK.Image: An image containing the with the bounding box mask applied with the |
| | same spacing as the original image. |
| | |
| | Note: |
| | This function assumes that the input image and label are SimpleITK.Image objects. |
| | The returned bounding box mask is a binary image where pixels inside the bounding box |
| | are set to 1 and others are set to 0. |
| | """ |
| | |
| | original_spacing = image.GetSpacing() |
| |
|
| | |
| | image_array = sitk.GetArrayFromImage(image) |
| | image_array = np.squeeze(image_array) |
| | label_array = sitk.GetArrayFromImage(label) |
| | label_array = np.squeeze(label_array) |
| |
|
| | |
| | first_nonzero_row_index = np.nonzero(np.any(label_array != 0, axis=1))[0][0] |
| | last_nonzero_row_index = np.max(np.nonzero(np.any(label_array != 0, axis=1))) |
| | first_nonzero_column_index = np.nonzero(np.any(label_array != 0, axis=0))[0][0] |
| | last_nonzero_column_index = np.max(np.nonzero(np.any(label_array != 0, axis=0))) |
| |
|
| | top_left_corner = (first_nonzero_row_index, first_nonzero_column_index) |
| | bottom_right_corner = (last_nonzero_row_index, last_nonzero_column_index) |
| |
|
| | |
| | bounding_box_array = label_array.copy() |
| | bounding_box_array[ |
| | top_left_corner[0] : bottom_right_corner[0] + 1, |
| | top_left_corner[1] : bottom_right_corner[1] + 1, |
| | ] = 1 |
| | |
| | |
| | bounding_box_array = bounding_box_array[None, ...].astype(np.uint8) |
| |
|
| | |
| | bounding_box_image = sitk.GetImageFromArray(bounding_box_array) |
| | bounding_box_image.SetSpacing(original_spacing) |
| | return bounding_box_image |
| |
|
| |
|
| | def threshold_based_crop(image): |
| | """ |
| | Use Otsu's threshold estimator to separate background and foreground. In medical imaging the background is |
| | usually air. Then crop the image using the foreground's axis aligned bounding box. |
| | Args: |
| | image (SimpleITK image): An image where the anatomy and background intensities form a |
| | bi-modal distribution |
| | (the assumption underlying Otsu's method.) |
| | Return: |
| | Cropped image based on foreground's axis aligned bounding box. |
| | """ |
| |
|
| | inside_value = 0 |
| | outside_value = 255 |
| | label_shape_filter = sitk.LabelShapeStatisticsImageFilter() |
| | |
| | |
| | label_shape_filter.Execute(sitk.OtsuThreshold(image, inside_value, outside_value)) |
| | bounding_box = label_shape_filter.GetBoundingBox(outside_value) |
| | return sitk.RegionOfInterest( |
| | image, |
| | bounding_box[int(len(bounding_box) / 2) :], |
| | bounding_box[0 : int(len(bounding_box) / 2)], |
| | ) |
| |
|
| | def creat_SIJ_mask(image, input_label): |
| | """ |
| | Create a mask for the sacroiliac joints (SIJ) from pelvis and sascrum segmentation mask |
| | |
| | Args: |
| | image (SimpleITK.Image): x-ray image. |
| | input_label (SimpleITK.Image): Segmentation mask containing labels for sacrum, left- and right pelvis |
| | |
| | Returns: |
| | SimpleITK.Image: Mask of the SIJ |
| | |
| | """ |
| | |
| | original_spacing = image.GetSpacing() |
| | |
| | |
| | mask_array = sitk.GetArrayFromImage(input_label).squeeze() |
| | |
| | sacrum_value = 1 |
| | left_pelvis_value = 2 |
| | right_pelvis_value = 3 |
| | background_value = 0 |
| |
|
| | |
| | sacrum_mask = (mask_array == sacrum_value) |
| |
|
| | first_nonzero_column_index = np.nonzero(np.any(sacrum_mask != 0, axis=0))[0][0] |
| | last_nonzero_column_index = np.max(np.nonzero(np.any(sacrum_mask != 0, axis=0))) |
| | box_width=last_nonzero_column_index-first_nonzero_column_index |
| |
|
| | dilation_extent = int(np.round(0.05 * box_width)) |
| |
|
| | dilated_sacrum_mask = dilate_mask(sacrum_mask, dilation_extent) |
| |
|
| | intersection_left = (dilated_sacrum_mask & (mask_array == left_pelvis_value)) |
| | if np.all(intersection_left == 0): |
| | print("Warning: No left intersection") |
| | left_pelvis_mask = (mask_array == 2) |
| | intersection_left = create_median_height_array(left_pelvis_mask) |
| | |
| | intersection_left = keep_largest_component(intersection_left) |
| | |
| | intersection_right = (dilated_sacrum_mask & (mask_array == right_pelvis_value)) |
| | if np.all(intersection_right == 0): |
| | print("Warning: No right intersection") |
| | right_pelvis_mask = (mask_array == 3) |
| | intersection_right = create_median_height_array(right_pelvis_mask) |
| | intersection_right = keep_largest_component(intersection_right) |
| | |
| | intersection_mask = intersection_left +intersection_right |
| | intersection_mask = intersection_mask[None, ...] |
| | |
| | instersection_mask_im = sitk.GetImageFromArray(intersection_mask) |
| | instersection_mask_im.SetSpacing(original_spacing) |
| | return instersection_mask_im |
| |
|
| | def dilate_mask(mask, extent): |
| | """ |
| | Keeps only the largest connected component in a binary segmentation mask. |
| | |
| | Args: |
| | mask (numpy.ndarray): A numpy array representing the binary segmentation mask, |
| | with 1s indicating the label and 0s indicating the background. |
| | |
| | Returns: |
| | numpy.ndarray: A modified version of the input mask, where only the largest |
| | connected component is retained, and other components are set to 0. |
| | |
| | """ |
| | mask_uint8 = mask.astype(np.uint8) |
| |
|
| | kernel = np.ones((2*extent+1, 2*extent+1), np.uint8) |
| | dilated_mask = dilate(mask_uint8, kernel, iterations=1) |
| | return dilated_mask |
| |
|
| | def mask_and_crop(image, input_label): |
| | """ |
| | Performs masking and cropping operations on an image and its label. |
| | |
| | Args: |
| | image (SimpleITK.Image): The image to be processed. |
| | label (SimpleITK.Image): The corresponding label image. |
| | |
| | Returns: |
| | tuple: A tuple containing two SimpleITK.Image objects. |
| | - cropped_boxed_image: The image after applying bounding box masking and cropping. |
| | - mask: The binary mask corresponding to the label after cropping. |
| | |
| | Note: |
| | This function relies on other functions: bounding_box_mask() and threshold_based_crop(). |
| | """ |
| | input_label = creat_SIJ_mask(image,input_label) |
| | box_mask = bounding_box_mask(image, input_label) |
| | |
| | boxed_image = sitk.Mask(image, box_mask, maskingValue=0, outsideValue=0) |
| | masked_image = sitk.Mask(image, input_label, maskingValue=0, outsideValue=0) |
| |
|
| | cropped_boxed_image = threshold_based_crop(boxed_image) |
| | cropped_masked_image = threshold_based_crop(masked_image) |
| |
|
| | mask = np.squeeze(sitk.GetArrayFromImage(cropped_masked_image)) |
| | mask = np.where(mask > 0, 1, 0) |
| | mask = sitk.GetImageFromArray(mask[None, ...]) |
| | return cropped_boxed_image, mask |
| |
|
| | def create_median_height_array(mask): |
| | """ |
| | Creates an array based on the median height of non-zero elements in each column of the input mask. |
| | |
| | Args: |
| | mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. |
| | |
| | Returns: |
| | numpy.ndarray: A new binary mask array with columns filled based on the median height, |
| | or None if the input mask has no non-zero columns. |
| | |
| | Note: |
| | This function is only used when there is no intersection between pelvis and sacrum, and creates an alternative |
| | SIJ mask, that serves as an approximate replacement. |
| | """ |
| | rows, cols = mask.shape |
| | column_details = [] |
| |
|
| | for col in range(cols): |
| | column_data = mask[:, col] |
| | non_zero_indices = np.nonzero(column_data)[0] |
| | if non_zero_indices.size > 0: |
| | height = non_zero_indices[-1] - non_zero_indices[0] + 1 |
| | start_idx = non_zero_indices[0] |
| | column_details.append((height, start_idx, col)) |
| | |
| | if not column_details: |
| | return None |
| | median_height = round(np.median([h[0] for h in column_details])) |
| | median_cols = [(col, start_idx) for height, start_idx, col in column_details if height == median_height] |
| | new_array = np.zeros_like(mask, dtype=int) |
| | for col, start_idx in median_cols: |
| | start_col = max(0, col - 5) |
| | end_col = min(cols, col + 5) |
| | new_array[start_idx:start_idx + median_height, start_col:end_col] = 1 |
| | return new_array |
| |
|
| | def keep_largest_component(mask): |
| | """ |
| | Identifies and retains the largest connected component in a binary segmentation mask. |
| | |
| | Args: |
| | mask (numpy.ndarray): A binary mask with 1s representing the label and 0s the background. |
| | |
| | Returns: |
| | numpy.ndarray: The modified mask with only the largest connected component. |
| | """ |
| | |
| | labeled_array, num_features = label(mask) |
| |
|
| | |
| | if num_features <= 1: |
| | return mask |
| |
|
| | |
| | largest_component = np.argmax(np.bincount(labeled_array.flat)[1:]) + 1 |
| |
|
| | |
| | return (labeled_array == largest_component).astype(mask.dtype) |