File size: 3,557 Bytes
dc75c07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35f3b91
 
 
dc75c07
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
import gradio as gr
import torch
import cv2
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from fastMONAI.vision_all import *
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
#pathlib.PosixPath = temp

def extract_slices_from_mask(img, mask_data):
    """Extract all slices from the 3D [W, H, D] image and mask data."""
    slices = []
    for idx in range(img.shape[-1]):
        slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
        slice_img = np.fliplr(np.rot90(slice_img, -1))
        slice_mask = np.fliplr(np.rot90(slice_mask, -1))
        slices.append((slice_img, slice_mask))
    return slices

def get_fused_image(img, pred_mask, alpha=0.8):
    """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
    gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    mask_color = np.array([255, 0, 0])
    colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
    
    fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
    
    # Flip the fused image vertically and horizontally
    fused_flipped = cv2.flip(fused, -1)  # Flip both vertically and horizontally
    
    return fused_flipped

def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir):
    """Predict function using the learner and other resources."""
    img_path = Path(fileobj.name)

    save_fn = 'pred_' + img_path.stem
    save_path = save_dir / save_fn
    org_img, input_img, org_size = med_img_reader(img_path, 
                                                  reorder=reorder,
                                                  resample=resample,
                                                  only_tensor=False)
    
    mask_data = inference(learn, reorder=reorder, resample=resample,
                          org_img=org_img, input_img=input_img,
                          org_size=org_size).data
    
    if "".join(org_img.orientation) == "LSA":        
        mask_data = mask_data.permute(0,1,3,2)
        mask_data = torch.flip(mask_data[0], dims=[1])
        mask_data = torch.Tensor(mask_data)[None]

    img = org_img.data
    org_img.set_data(mask_data)
    org_img.save(save_path)

    slices = extract_slices_from_mask(img[0], mask_data[0])
    fused_images = [(get_fused_image(
                        ((slice_img - slice_img.min()) / (slice_img.max() - slice_img.min()) * 255).astype(np.uint8), 
                        slice_mask))
                    for slice_img, slice_mask in slices]
    
    volume = compute_binary_tumor_volume(org_img)

    return fused_images, round(volume, 2)

# Initialize the system
models_path = Path.cwd()
save_dir = Path.cwd() / 'hs_pred'
save_dir.mkdir(parents=True, exist_ok=True)

# Load the model and other required resources
learn, reorder, resample = load_system_resources(models_path=Path.cwd(),
                                                 learner_fn='heart_model.pkl',
                                                 variables_fn='vars.pkl')

# Gradio interface setup
output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")

demo = gr.Interface(
    fn=lambda fileobj: gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir),
    inputs=["file"],
    outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
    examples=[[str(Path.cwd() /"sample.nii.gz")]],
    allow_flagging='never')

# Launch the Gradio interface
demo.launch()