File size: 6,252 Bytes
c34fb59
febd0e1
 
 
c34fb59
 
 
3abb5ec
2064622
 
 
 
c34fb59
 
 
b0bd5fc
 
 
 
 
 
 
 
 
 
 
 
 
 
c34fb59
 
 
 
5638d3f
 
c34fb59
 
 
 
 
 
5638d3f
c34fb59
 
5638d3f
c34fb59
 
5638d3f
c34fb59
 
 
 
 
5638d3f
c34fb59
5638d3f
c34fb59
 
5638d3f
 
 
 
 
 
 
c34fb59
 
 
 
 
5638d3f
c34fb59
 
 
 
 
 
 
 
 
5638d3f
c34fb59
5638d3f
 
 
 
b0bd5fc
c34fb59
b0bd5fc
c34fb59
 
 
490c5a6
c34fb59
5638d3f
b0bd5fc
c34fb59
5638d3f
490c5a6
b0bd5fc
c34fb59
f37e491
b0bd5fc
 
 
 
5638d3f
c34fb59
5638d3f
b0bd5fc
5638d3f
b0bd5fc
5638d3f
 
c34fb59
 
5638d3f
27298f5
 
c34fb59
b0bd5fc
5638d3f
b0bd5fc
5638d3f
 
 
f37e491
c34fb59
5638d3f
2fecb9a
5638d3f
b0bd5fc
 
 
5638d3f
2064622
 
3abb5ec
 
 
36d522b
b0bd5fc
 
904a4b1
b0bd5fc
805baa7
b0bd5fc
 
 
 
 
 
c34fb59
5638d3f
b0bd5fc
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
import cv2
import numpy as np
from pathlib import Path
from huggingface_hub import snapshot_download
from fastMONAI.vision_all import *
from fastMONAI.vision_inference import load_system_resources, inference, compute_binary_tumor_volume
import sys

# Debug: List all symbols imported from fastMONAI.vision_all
print("[DEBUG] fastMONAI.vision_all symbols:", dir())
from git import Repo
import os

#Additional support for local execution:-
#import pathlib
#temp = pathlib.PosixPath
#pathlib.PosixPath = pathlib.WindowsPath
#pathlib.PosixPath = temp

clone_dir = Path.cwd() / 'clone_dir'
URI = os.getenv('PAT_Token_URI')

if os.path.exists(clone_dir):
    pass
else:
    Repo.clone_from(URI, clone_dir)

def extract_slices_from_mask(img, mask_data, view):
    """Extract and resize slices from the 3D [W, H, D] image and mask data based on the selected view."""
    slices = []
    target_size = (320, 320)
    
    for idx in range(img.shape[2] if view == "Sagittal" else img.shape[1] if view == "Axial" else img.shape[0]):
        if view == "Sagittal":
            slice_img, slice_mask = img[:, :, idx], mask_data[:, :, idx]
        elif view == "Axial":
            slice_img, slice_mask = img[:, idx, :], mask_data[:, idx, :]
        elif view == "Coronal":
            slice_img, slice_mask = img[idx, :, :], mask_data[idx, :, :]
        
        slice_img = np.fliplr(np.rot90(slice_img, -1))
        slice_mask = np.fliplr(np.rot90(slice_mask, -1))

        slice_img_resized, slice_mask_resized = resize_and_pad(slice_img, slice_mask, target_size)
        slices.append((slice_img_resized, slice_mask_resized))
    
    return slices

def resize_and_pad(slice_img, slice_mask, target_size):
    """Resize and pad the image and mask to fit the target size while maintaining the aspect ratio."""
    h, w = slice_img.shape
    scale = min(target_size[0] / w, target_size[1] / h)
    new_w, new_h = int(w * scale), int(h * scale)
    
    resized_img = cv2.resize(slice_img, (new_w, new_h), interpolation=cv2.INTER_LINEAR)
    resized_mask = cv2.resize(slice_mask, (new_w, new_h), interpolation=cv2.INTER_NEAREST)

    pad_w = (target_size[0] - new_w) // 2
    pad_h = (target_size[1] - new_h) // 2

    padded_img = np.pad(resized_img, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)
    padded_mask = np.pad(resized_mask, ((pad_h, target_size[1] - new_h - pad_h), (pad_w, target_size[0] - new_w - pad_w)), mode='constant', constant_values=0)

    return padded_img, padded_mask

def normalize_image(slice_img):
    """Normalize the image to the range [0, 255] safely."""
    slice_img_min, slice_img_max = slice_img.min(), slice_img.max()
    if slice_img_min == slice_img_max:  # Avoid division by zero
        return np.zeros_like(slice_img, dtype=np.uint8)
    normalized_img = (slice_img - slice_img_min) / (slice_img_max - slice_img_min) * 255
    return normalized_img.astype(np.uint8)

def get_fused_image(img, pred_mask, view, alpha=0.8):
    """Fuse a grayscale image with a mask overlay and flip both horizontally and vertically."""
    gray_img_colored = cv2.cvtColor(img, cv2.COLOR_GRAY2BGR)
    mask_color = np.array([255, 0, 0])
    colored_mask = (pred_mask[..., None] * mask_color).astype(np.uint8)
    
    fused = cv2.addWeighted(gray_img_colored, alpha, colored_mask, 1 - alpha, 0)
    
    # Flip the fused image vertically and horizontally
    fused_flipped = cv2.flip(fused, -1)  # Flip both vertically and horizontally
    
    if view=='Sagittal':
        return fused_flipped
    elif view=='Coronal' or 'Axial':
        rotated = cv2.flip(cv2.rotate(fused, cv2.ROTATE_90_COUNTERCLOCKWISE), 1)
        return rotated

def gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view):
    """Predict function using the learner and other resources."""

    if view == None:
        view = 'Sagittal'
    
    img_path = Path(fileobj.name)

    save_fn = 'pred_' + img_path.stem
    save_path = save_dir / save_fn
    org_img, input_img, org_size = med_img_reader(img_path, 
                                                  reorder=reorder,
                                                  resample=resample,
                                                  only_tensor=False)
    
    mask_data = inference(learn, reorder=reorder, resample=resample,
                          org_img=org_img, input_img=input_img,
                          org_size=org_size).data
    
    if "".join(org_img.orientation) == "LSA":        
        mask_data = mask_data.permute(0,1,3,2)
        mask_data = torch.flip(mask_data[0], dims=[1])
        mask_data = torch.Tensor(mask_data)[None]

    img = org_img.data
    org_img.set_data(mask_data)
    org_img.save(save_path)

    slices = extract_slices_from_mask(img[0], mask_data[0], view)
    fused_images = [(get_fused_image(
                        normalize_image(slice_img),  # Normalize safely
                        slice_mask, view))
                    for slice_img, slice_mask in slices]
    
    volume = compute_binary_tumor_volume(org_img)

    return fused_images, round(volume, 2)

# Initialize the system
models_path = Path.cwd() / 'clone_dir'
save_dir = Path.cwd() / 'hs_pred'
save_dir.mkdir(parents=True, exist_ok=True)


# Debug: Check if load_system_resources is defined
learn, reorder, resample = load_system_resources(models_path=models_path,
                                                 learner_fn='heart_model.pkl',
                                                 variables_fn='vars.pkl')

# Gradio interface setup
output_text = gr.Textbox(label="Volume of the Left Atrium (mL):")

view_selector = gr.Radio(choices=["Axial", "Coronal", "Sagittal"], value='Sagittal', label="Select View (Sagittal by default)")

demo = gr.Interface(
    fn=lambda fileobj, view='Sagittal': gradio_image_segmentation(fileobj, learn, reorder, resample, save_dir, view),
    inputs=["file", view_selector],
    outputs=[gr.Gallery(label="Click an Image, and use Arrow Keys to scroll slices", columns=3, height=450), output_text],
    examples=[[str(Path.cwd() /"sample.nii.gz")]],
    allow_flagging='never')

# Launch the Gradio interface
demo.launch()