import gradio as gr import torch import cv2 import numpy as np from pathlib import Path from huggingface_hub import snapshot_download from fastMONAI.vision_all import * from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume import sys # Debug: List all symbols imported from fastMONAI.vision_all print("[DEBUG] fastMONAI.vision_all symbols:", dir()) from git import Repo import os #Additional support for local execution:- #import pathlib #temp = pathlib.PosixPath #pathlib.PosixPath = pathlib.WindowsPath #pathlib.PosixPath = temp clone_dir = Path.cwd() / 'clone_dir' URI = os.getenv('PAT_Token_URI') if os.path.exists(clone_dir): pass else: Repo.clone_from(URI, clone_dir) def extract_slices_from_mask(img, mask_data, view): """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view.""" slices = [] target_size = (320, 320) for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]): if view == "Sagittal": slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] elif view == "Axial": slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :] elif view == "Coronal": slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :] slice_img = np.fliplr(np.rot90(slice_img, -1)) slice_mask = np.fliplr(np.rot90(slice_mask, -1)) slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size) slices.append((slice_img_resized, slice_mask_resized)) return slices def resize_and_pad(slice_img, slice_mask, target_size): """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio.""" h, w = slice_img.shape scale = min(target_size[0] / w, target_size[1] / h) new_w, new_h = int(w * scale), int(h * scale) resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR) resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST) pad_w = (target_size[0] - new_w) // 2 pad_h = (target_size[1] - new_h) // 2 padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0) return padded_img, padded_mask def normalize_image(slice_img): """Normalize the image to the range [0, 255] safely.""" slice_img_min, slice_img_max = slice_img.min(), slice_img.max() if slice_img_min == slice_img_max: # Avoid division by zero return np.zeros_like(slice_img, dtype=np.uint8) normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255 return normalized_img.astype(np.uint8) def get_fused_image(img, pred_mask, view, alpha=0.8): """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) mask_color = np.array([255, 0, 0]) colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) # Flip the fused image vertically and horizontally fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally if view=='Sagittal': return fused_flipped elif view=='Coronal' or 'Axial': rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1) return rotated def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view): """Predict function using the learner and other resources.""" if view == None: view = 'Sagittal' img_path = Path(fileobj.name) save_fn = 'pred_' + img_path.stem save_path = save_dir / save_fn org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False) mask_data = inference(learn, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data if "".join(org_img.orientation) == "LSA": mask_data = mask_data.permute(0,1,3,2) mask_data = torch.flip(mask_data[0], dims=[1]) mask_data = torch.Tensor(mask_data)[None] img = org_img.data org_img.set_data(mask_data) org_img.save(save_path) slices = extract_slices_from_mask(img[0], mask_data[0], view) fused_images = [(get_fused_image( normalize_image(slice_img), # Normalize safely slice_mask, view)) for slice_img, slice_mask in slices] volume = compute_binary_tumor_volume(org_img) return fused_images, round(volume, 2) # Initialize the system models_path = Path.cwd() / 'clone_dir' save_dir = Path.cwd() / 'hs_pred' save_dir.mkdir(parents=True, exist_ok=True) # Debug: Check if load_system_resources is defined learn, reorder, resample = load_system_resources(models_path=models_path, learner_fn='heart_model.pkl', variables_fn='vars.pkl') # Gradio interface setup output_text = gr.Textbox(label="Volume of the Left Atrium (mL):") view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)") demo = gr.Interface( fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view), inputs=["file", view_selector], outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text], examples=[[str(Path.cwd() /"sample.nii.gz")]], allow_flagging='never') # Launch the Gradio interface demo.launch()