File size: 5,958 Bytes
cd59216
 
 
 
 
 
 
 
 
 
 
 
f23a0e9
 
 
 
 
 
 
 
b0b2697
cd59216
 
 
 
 
 
 
 
 
 
 
070d37c
cd59216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
070d37c
 
 
 
 
 
 
 
cd59216
 
 
 
733f80d
 
 
 
38a7cbb
733f80d
b0b2697
cd59216
f23a0e9
cd59216
 
eafaea5
b0b2697
070d37c
cd59216
b0b2697
 
 
53ccf33
 
 
b0b2697
 
a17b102
b0b2697
 
 
 
 
 
 
cd59216
 
 
 
 
 
38a7cbb
 
 
733f80d
 
 
38a7cbb
733f80d
cd59216
 
 
 
 
 
 
 
 
 
 
5ea03ca
cd59216
 
070d37c
cd59216
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
import os
os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID"

import gradio as gr
import torch
import cv2
import numpy as np
from preprocess import unsharp_masking
import glob
import time

device = "cuda" if torch.cuda.is_available() else "cpu"
model_paths = {
               'SE-RegUNet 4GF': './model/SERegUNet4GF.pt',
               'SE-RegUNet 16GF': './model/SERegUNet16GF.pt',
               'AngioNet': './model/AngioNet.pt',
               'EffUNet++ B5': './model/EffUNetppb5.pt',
               'Reg-SA-UNet++': './model/RegSAUnetpp.pt',
               'UNet3+': './model/UNet3plus.pt',
              }
scales = [1, 2, 4, 8, 16]

print(
    "torch: ", torch.__version__,
)

def filesort(img, model):
    # img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE)
    ori = img.copy()
    img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY)
    h, w = img.shape
    img_out = preprocessing(img, model)
    return img_out, h, w, ori

def preprocessing(img, model='SE-RegUNet 4GF'):
    # print(img.shape, img.dtype)
    # img = cv2.resize(img, (512, 512))
    img = unsharp_masking(img).astype(np.uint8)
    if model == 'AngioNet' or model == 'UNet3+':
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        img_out = np.expand_dims(img, axis=0)
    elif model == 'SE-RegUNet 4GF':
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8,8))
        image1 = clahe1.apply(img)
        image2 = clahe2.apply(img)
        img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6))
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6))
        img_out = np.stack((img, image1, image2), axis=0)
    else:
        clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8))
        image1 = clahe1.apply(img)
        image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6))
        img_out = np.stack((image1,)*3, axis=0)
    return img_out

def inference(pipe, img, model):
    with torch.no_grad():
        if model == 'AngioNet':
            img = torch.cat([img, img], dim=0)
        logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8)   
    return logit

def process_input_image(img, model, scale):
    ori_img = img.copy()
    h, w, _ = ori_img.shape
    pad_h = h % 32
    pad_w = w % 32
    if pad_h == 0 and pad_w > 0:
        img = ori_img[:, pad_w//2:-pad_w//2]
    elif pad_h > 0 and pad_w == 0:
        img = ori_img[pad_h//2:-pad_h//2, :]
    elif pad_h > 0 and pad_w > 0:
        img = ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2]
    img_out = img.copy()
    
    pipe = torch.jit.load(model_paths[model])
    pipe = pipe.to(device).eval()
    
    scale = int(scale.split('x')[0])
    scale_all = scales[:scales.index(scale)+1]
    
    start = time.time()
    logit = np.zeros([img.shape[0], img.shape[1]], np.uint8)
    for scale in scale_all:
        if scale == 1:
            temp_img, _, _, _ = filesort(img, model)
            temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device)
            logit += inference(pipe, temp_img, model)
        else:
            len_h, len_w = img.shape[0] // scale, img.shape[1] // scale
            # logit = np.zeros([img.shape[0], img.shape[1]], np.uint8)
            for x in range(2*scale-1):
                for y in range(2*scale-1):
                    temp_img, _, _, _ = filesort(img[len_h * x // 2 : (len_h * x // 2) + len_h, 
                                                     len_w * y // 2 : (len_w * y // 2) + len_w], model)
                    temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device)
                    logit[len_h * x // 2 : (len_h * x // 2) + len_h, 
                          len_w * y // 2 : (len_w * y // 2) + len_w] += inference(pipe, temp_img, model)
    spent = time.time() - start
    spent = f"{spent:.3f} seconds"
    
    logit = logit.astype(bool)
    # img_out = cv2.cvtColor(ori, cv2.COLOR_GRAY2RGB)
    img_out[logit, 0] = 255
    if pad_h == 0 and pad_w == 0:
        ori_img = img_out
    elif pad_h == 0 and pad_w > 0:
        ori_img[:, pad_w//2:-pad_w//2] = img_out
    elif pad_h > 0 and pad_w == 0:
        ori_img[pad_h//2:-pad_h//2, :] = img_out
    elif pad_h > 0 and pad_w > 0:
        ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2] = img_out
    return spent, ori_img

    
my_app = gr.Blocks()
with my_app:
    gr.Markdown("Coronary Angiogram Segmentation with Gradio.")
    gr.Markdown("Author: Ching-Ting Lin, Artificial Intelligence Center, China Medical University Hospital, Taichung City, Taiwan.")
    with gr.Tabs():
        with gr.TabItem("Select your image"):
            with gr.Row():
                with gr.Column():
                    img_source = gr.Image(label="Please select angiogram.", value='./example/angio.png', height=512, width=512)
                    model_choice = gr.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 
                                                'Reg-SA-UNet++', 'UNet3+'], label='Model', info='Which model to infer?')
                    model_rescale = gr.Dropdown(['1x1', '2x2', '4x4', '8x8', '16x16'], label='Rescale', info='How many batches?')
                    source_image_loader = gr.Button("Vessel Segment")
                with gr.Column():
                    time_spent = gr.Label(label="Time Spent (Preprocessing + Inference)")
                    img_output = gr.Image(label="Output Mask")
                    
        source_image_loader.click(
            process_input_image,
            [
                img_source,
                model_choice,
                model_rescale
            ],
            [
                time_spent,
                img_output
            ]
        )
    
my_app.launch(debug=True)