Spaces:
Runtime error
Runtime error
| import os | |
| os.environ["CUDA_DEVICE_ORDER"]="PCI_BUS_ID" | |
| import gradio as gr | |
| import torch | |
| import cv2 | |
| import numpy as np | |
| from preprocess import unsharp_masking | |
| import glob | |
| import time | |
| device = "cuda" if torch.cuda.is_available() else "cpu" | |
| model_paths = { | |
| 'SE-RegUNet 4GF': './model/SERegUNet4GF.pt', | |
| 'SE-RegUNet 16GF': './model/SERegUNet16GF.pt', | |
| 'AngioNet': './model/AngioNet.pt', | |
| 'EffUNet++ B5': './model/EffUNetppb5.pt', | |
| 'Reg-SA-UNet++': './model/RegSAUnetpp.pt', | |
| 'UNet3+': './model/UNet3plus.pt', | |
| } | |
| scales = [1, 2, 4, 8, 16] | |
| print( | |
| "torch: ", torch.__version__, | |
| ) | |
| def filesort(img, model): | |
| # img = cv2.imread(img_path, cv2.IMREAD_GRAYSCALE) | |
| ori = img.copy() | |
| img = cv2.cvtColor(img, cv2.COLOR_BGR2GRAY) | |
| h, w = img.shape | |
| img_out = preprocessing(img, model) | |
| return img_out, h, w, ori | |
| def preprocessing(img, model='SE-RegUNet 4GF'): | |
| # print(img.shape, img.dtype) | |
| # img = cv2.resize(img, (512, 512)) | |
| img = unsharp_masking(img).astype(np.uint8) | |
| if model == 'AngioNet' or model == 'UNet3+': | |
| img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) | |
| img_out = np.expand_dims(img, axis=0) | |
| elif model == 'SE-RegUNet 4GF': | |
| clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
| clahe2 = cv2.createCLAHE(clipLimit=8.0, tileGridSize=(8,8)) | |
| image1 = clahe1.apply(img) | |
| image2 = clahe2.apply(img) | |
| img = np.float32((img - img.min()) / (img.max() - img.min() + 1e-6)) | |
| image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) | |
| image2 = np.float32((image2 - image2.min()) / (image2.max() - image2.min() + 1e-6)) | |
| img_out = np.stack((img, image1, image2), axis=0) | |
| else: | |
| clahe1 = cv2.createCLAHE(clipLimit=2.0, tileGridSize=(8,8)) | |
| image1 = clahe1.apply(img) | |
| image1 = np.float32((image1 - image1.min()) / (image1.max() - image1.min() + 1e-6)) | |
| img_out = np.stack((image1,)*3, axis=0) | |
| return img_out | |
| def inference(pipe, img, model): | |
| with torch.no_grad(): | |
| if model == 'AngioNet': | |
| img = torch.cat([img, img], dim=0) | |
| logit = np.round(torch.softmax(pipe.forward(img), dim=1).detach().cpu().numpy()[0, 0]).astype(np.uint8) | |
| return logit | |
| def process_input_image(img, model, scale): | |
| ori_img = img.copy() | |
| h, w, _ = ori_img.shape | |
| pad_h = h % 32 | |
| pad_w = w % 32 | |
| if pad_h == 0 and pad_w > 0: | |
| img = ori_img[:, pad_w//2:-pad_w//2] | |
| elif pad_h > 0 and pad_w == 0: | |
| img = ori_img[pad_h//2:-pad_h//2, :] | |
| elif pad_h > 0 and pad_w > 0: | |
| img = ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2] | |
| img_out = img.copy() | |
| pipe = torch.jit.load(model_paths[model]) | |
| pipe = pipe.to(device).eval() | |
| scale = int(scale.split('x')[0]) | |
| scale_all = scales[:scales.index(scale)+1] | |
| start = time.time() | |
| logit = np.zeros([img.shape[0], img.shape[1]], np.uint8) | |
| for scale in scale_all: | |
| if scale == 1: | |
| temp_img, _, _, _ = filesort(img, model) | |
| temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device) | |
| logit += inference(pipe, temp_img, model) | |
| else: | |
| len_h, len_w = img.shape[0] // scale, img.shape[1] // scale | |
| # logit = np.zeros([img.shape[0], img.shape[1]], np.uint8) | |
| for x in range(2*scale-1): | |
| for y in range(2*scale-1): | |
| temp_img, _, _, _ = filesort(img[len_h * x // 2 : (len_h * x // 2) + len_h, | |
| len_w * y // 2 : (len_w * y // 2) + len_w], model) | |
| temp_img = torch.FloatTensor(temp_img).unsqueeze(0).to(device) | |
| logit[len_h * x // 2 : (len_h * x // 2) + len_h, | |
| len_w * y // 2 : (len_w * y // 2) + len_w] += inference(pipe, temp_img, model) | |
| spent = time.time() - start | |
| spent = f"{spent:.3f} seconds" | |
| logit = logit.astype(bool) | |
| # img_out = cv2.cvtColor(ori, cv2.COLOR_GRAY2RGB) | |
| img_out[logit, 0] = 255 | |
| if pad_h == 0 and pad_w == 0: | |
| ori_img = img_out | |
| elif pad_h == 0 and pad_w > 0: | |
| ori_img[:, pad_w//2:-pad_w//2] = img_out | |
| elif pad_h > 0 and pad_w == 0: | |
| ori_img[pad_h//2:-pad_h//2, :] = img_out | |
| elif pad_h > 0 and pad_w > 0: | |
| ori_img[pad_h//2:-pad_h//2, pad_w//2:-pad_w//2] = img_out | |
| return spent, ori_img | |
| my_app = gr.Blocks() | |
| with my_app: | |
| gr.Markdown("Coronary Angiogram Segmentation with Gradio.") | |
| gr.Markdown("Author: Ching-Ting Lin, Artificial Intelligence Center, China Medical University Hospital, Taichung City, Taiwan.") | |
| with gr.Tabs(): | |
| with gr.TabItem("Select your image"): | |
| with gr.Row(): | |
| with gr.Column(): | |
| img_source = gr.Image(label="Please select angiogram.", value='./example/angio.png', height=512, width=512) | |
| model_choice = gr.Dropdown(['SE-RegUNet 4GF', 'SE-RegUNet 16GF', 'AngioNet', 'EffUNet++ B5', | |
| 'Reg-SA-UNet++', 'UNet3+'], label='Model', info='Which model to infer?') | |
| model_rescale = gr.Dropdown(['1x1', '2x2', '4x4', '8x8', '16x16'], label='Rescale', info='How many batches?') | |
| source_image_loader = gr.Button("Vessel Segment") | |
| with gr.Column(): | |
| time_spent = gr.Label(label="Time Spent (Preprocessing + Inference)") | |
| img_output = gr.Image(label="Output Mask") | |
| source_image_loader.click( | |
| process_input_image, | |
| [ | |
| img_source, | |
| model_choice, | |
| model_rescale | |
| ], | |
| [ | |
| time_spent, | |
| img_output | |
| ] | |
| ) | |
| my_app.launch(debug=True) |