| | import gradio as gr |
| | import torch |
| | import cv2 |
| | import numpy as np |
| | from pathlib import Path |
| | from huggingface_hub import snapshot_download |
| | from fastMONAI.vision_all import * |
| | |
| | |
| | |
| | |
| |
|
| | def extract_slices_from_mask(img, mask_data): |
| | """Extract all slices from the 3D [W, H, D] image and mask data.""" |
| | slices = [] |
| | for idx in range(img.shape[-1]): |
| | slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx] |
| | slice_img = np.fliplr(np.rot90(slice_img, -1)) |
| | slice_mask = np.fliplr(np.rot90(slice_mask, -1)) |
| | slices.append((slice_img, slice_mask)) |
| | return slices |
| |
|
| | def get_fused_image(img, pred_mask, alpha=0.8): |
| | """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically.""" |
| | gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR) |
| | mask_color = np.array([255, 0, 0]) |
| | colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8) |
| | |
| | fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0) |
| | |
| | |
| | fused_flipped = cv2.flip(fused, -1) |
| | |
| | return fused_flipped |
| |
|
| | def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir): |
| | """Predict function using the learner and other resources.""" |
| | img_path = Path(fileobj.name) |
| |
|
| | save_fn = 'pred_' + img_path.stem |
| | save_path = save_dir / save_fn |
| | org_img, input_img, org_size = med_img_reader(img_path, |
| | reorder=reorder, |
| | resample=resample, |
| | only_tensor=False) |
| | |
| | mask_data = inference(learn, reorder=reorder, resample=resample, |
| | org_img=org_img, input_img=input_img, |
| | org_size=org_size).data |
| | |
| | if "".join(org_img.orientation) == "LSA": |
| | mask_data = mask_data.permute(0,1,3,2) |
| | mask_data = torch.flip(mask_data[0], dims=[1]) |
| | mask_data = torch.Tensor(mask_data)[None] |
| |
|
| | img = org_img.data |
| | org_img.set_data(mask_data) |
| | org_img.save(save_path) |
| |
|
| | slices = extract_slices_from_mask(img[0], mask_data[0]) |
| | fused_images = [(get_fused_image( |
| | ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8), |
| | slice_mask)) |
| | for slice_img, slice_mask in slices] |
| | |
| | volume = compute_binary_tumor_volume(org_img) |
| |
|
| | return fused_images, round(volume, 2) |
| |
|
| | |
| | models_path = Path.cwd() |
| | save_dir = Path.cwd() / 'hs_pred' |
| | save_dir.mkdir(parents=True, exist_ok=True) |
| |
|
| | |
| | learn, reorder, resample = load_system_resources(models_path=Path.cwd(), |
| | learner_fn='heart_model.pkl', |
| | variables_fn='vars.pkl') |
| |
|
| | |
| | output_text = gr.Textbox(label="Volume of the Left Atrium (mL):") |
| |
|
| | demo = gr.Interface( |
| | fn=lambda fileobj: gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir), |
| | inputs=["file"], |
| | outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text], |
| | examples=[[str(Path.cwd() /"sample.nii.gz")]], |
| | allow_flagging='never') |
| |
|
| | |
| | demo.launch() |