File size: 6,252 Bytes
c34fb59 febd0e1 c34fb59 3abb5ec 2064622 c34fb59 b0bd5fc c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f c34fb59 5638d3f b0bd5fc c34fb59 b0bd5fc c34fb59 490c5a6 c34fb59 5638d3f b0bd5fc c34fb59 5638d3f 490c5a6 b0bd5fc c34fb59 f37e491 b0bd5fc 5638d3f c34fb59 5638d3f b0bd5fc 5638d3f b0bd5fc 5638d3f c34fb59 5638d3f 27298f5 c34fb59 b0bd5fc 5638d3f b0bd5fc 5638d3f f37e491 c34fb59 5638d3f 2fecb9a 5638d3f b0bd5fc 5638d3f 2064622 3abb5ec 36d522b b0bd5fc 904a4b1 b0bd5fc 805baa7 b0bd5fc c34fb59 5638d3f b0bd5fc | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 | import gradio as gr
import torch
import cv2
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from fastMONAI.vision_all import *
from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume
import sys
# Debug: List all symbols imported from fastMONAI.vision_all
print("[DEBUG] fastMONAI.vision_all symbols:", dir())
from git import Repo
import os
#Additional support for local execution:-
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
#pathlib.PosixPath = temp
clone_dir = Path.cwd() / 'clone_dir'
URI = os.getenv('PAT_Token_URI')
if os.path.exists(clone_dir):
pass
else:
Repo.clone_from(URI, clone_dir)
def extract_slices_from_mask(img, mask_data, view):
"""Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
slices = []
target_size = (320, 320)
for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
if view == "Sagittal":
slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
elif view == "Axial":
slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
elif view == "Coronal":
slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
slice_img = np.fliplr(np.rot90(slice_img, -1))
slice_mask = np.fliplr(np.rot90(slice_mask, -1))
slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
slices.append((slice_img_resized, slice_mask_resized))
return slices
def resize_and_pad(slice_img, slice_mask, target_size):
"""Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
h, w = slice_img.shape
scale = min(target_size[0] / w, target_size[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
pad_w = (target_size[0] - new_w) // 2
pad_h = (target_size[1] - new_h) // 2
padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
return padded_img, padded_mask
def normalize_image(slice_img):
"""Normalize the image to the range [0, 255] safely."""
slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
if slice_img_min == slice_img_max: # Avoid division by zero
return np.zeros_like(slice_img, dtype=np.uint8)
normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
return normalized_img.astype(np.uint8)
def get_fused_image(img, pred_mask, view, alpha=0.8):
"""Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
mask_color = np.array([255, 0, 0])
colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
# Flip the fused image vertically and horizontally
fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
if view=='Sagittal':
return fused_flipped
elif view=='Coronal' or 'Axial':
rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
return rotated
def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
"""Predict function using the learner and other resources."""
if view == None:
view = 'Sagittal'
img_path = Path(fileobj.name)
save_fn = 'pred_' + img_path.stem
save_path = save_dir / save_fn
org_img, input_img, org_size = med_img_reader(img_path,
reorder=reorder,
resample=resample,
only_tensor=False)
mask_data = inference(learn, reorder=reorder, resample=resample,
org_img=org_img, input_img=input_img,
org_size=org_size).data
if "".join(org_img.orientation) == "LSA":
mask_data = mask_data.permute(0,1,3,2)
mask_data = torch.flip(mask_data[0], dims=[1])
mask_data = torch.Tensor(mask_data)[None]
img = org_img.data
org_img.set_data(mask_data)
org_img.save(save_path)
slices = extract_slices_from_mask(img[0], mask_data[0], view)
fused_images = [(get_fused_image(
normalize_image(slice_img), # Normalize safely
slice_mask, view))
for slice_img, slice_mask in slices]
volume = compute_binary_tumor_volume(org_img)
return fused_images, round(volume, 2)
# Initialize the system
models_path = Path.cwd() / 'clone_dir'
save_dir = Path.cwd() / 'hs_pred'
save_dir.mkdir(parents=True, exist_ok=True)
# Debug: Check if load_system_resources is defined
learn, reorder, resample = load_system_resources(models_path=models_path,
learner_fn='heart_model.pkl',
variables_fn='vars.pkl')
# Gradio interface setup
output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
demo = gr.Interface(
fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
inputs=["file", view_selector],
outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
examples=[[str(Path.cwd() /"sample.nii.gz")]],
allow_flagging='never')
# Launch the Gradio interface
demo.launch() |