drankush-ai's picture
Update app.py
3abb5ec verified
import gradio as gr
import torch
import cv2
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from fastMONAI.vision_all import *
from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume
import sys
# Debug: List all symbols imported from fastMONAI.vision_all
print("[DEBUG] fastMONAI.vision_all symbols:", dir())
from git import Repo
import os
#Additional support for local execution:-
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
#pathlib.PosixPath = temp
clone_dir = Path.cwd() / 'clone_dir'
URI = os.getenv('PAT_Token_URI')
if os.path.exists(clone_dir):
pass
else:
Repo.clone_from(URI, clone_dir)
def extract_slices_from_mask(img, mask_data, view):
"""Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
slices = []
target_size = (320, 320)
for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
if view == "Sagittal":
slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
elif view == "Axial":
slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
elif view == "Coronal":
slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
slice_img = np.fliplr(np.rot90(slice_img, -1))
slice_mask = np.fliplr(np.rot90(slice_mask, -1))
slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
slices.append((slice_img_resized, slice_mask_resized))
return slices
def resize_and_pad(slice_img, slice_mask, target_size):
"""Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
h, w = slice_img.shape
scale = min(target_size[0] / w, target_size[1] / h)
new_w, new_h = int(w * scale), int(h * scale)
resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)
pad_w = (target_size[0] - new_w) // 2
pad_h = (target_size[1] - new_h) // 2
padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
return padded_img, padded_mask
def normalize_image(slice_img):
"""Normalize the image to the range [0, 255] safely."""
slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
if slice_img_min == slice_img_max: # Avoid division by zero
return np.zeros_like(slice_img, dtype=np.uint8)
normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
return normalized_img.astype(np.uint8)
def get_fused_image(img, pred_mask, view, alpha=0.8):
"""Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
mask_color = np.array([255, 0, 0])
colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
# Flip the fused image vertically and horizontally
fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally
if view=='Sagittal':
return fused_flipped
elif view=='Coronal' or 'Axial':
rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
return rotated
def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
"""Predict function using the learner and other resources."""
if view == None:
view = 'Sagittal'
img_path = Path(fileobj.name)
save_fn = 'pred_' + img_path.stem
save_path = save_dir / save_fn
org_img, input_img, org_size = med_img_reader(img_path,
reorder=reorder,
resample=resample,
only_tensor=False)
mask_data = inference(learn, reorder=reorder, resample=resample,
org_img=org_img, input_img=input_img,
org_size=org_size).data
if "".join(org_img.orientation) == "LSA":
mask_data = mask_data.permute(0,1,3,2)
mask_data = torch.flip(mask_data[0], dims=[1])
mask_data = torch.Tensor(mask_data)[None]
img = org_img.data
org_img.set_data(mask_data)
org_img.save(save_path)
slices = extract_slices_from_mask(img[0], mask_data[0], view)
fused_images = [(get_fused_image(
normalize_image(slice_img), # Normalize safely
slice_mask, view))
for slice_img, slice_mask in slices]
volume = compute_binary_tumor_volume(org_img)
return fused_images, round(volume, 2)
# Initialize the system
models_path = Path.cwd() / 'clone_dir'
save_dir = Path.cwd() / 'hs_pred'
save_dir.mkdir(parents=True, exist_ok=True)
# Debug: Check if load_system_resources is defined
learn, reorder, resample = load_system_resources(models_path=models_path,
learner_fn='heart_model.pkl',
variables_fn='vars.pkl')
# Gradio interface setup
output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")
view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")
demo = gr.Interface(
fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
inputs=["file", view_selector],
outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
examples=[[str(Path.cwd() /"sample.nii.gz")]],
allow_flagging='never')
# Launch the Gradio interface
demo.launch()