import os os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" import gradio as gr import torch import cv2 import numpy as np from preprocess import unsharp_masking import glob import time device = "cuda" if torch.cuda.is_available() else "cpu" model_paths = { 'SE-RegUNet 4GF': './model/SERegUNet4GF.pt', 'SE-RegUNet 16GF': './model/SERegUNet16GF.pt', 'AngioNet': './model/AngioNet.pt', 'EffUNet++ B5': './model/EffUNetppb5.pt', 'Reg-SA-UNet++': './model/RegSAUnetpp.pt', 'UNet3+': './model/UNet3plus.pt', } scales = [1, 2, 4, 8, 16] print( "torch: ", torch.__version__, ) def filesort(img, model): # img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) ori = img.copy() img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) h, w = img.shape img_out = preprocessing(img, model) return img_out, h, w, ori def preprocessing(img, model='SE-RegUNet 4GF'): # print(img.shape, img.dtype) # img = cv2.resize(img, (512, 512)) img = unsharp_masking(img).astype(np.uint8) if model == 'AngioNet' or model == 'UNet3+': img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) img_out = np.expand_dims(img, axis=0) elif model == 'SE-RegUNet 4GF': clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8,8)) image1 = clahe1.apply(img) image2 = clahe2.apply(img) img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6)) img_out = np.stack((img, image1, image2), axis=0) else: clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) image1 = clahe1.apply(img) image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) img_out = np.stack((image1,)*3, axis=0) return img_out def inference(pipe, img, model): with torch.no_grad(): if model == 'AngioNet': img = torch.cat([img, img], dim=0) logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8) return logit def process_input_image(img, model, scale): ori_img = img.copy() h, w, _ = ori_img.shape pad_h = h % 32 pad_w = w % 32 if pad_h == 0 and pad_w > 0: img = ori_img[:, pad_w//2:-pad_w//2] elif pad_h > 0 and pad_w == 0: img = ori_img[pad_h//2:-pad_h//2, :] elif pad_h > 0 and pad_w > 0: img = ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2] img_out = img.copy() pipe = torch.jit.load(model_paths[model]) pipe = pipe.to(device).eval() scale = int(scale.split('x')[0]) scale_all = scales[:scales.index(scale)+1] start = time.time() logit = np.zeros([img.shape[0], img.shape[1]], np.uint8) for scale in scale_all: if scale == 1: temp_img, _, _, _ = filesort(img, model) temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device) logit += inference(pipe, temp_img, model) else: len_h, len_w = img.shape[0] // scale, img.shape[1] // scale # logit = np.zeros([img.shape[0], img.shape[1]], np.uint8) for x in range(2*scale-1): for y in range(2*scale-1): temp_img, _, _, _ = filesort(img[len_h * x // 2 : (len_h * x // 2) + len_h, len_w * y // 2 : (len_w * y // 2) + len_w], model) temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device) logit[len_h * x // 2 : (len_h * x // 2) + len_h, len_w * y // 2 : (len_w * y // 2) + len_w] += inference(pipe, temp_img, model) spent = time.time() - start spent = f"{spent:.3f} seconds" logit = logit.astype(bool) # img_out = cv2.cvtColor(ori, cv2.COLOR_GRAY2RGB) img_out[logit, 0] = 255 if pad_h == 0 and pad_w == 0: ori_img = img_out elif pad_h == 0 and pad_w > 0: ori_img[:, pad_w//2:-pad_w//2] = img_out elif pad_h > 0 and pad_w == 0: ori_img[pad_h//2:-pad_h//2, :] = img_out elif pad_h > 0 and pad_w > 0: ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2] = img_out return spent, ori_img my_app = gr.Blocks() with my_app: gr.Markdown("Coronary Angiogram Segmentation with Gradio.") gr.Markdown("Author: Ching-Ting Lin, Artificial Intelligence Center, China Medical University Hospital, Taichung City, Taiwan.") with gr.Tabs(): with gr.TabItem("Select your image"): with gr.Row(): with gr.Column(): img_source = gr.Image(label="Please select angiogram.", value='./example/angio.png', height=512, width=512) model_choice = gr.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', 'Reg-SA-UNet++', 'UNet3+'], label='Model', info='Which model to infer?') model_rescale = gr.Dropdown(['1x1', '2x2', '4x4', '8x8', '16x16'], label='Rescale', info='How many batches?') source_image_loader = gr.Button("Vessel Segment") with gr.Column(): time_spent = gr.Label(label="Time Spent (Preprocessing + Inference)") img_output = gr.Image(label="Output Mask") source_image_loader.click( process_input_image, [ img_source, model_choice, model_rescale ], [ time_spent, img_output ] ) my_app.launch(debug=True)