Spaces:
Sleeping
Sleeping
| import os | |
| import zipfile | |
| from pathlib import Path | |
| from time import time | |
| from typing import Union | |
| import matplotlib.pyplot as plt | |
| import dosma | |
| import numpy as np | |
| import wget | |
| import cv2 | |
| import scipy.misc | |
| from PIL import Image | |
| import dicom2nifti | |
| import math | |
| import pydicom | |
| import operator | |
| import moviepy.video.io.ImageSequenceClip | |
| from tkinter import Tcl | |
| import pandas as pd | |
| import warnings | |
| import numpy as np | |
| from skimage.morphology import skeletonize_3d | |
| from scipy.spatial.distance import pdist, squareform | |
| from scipy.interpolate import splprep, splev | |
| import nibabel as nib | |
| from nibabel.processing import resample_to_output | |
| import matplotlib.pyplot as plt | |
| from scipy.interpolate import interp1d | |
| from totalsegmentator.libs import ( | |
| download_pretrained_weights, | |
| nostdout, | |
| setup_nnunet, | |
| ) | |
| from comp2comp.inference_class_base import InferenceClass | |
| from comp2comp.models.models import Models | |
| from comp2comp.spine import spine_utils | |
| import nibabel as nib | |
| class AortaSegmentation(InferenceClass): | |
| """Spine segmentation.""" | |
| def __init__(self, save=True): | |
| super().__init__() | |
| self.model_name = "totalsegmentator" | |
| self.save_segmentations = save | |
| def __call__(self, inference_pipeline): | |
| # inference_pipeline.dicom_series_path = self.input_path | |
| self.output_dir = inference_pipeline.output_dir | |
| self.output_dir_segmentations = os.path.join(self.output_dir, "segmentations/") | |
| if not os.path.exists(self.output_dir_segmentations): | |
| os.makedirs(self.output_dir_segmentations) | |
| self.model_dir = inference_pipeline.model_dir | |
| seg, mv = self.spine_seg( | |
| os.path.join(self.output_dir_segmentations, "converted_dcm.nii.gz"), | |
| self.output_dir_segmentations + "spine.nii.gz", | |
| inference_pipeline.model_dir, | |
| ) | |
| seg = seg.get_fdata() | |
| medical_volume = mv.get_fdata() | |
| axial_masks = [] | |
| ct_image = [] | |
| for i in range(seg.shape[2]): | |
| axial_masks.append(seg[:, :, i]) | |
| for i in range(medical_volume.shape[2]): | |
| ct_image.append(medical_volume[:, :, i]) | |
| # Save input axial slices to pipeline | |
| inference_pipeline.ct_image = ct_image | |
| # Save aorta masks to pipeline | |
| inference_pipeline.axial_masks = axial_masks | |
| return {} | |
| def setup_nnunet_c2c(self, model_dir: Union[str, Path]): | |
| """Adapted from TotalSegmentator.""" | |
| model_dir = Path(model_dir) | |
| config_dir = model_dir / Path("." + self.model_name) | |
| (config_dir / "nnunet/results/nnUNet/3d_fullres").mkdir(exist_ok=True, parents=True) | |
| (config_dir / "nnunet/results/nnUNet/2d").mkdir(exist_ok=True, parents=True) | |
| weights_dir = config_dir / "nnunet/results" | |
| self.weights_dir = weights_dir | |
| os.environ["nnUNet_raw_data_base"] = str( | |
| weights_dir | |
| ) # not needed, just needs to be an existing directory | |
| os.environ["nnUNet_preprocessed"] = str( | |
| weights_dir | |
| ) # not needed, just needs to be an existing directory | |
| os.environ["RESULTS_FOLDER"] = str(weights_dir) | |
| def download_spine_model(self, model_dir: Union[str, Path]): | |
| download_dir = Path( | |
| os.path.join( | |
| self.weights_dir, | |
| "nnUNet/3d_fullres/Task253_Aorta/nnUNetTrainerV2_ep4000_nomirror__nnUNetPlansv2.1", | |
| ) | |
| ) | |
| print(download_dir) | |
| fold_0_path = download_dir / "fold_0" | |
| if not os.path.exists(fold_0_path): | |
| download_dir.mkdir(parents=True, exist_ok=True) | |
| wget.download( | |
| "https://huggingface.co/AdritRao/aaa_test/resolve/main/fold_0.zip", | |
| out=os.path.join(download_dir, "fold_0.zip"), | |
| ) | |
| with zipfile.ZipFile(os.path.join(download_dir, "fold_0.zip"), "r") as zip_ref: | |
| zip_ref.extractall(download_dir) | |
| os.remove(os.path.join(download_dir, "fold_0.zip")) | |
| wget.download( | |
| "https://huggingface.co/AdritRao/aaa_test/resolve/main/plans.pkl", | |
| out=os.path.join(download_dir, "plans.pkl"), | |
| ) | |
| print("Spine model downloaded.") | |
| else: | |
| print("Spine model already downloaded.") | |
| def spine_seg(self, input_path: Union[str, Path], output_path: Union[str, Path], model_dir): | |
| """Run spine segmentation. | |
| Args: | |
| input_path (Union[str, Path]): Input path. | |
| output_path (Union[str, Path]): Output path. | |
| """ | |
| print("Segmenting spine...") | |
| st = time() | |
| os.environ["SCRATCH"] = self.model_dir | |
| print(self.model_dir) | |
| # Setup nnunet | |
| model = "3d_fullres" | |
| folds = [0] | |
| trainer = "nnUNetTrainerV2_ep4000_nomirror" | |
| crop_path = None | |
| task_id = [253] | |
| self.setup_nnunet_c2c(model_dir) | |
| self.download_spine_model(model_dir) | |
| from totalsegmentator.nnunet import nnUNet_predict_image | |
| with nostdout(): | |
| img, seg = nnUNet_predict_image( | |
| input_path, | |
| output_path, | |
| task_id, | |
| model=model, | |
| folds=folds, | |
| trainer=trainer, | |
| tta=False, | |
| multilabel_image=True, | |
| resample=1.5, | |
| crop=None, | |
| crop_path=crop_path, | |
| task_name="total", | |
| nora_tag="None", | |
| preview=False, | |
| nr_threads_resampling=1, | |
| nr_threads_saving=6, | |
| quiet=False, | |
| verbose=False, | |
| test=0, | |
| ) | |
| end = time() | |
| # Log total time for spine segmentation | |
| print(f"Total time for spine segmentation: {end-st:.2f}s.") | |
| seg_data = seg.get_fdata() | |
| seg = nib.Nifti1Image(seg_data, seg.affine, seg.header) | |
| return seg, img | |
| class AortaDiameter(InferenceClass): | |
| def __init__(self): | |
| super().__init__() | |
| def normalize_img(self, img: np.ndarray) -> np.ndarray: | |
| """Normalize the image. | |
| Args: | |
| img (np.ndarray): Input image. | |
| Returns: | |
| np.ndarray: Normalized image. | |
| """ | |
| return (img - img.min()) / (img.max() - img.min()) | |
| def __call__(self, inference_pipeline): | |
| axial_masks = inference_pipeline.axial_masks # list of 2D numpy arrays of shape (512, 512) | |
| ct_img = inference_pipeline.ct_image # 3D numpy array of shape (512, 512, num_axial_slices) | |
| # image output directory | |
| output_dir = inference_pipeline.output_dir | |
| output_dir_slices = os.path.join(output_dir, "images/slices/") | |
| if not os.path.exists(output_dir_slices): | |
| os.makedirs(output_dir_slices) | |
| output_dir = inference_pipeline.output_dir | |
| output_dir_summary = os.path.join(output_dir, "images/summary/") | |
| if not os.path.exists(output_dir_summary): | |
| os.makedirs(output_dir_summary) | |
| DICOM_PATH = inference_pipeline.dicom_series_path | |
| dicom = pydicom.dcmread(DICOM_PATH+"/"+os.listdir(DICOM_PATH)[0]) | |
| dicom.PhotometricInterpretation = 'YBR_FULL' | |
| pixel_conversion = dicom.PixelSpacing | |
| print("Pixel conversion: "+str(pixel_conversion)) | |
| RATIO_PIXEL_TO_MM = pixel_conversion[0] | |
| SLICE_COUNT = dicom["InstanceNumber"].value | |
| print(SLICE_COUNT) | |
| SLICE_COUNT = len(ct_img) | |
| diameterDict = {} | |
| for i in range(len(ct_img)): | |
| mask = axial_masks[i].astype('uint8') | |
| img = ct_img[i] | |
| img = np.clip(img, -300, 1800) | |
| img = self.normalize_img(img) * 255.0 | |
| img = img.reshape((img.shape[0], img.shape[1], 1)) | |
| img = np.tile(img, (1, 1, 3)) | |
| contours, _ = cv2.findContours(mask, cv2.RETR_LIST, cv2.CHAIN_APPROX_NONE) | |
| if len(contours) != 0: | |
| areas = [cv2.contourArea(c) for c in contours] | |
| sorted_areas = np.sort(areas) | |
| contours = contours[areas.index(sorted_areas[-1])] | |
| overlay = img.copy() | |
| back = img.copy() | |
| cv2.drawContours(back, [contours], 0, (0,255,0), -1) | |
| alpha = 0.25 | |
| img = cv2.addWeighted(img, 1-alpha, back, alpha, 0) | |
| ellipse = cv2.fitEllipse(contours) | |
| (xc,yc),(d1,d2),angle = ellipse | |
| cv2.ellipse(img, ellipse, (0, 255, 0), 1) | |
| xc, yc = ellipse[0] | |
| cv2.circle(img, (int(xc),int(yc)), 5, (0, 0, 255), -1) | |
| rmajor = max(d1,d2)/2 | |
| rminor = min(d1,d2)/2 | |
| ### Draw major axes | |
| if angle > 90: | |
| angle = angle - 90 | |
| else: | |
| angle = angle + 90 | |
| print(angle) | |
| xtop = xc + math.cos(math.radians(angle))*rmajor | |
| ytop = yc + math.sin(math.radians(angle))*rmajor | |
| xbot = xc + math.cos(math.radians(angle+180))*rmajor | |
| ybot = yc + math.sin(math.radians(angle+180))*rmajor | |
| cv2.line(img, (int(xtop),int(ytop)), (int(xbot),int(ybot)), (0, 0, 255), 3) | |
| ### Draw minor axes | |
| if angle > 90: | |
| angle = angle - 90 | |
| else: | |
| angle = angle + 90 | |
| print(angle) | |
| x1 = xc + math.cos(math.radians(angle))*rminor | |
| y1 = yc + math.sin(math.radians(angle))*rminor | |
| x2 = xc + math.cos(math.radians(angle+180))*rminor | |
| y2 = yc + math.sin(math.radians(angle+180))*rminor | |
| cv2.line(img, (int(x1),int(y1)), (int(x2),int(y2)), (255, 0, 0), 3) | |
| # pixel_length = math.sqrt( (x1-x2)**2 + (y1-y2)**2 ) | |
| pixel_length = rminor*2 | |
| print("Pixel_length_minor: "+str(pixel_length)) | |
| area_px = cv2.contourArea(contours) | |
| area_mm = round(area_px*RATIO_PIXEL_TO_MM) | |
| area_cm = area_mm/10 | |
| diameter_mm = round((pixel_length)*RATIO_PIXEL_TO_MM) | |
| diameter_cm = diameter_mm/10 | |
| diameterDict[(SLICE_COUNT-(i))] = diameter_cm | |
| img = cv2.rotate(img, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| h,w,c = img.shape | |
| lbls = ["Area (mm): "+str(area_mm)+"mm", "Area (cm): "+str(area_cm)+"cm", "Diameter (mm): "+str(diameter_mm)+"mm", "Diameter (cm): "+str(diameter_cm)+"cm", "Slice: "+str(SLICE_COUNT-(i))] | |
| offset = 0 | |
| font = cv2.FONT_HERSHEY_SIMPLEX | |
| scale = 0.03 | |
| fontScale = min(w,h)/(25/scale) | |
| cv2.putText(img, lbls[0], (10, 40), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[1], (10, 70), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[2], (10, 100), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[3], (10, 130), font, fontScale, (0, 255, 0), 2) | |
| cv2.putText(img, lbls[4], (10, 160), font, fontScale, (0, 255, 0), 2) | |
| cv2.imwrite(output_dir_slices+"slice"+str(SLICE_COUNT-(i))+".png", img) | |
| plt.bar(list(diameterDict.keys()), diameterDict.values(), color='b') | |
| plt.title(r"$\bf{Diameter}$" + " " + r"$\bf{Progression}$") | |
| plt.xlabel('Slice Number') | |
| plt.ylabel('Diameter Measurement (cm)') | |
| plt.savefig(output_dir_summary+"diameter_graph.png", dpi=500) | |
| print(diameterDict) | |
| print(max(diameterDict.items(), key=operator.itemgetter(1))[0]) | |
| print(diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]]) | |
| inference_pipeline.max_diameter = diameterDict[max(diameterDict.items(), key=operator.itemgetter(1))[0]] | |
| img = ct_img[SLICE_COUNT-(max(diameterDict.items(), key=operator.itemgetter(1))[0])] | |
| img = np.clip(img, -300, 1800) | |
| img = self.normalize_img(img) * 255.0 | |
| img = img.reshape((img.shape[0], img.shape[1], 1)) | |
| img2 = np.tile(img, (1, 1, 3)) | |
| img2 = cv2.rotate(img2, cv2.ROTATE_90_COUNTERCLOCKWISE) | |
| img1 = cv2.imread(output_dir_slices+'slice'+str(max(diameterDict.items(), key=operator.itemgetter(1))[0])+'.png') | |
| border_size = 3 | |
| img1 = cv2.copyMakeBorder( | |
| img1, | |
| top=border_size, | |
| bottom=border_size, | |
| left=border_size, | |
| right=border_size, | |
| borderType=cv2.BORDER_CONSTANT, | |
| value=[0, 244, 0] | |
| ) | |
| img2 = cv2.copyMakeBorder( | |
| img2, | |
| top=border_size, | |
| bottom=border_size, | |
| left=border_size, | |
| right=border_size, | |
| borderType=cv2.BORDER_CONSTANT, | |
| value=[244, 0, 0] | |
| ) | |
| vis = np.concatenate((img2, img1), axis=1) | |
| cv2.imwrite(output_dir_summary+'out.png', vis) | |
| image_folder=output_dir_slices | |
| fps=20 | |
| image_files = [os.path.join(image_folder,img) | |
| for img in Tcl().call('lsort', '-dict', os.listdir(image_folder)) | |
| if img.endswith(".png")] | |
| clip = moviepy.video.io.ImageSequenceClip.ImageSequenceClip(image_files, fps=fps) | |
| clip.write_videofile(output_dir_summary+'aaa.mp4') | |
| def compute_centerline_3d(aorta_segmentation): | |
| skeleton = skeletonize_3d(aorta_segmentation) | |
| z, y, x = np.where(skeleton) | |
| centerline_points = np.vstack((x, y, z)).T | |
| centerline_points = centerline_points[centerline_points[:, 0].argsort()] | |
| return centerline_points | |
| def fit_bspline(centerline_points, smoothness=1e8): | |
| x, y, z = centerline_points.T | |
| tck, _ = splprep([x, y, z], s=smoothness) | |
| return tck | |
| def evaluate_bspline(tck, num_points=1000): | |
| u = np.linspace(0, 1, num_points) | |
| x, y, z = splev(u, tck) | |
| return np.vstack((x, y, z)).T | |
| def interpolate_points(data, num_points=32): | |
| x = data[:, 0] | |
| y = data[:, 1:] | |
| f_y = interp1d(x, y, kind="nearest", fill_value="extrapolate", axis=0) | |
| new_x = np.arange(0, num_points) | |
| new_y = f_y(new_x) | |
| new_data = np.round(np.hstack((new_x.reshape(-1, 1), new_y))) | |
| return new_data | |
| def compute_orthogonal_planes(tck, num_points=100): | |
| u = np.linspace(0, 1, num_points) | |
| points = np.vstack(splev(u, tck)).T | |
| tangents = np.vstack(splev(u, tck, der=1)).T | |
| normals = tangents / np.linalg.norm(tangents, axis=1)[:, np.newaxis] | |
| planes = [] | |
| for point, normal in zip(points, normals): | |
| d = -np.dot(point, normal) | |
| planes.append((normal, d)) | |
| return planes | |
| def compute_maximum_diameter(aorta_segmentation, planes): | |
| z, y, x = np.where(aorta_segmentation) | |
| aorta_points = np.vstack((x, y, z)).T | |
| max_diameters = [] | |
| intersecting_points_list = [] | |
| for normal, d in planes: | |
| distances = np.dot(aorta_points, normal) + d | |
| intersecting_points = aorta_points[np.abs(distances) < 0.5] | |
| if len(intersecting_points) < 2: | |
| continue | |
| dist_matrix = squareform(pdist(intersecting_points)) | |
| intersecting_points_list.append(intersecting_points) | |
| max_diameter = np.max(dist_matrix) | |
| max_diameters.append(max_diameter) | |
| max_diameter_index = np.argmax(max_diameters) | |
| max_diameter_in_pixels = max_diameters[max_diameter_index] | |
| print(f'Maximum Diameter in Pixels: {max_diameter_in_pixels}') | |
| diameter_mm = round((max_diameter_in_pixels)*RATIO_PIXEL_TO_MM) | |
| print(f'Maximum Diameter in mm: {diameter_mm}') | |
| max_diameters = np.array(max_diameters) * 0.15 | |
| max_diameter_index = np.argmax(max_diameters) | |
| max_diameter_normal, max_diameter_point = planes[max_diameter_index] | |
| max_intersecting_points = intersecting_points_list[max_diameter_index] | |
| print("max_diameter_normal type:", type(max_diameter_normal)) | |
| print("max_diameter_normal shape:", np.shape(max_diameter_normal)) | |
| print("max_diameter_point type:", type(max_diameter_point)) | |
| print("max_diameter_point shape:", np.shape(max_diameter_point)) | |
| print("max intersecting points type:", type(max_intersecting_points)) | |
| print("max intersecting points shape:", np.shape(max_intersecting_points)) | |
| print("max intersecting points:", max_intersecting_points) | |
| return ( | |
| max_diameters, | |
| max_diameter_point, | |
| max_diameter_normal, | |
| max_intersecting_points, | |
| ) | |
| def plot_2d_planar_reconstruction( | |
| image, | |
| segmentation, | |
| interpolated_points, | |
| max_diameter_point, | |
| max_diameter_normal, | |
| max_intersecting_points, | |
| ): | |
| fig, axs = plt.subplots(nrows=2, ncols=1, figsize=(15, 10)) | |
| sagittal_index = interpolated_points[:, 2].astype(int) | |
| image_2d = image[sagittal_index, :, range(image.shape[2])] | |
| seg_2d = segmentation[sagittal_index, :, range(image.shape[2])] | |
| # axs[0].imshow(image_2d, cmap="gray") | |
| # axs[0].imshow(seg_2d, cmap="jet", alpha=0.3) | |
| axs[0].scatter( | |
| interpolated_points[:, 1].astype(int), | |
| interpolated_points[:, 0].astype(int), | |
| color="red", | |
| s=1, | |
| ) | |
| axs[0].plot( | |
| max_intersecting_points[:, 1].astype(int), | |
| max_intersecting_points[:, 0].astype(int), | |
| color="blue", | |
| ) | |
| coronal_index = interpolated_points[:, 1].astype(int) | |
| image_2d = image[:, coronal_index, range(image.shape[2])].T | |
| seg_2d = segmentation[:, coronal_index, range(image.shape[2])].T | |
| # axs[1].imshow(image_2d, cmap="gray") | |
| # axs[1].imshow(seg_2d, cmap="jet", alpha=0.3) | |
| axs[1].scatter( | |
| interpolated_points[:, 2].astype(int), | |
| interpolated_points[:, 0].astype(int), | |
| color="red", | |
| s=1, | |
| ) | |
| axs[1].plot( | |
| max_intersecting_points[:, 2].astype(int), | |
| max_intersecting_points[:, 0].astype(int), | |
| color="blue", | |
| ) | |
| plt.savefig(output_dir_summary+"planar_reconstruction.png") | |
| output_dir = inference_pipeline.output_dir_segmentations | |
| segmentation = nib.load( | |
| os.path.join(output_dir, "converted_dcm.nii.gz") | |
| ) | |
| image = nib.load( | |
| os.path.join(output_dir, "spine.nii.gz") | |
| ) | |
| image = resample_to_output(image, (1.5, 1.5, 1.5)) | |
| segmentation = resample_to_output(segmentation, (1.5, 1.5, 1.5), order=0) | |
| image = image.get_fdata() | |
| segmentation = segmentation.get_fdata() | |
| segmentation[segmentation == 42] = 1 | |
| print(segmentation.shape) | |
| print(np.unique(segmentation)) | |
| centerline_points = compute_centerline_3d(segmentation) | |
| print(centerline_points) | |
| tck = fit_bspline(centerline_points) | |
| evaluated_points = evaluate_bspline(tck) | |
| print(evaluated_points) | |
| interpolated_points = interpolate_points(evaluated_points, image.shape[2]) | |
| print(interpolated_points) | |
| planes = compute_orthogonal_planes(tck) | |
| ( | |
| cmax_diameters, | |
| max_diameter_point, | |
| max_diameter_normal, | |
| max_intersecting_points, | |
| ) = compute_maximum_diameter(segmentation, planes) | |
| plot_2d_planar_reconstruction( | |
| image, | |
| segmentation, | |
| interpolated_points, | |
| max_diameter_point, | |
| max_diameter_normal, | |
| max_intersecting_points, | |
| ) | |
| return {} | |
| class AortaMetricsSaver(InferenceClass): | |
| """Save metrics to a CSV file.""" | |
| def __init__(self): | |
| super().__init__() | |
| def __call__(self, inference_pipeline): | |
| """Save metrics to a CSV file.""" | |
| self.max_diameter = inference_pipeline.max_diameter | |
| self.dicom_series_path = inference_pipeline.dicom_series_path | |
| self.output_dir = inference_pipeline.output_dir | |
| self.csv_output_dir = os.path.join(self.output_dir, "metrics") | |
| if not os.path.exists(self.csv_output_dir): | |
| os.makedirs(self.csv_output_dir, exist_ok=True) | |
| self.save_results() | |
| return {} | |
| def save_results(self): | |
| """Save results to a CSV file.""" | |
| _, filename = os.path.split(self.dicom_series_path) | |
| data = [[filename, str(self.max_diameter)]] | |
| df = pd.DataFrame(data, columns=['Filename', 'Max Diameter']) | |
| df.to_csv(os.path.join(self.csv_output_dir, "aorta_metrics.csv"), index=False) |