| | import gradio as gr |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | from pathlib import Path |
| | from huggingface_hub import snapshot_download |
| | from fastMONAI.vision_all import * |
| | from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume |
| | import sys |
| |
|
| | |
| | print("[DEBUG] fastMONAI.vision_all symbols:", dir()) |
| | from git import Repo |
| | import os |
| |
|
| | |
| | |
| | |
| | |
| | |
| |
|
| | clone_dir = Path.cwd() / 'clone_dir' |
| | URI = os.getenv('PAT_Token_URI') |
| |
|
| | if os.path.exists(clone_dir): |
| | pass |
| | else: |
| | Repo.clone_from(URI, clone_dir) |
| |
|
| | def extract_slices_from_mask(img, mask_data, view): |
| | """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view.""" |
| | slices = [] |
| | target_size = (320, 320) |
| | |
| | for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]): |
| | if view == "Sagittal": |
| | slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] |
| | elif view == "Axial": |
| | slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :] |
| | elif view == "Coronal": |
| | slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :] |
| | |
| | slice_img = np.fliplr(np.rot90(slice_img, -1)) |
| | slice_mask = np.fliplr(np.rot90(slice_mask, -1)) |
| |
|
| | slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size) |
| | slices.append((slice_img_resized, slice_mask_resized)) |
| | |
| | return slices |
| |
|
| | def resize_and_pad(slice_img, slice_mask, target_size): |
| | """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio.""" |
| | h, w = slice_img.shape |
| | scale = min(target_size[0] / w, target_size[1] / h) |
| | new_w, new_h = int(w * scale), int(h * scale) |
| | |
| | resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) |
| | resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) |
| |
|
| | pad_w = (target_size[0] - new_w) // 2 |
| | pad_h = (target_size[1] - new_h) // 2 |
| |
|
| | padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) |
| | padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) |
| |
|
| | return padded_img, padded_mask |
| |
|
| | def normalize_image(slice_img): |
| | """Normalize the image to the range [0, 255] safely.""" |
| | slice_img_min, slice_img_max = slice_img.min(), slice_img.max() |
| | if slice_img_min == slice_img_max: |
| | return np.zeros_like(slice_img, dtype=np.uint8) |
| | normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255 |
| | return normalized_img.astype(np.uint8) |
| |
|
| | def get_fused_image(img, pred_mask, view, alpha=0.8): |
| | """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" |
| | gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
| | mask_color = np.array([255, 0, 0]) |
| | colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) |
| | |
| | fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) |
| | |
| | |
| | fused_flipped = cv2.flip(fused, -1) |
| | |
| | if view=='Sagittal': |
| | return fused_flipped |
| | elif view=='Coronal' or 'Axial': |
| | rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1) |
| | return rotated |
| |
|
| | def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view): |
| | """Predict function using the learner and other resources.""" |
| |
|
| | if view == None: |
| | view = 'Sagittal' |
| | |
| | img_path = Path(fileobj.name) |
| |
|
| | save_fn = 'pred_' + img_path.stem |
| | save_path = save_dir / save_fn |
| | org_img, input_img, org_size = med_img_reader(img_path, |
| | reorder=reorder, |
| | resample=resample, |
| | only_tensor=False) |
| | |
| | mask_data = inference(learn, reorder=reorder, resample=resample, |
| | org_img=org_img, input_img=input_img, |
| | org_size=org_size).data |
| | |
| | if "".join(org_img.orientation) == "LSA": |
| | mask_data = mask_data.permute(0,1,3,2) |
| | mask_data = torch.flip(mask_data[0], dims=[1]) |
| | mask_data = torch.Tensor(mask_data)[None] |
| |
|
| | img = org_img.data |
| | org_img.set_data(mask_data) |
| | org_img.save(save_path) |
| |
|
| | slices = extract_slices_from_mask(img[0], mask_data[0], view) |
| | fused_images = [(get_fused_image( |
| | normalize_image(slice_img), |
| | slice_mask, view)) |
| | for slice_img, slice_mask in slices] |
| | |
| | volume = compute_binary_tumor_volume(org_img) |
| |
|
| | return fused_images, round(volume, 2) |
| |
|
| | |
| | models_path = Path.cwd() / 'clone_dir' |
| | save_dir = Path.cwd() / 'hs_pred' |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| |
|
| | |
| | learn, reorder, resample = load_system_resources(models_path=models_path, |
| | learner_fn='heart_model.pkl', |
| | variables_fn='vars.pkl') |
| |
|
| | |
| | output_text = gr.Textbox(label="Volume of the Left Atrium (mL):") |
| |
|
| | view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)") |
| |
|
| | demo = gr.Interface( |
| | fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view), |
| | inputs=["file", view_selector], |
| | outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text], |
| | examples=[[str(Path.cwd() /"sample.nii.gz")]], |
| | allow_flagging='never') |
| |
|
| | |
| | demo.launch() |