import gradio as gr import torch import cv2 import numpy as np from pathlib import Path from huggingface_hub import snapshot_download from fastMONAI.vision_all import * #import pathlib #temp = pathlib.PosixPath #pathlib.PosixPath = pathlib.WindowsPath #pathlib.PosixPath = temp def extract_slices_from_mask(img, mask_data): """Extract all slices from the 3D [W, H, D] image and mask data.""" slices = [] for idx in range(img.shape[-1]): slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] slice_img = np.fliplr(np.rot90(slice_img, -1)) slice_mask = np.fliplr(np.rot90(slice_mask, -1)) slices.append((slice_img, slice_mask)) return slices def get_fused_image(img, pred_mask, alpha=0.8): """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) mask_color = np.array([255, 0, 0]) colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) # Flip the fused image vertically and horizontally fused_flipped = cv2.flip(fused, -1) # Flip both vertically and horizontally return fused_flipped def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir): """Predict function using the learner and other resources.""" img_path = Path(fileobj.name) save_fn = 'pred_' + img_path.stem save_path = save_dir / save_fn org_img, input_img, org_size = med_img_reader(img_path, reorder=reorder, resample=resample, only_tensor=False) mask_data = inference(learn, reorder=reorder, resample=resample, org_img=org_img, input_img=input_img, org_size=org_size).data if "".join(org_img.orientation) == "LSA": mask_data = mask_data.permute(0,1,3,2) mask_data = torch.flip(mask_data[0], dims=[1]) mask_data = torch.Tensor(mask_data)[None] img = org_img.data org_img.set_data(mask_data) org_img.save(save_path) slices = extract_slices_from_mask(img[0], mask_data[0]) fused_images = [(get_fused_image( ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8), slice_mask)) for slice_img, slice_mask in slices] volume = compute_binary_tumor_volume(org_img) return fused_images, round(volume, 2) # Initialize the system models_path = Path.cwd() save_dir = Path.cwd() / 'hs_pred' save_dir.mkdir(parents=True, exist_ok=True) # Load the model and other required resources learn, reorder, resample = load_system_resources(models_path=Path.cwd(), learner_fn='heart_model.pkl', variables_fn='vars.pkl') # Gradio interface setup output_text = gr.Textbox(label="Volume of the Left Atrium (mL):") demo = gr.Interface( fn=lambda fileobj: gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir), inputs=["file"], outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text], examples=[[str(Path.cwd() /"sample.nii.gz")]], allow_flagging='never') # Launch the Gradio interface demo.launch()