Spaces:
Running
Running
| from pipline import Transformer_Regression, extract_regions_Last , compute_ratios | |
| import torch | |
| import torchvision.transforms as transforms | |
| from torch.nn import functional as F | |
| import cv2 | |
| import gradio as gr | |
| import numpy as np | |
| from PIL import Image | |
| device = torch.device('cuda' if torch.cuda.is_available() else 'cpu') | |
| ## Define some parameters | |
| image_shape = 384 #### 512 got 87 | |
| batch_size=1 | |
| dim_patch=4 | |
| num_classes=3 | |
| label_smoothing=0.1 | |
| scale=1 | |
| import time | |
| start = time.time() | |
| torch.manual_seed(0) | |
| #import random | |
| tfms = transforms.Compose([ | |
| transforms.Resize((image_shape, image_shape)), | |
| transforms.ToTensor(), | |
| transforms.Normalize(0.5,0.5) | |
| #transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), | |
| #transforms.Normalize(mean=(0.48145466, 0.4578275, 0.40821073), std=(0.26862954, 0.26130258, 0.27577711)) | |
| ]) | |
| def Final_Compute_regression_results_Sample(Model, batch_sampler,num_head=2): | |
| Model.eval() | |
| score_cup = [] | |
| score_disc = [] | |
| yreg_pred = [] | |
| yreg_true = [] | |
| with torch.no_grad(): | |
| #for batch_sampler in loader: | |
| train_batch_tfms = batch_sampler['image'].to(device=device) | |
| #ytrue_seg = batch_sampler['image_original'] #.detach().cpu().numpy() | |
| ytrue_seg = batch_sampler['image_original'] # .detach().cpu().numpy() | |
| scores = Model(train_batch_tfms.unsqueeze(0)) | |
| yseg_pred = F.interpolate(scores['seg'], size=(ytrue_seg.shape[0], ytrue_seg.shape[1]), mode='bilinear', | |
| align_corners=True) | |
| # Regions_crop=extract_regions_Last(np.array(batch_sampler['image_original'][0]),yseg_pred[0].detach().cpu().numpy()) | |
| Regions_crop = extract_regions_Last(np.array(batch_sampler['image_original']), | |
| yseg_pred.argmax(1).long()[0].detach().cpu().numpy()) | |
| Regions_crop['image'] = Image.fromarray(np.uint8(Regions_crop['image'])).convert('RGB') | |
| ### Get back if two heads | |
| ytrue_seg_crop = ytrue_seg[Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
| Regions_crop['cord'][2]:Regions_crop['cord'][3]] | |
| ytrue_seg_crop = np.expand_dims(ytrue_seg_crop, axis=0) | |
| if num_head==2: | |
| scores = Model((tfms(Regions_crop['image']).unsqueeze(0)).to(device)) | |
| yseg_pred_crop = F.interpolate(scores['seg_aux_1'], size=(ytrue_seg_crop.shape[1], ytrue_seg_crop.shape[2]), | |
| mode='bilinear', align_corners=True) | |
| yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
| Regions_crop['cord'][2]:Regions_crop['cord'][3]] = yseg_pred_crop | |
| # yseg_pred[:, :, Regions_crop['cord'][0]:Regions_crop['cord'][1], | |
| # Regions_crop['cord'][2]:Regions_crop['cord'][3]]+yseg_pred_crop | |
| yseg_pred = torch.softmax(yseg_pred, dim=1) | |
| yseg_pred = yseg_pred.argmax(1).long() | |
| yseg_pred = ((yseg_pred).long()).detach().cpu().numpy() | |
| ratios = compute_ratios(yseg_pred[0]) | |
| yreg_pred.append(ratios.vcdr) | |
| ### Plot | |
| p_img = batch_sampler['image'].to(device=device).unsqueeze(0) | |
| p_img = F.interpolate(p_img, size=(yseg_pred.shape[1], yseg_pred.shape[2]), | |
| mode='bilinear', align_corners=True) | |
| ### Get reversed image | |
| image_orig = (p_img[0] * 0.5 + 0.5).permute(1, 2, 0).detach().cpu().numpy() | |
| image_orig=np.uint8(image_orig*255) | |
| #### | |
| # train_batch_tfms | |
| #plt.imshow(image_orig) | |
| # make a copy as these operations are destructive | |
| image_cont = image_orig.copy() | |
| ###### plot for Prediction.... | |
| # threshold for 2 value | |
| ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 1, 2, 0) | |
| # find and draw contour for 2 value (red) | |
| conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| cv2.drawContours(image_cont, conts, -1, (0, 255, 0), 2) | |
| #threshold for 1 value | |
| ret, thresh = cv2.threshold(np.uint8(yseg_pred[0]), 0, 2, 0) | |
| #find and draw contour for 1 value (blue) | |
| conts, hir = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_NONE) | |
| cv2.drawContours(image_cont, conts, -1, (0, 0, 255), 2) | |
| #plot contoured image | |
| # plt.imshow(image_cont) | |
| # plt.axis('off') | |
| # print('Vertical cup to disc ratio:') | |
| # print(ratios.vcdr) | |
| if ratios.vcdr < 0.6: | |
| glaucoma = 'None' | |
| else: | |
| glaucoma = 'May be there is a risk of Glaucoma' | |
| # print('Galucoma:') | |
| return image_cont, ratios.vcdr, glaucoma, Regions_crop | |
| #load model | |
| DeepLab=Transformer_Regression(image_dim=image_shape,dim_patch=dim_patch,num_classes=3,scale=scale,feat_dim=128) | |
| DeepLab.to(device=device) | |
| DeepLab.load_state_dict(torch.load("TrainAll_Maghrabi84_50iteration_SWIN.pth.tar", map_location=torch.device(device))) | |
| def infer(img): | |
| # img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) | |
| sample_batch = dict() | |
| sample_batch['image_original'] = img | |
| im_retina_pil = Image.fromarray(img) | |
| im_retina_pil = tfms(im_retina_pil) | |
| sample_batch['image'] = im_retina_pil | |
| # plt.figure('Head2') | |
| result, ratio, diagnosis, cropped = Final_Compute_regression_results_Sample(DeepLab, sample_batch, num_head=2) | |
| # cropped = cv2.cvtColor(np.asarray(cropped), cv2.COLOR_BGR2RGB) | |
| cropped = result[cropped['cord'][0] :cropped['cord'][1] , | |
| cropped['cord'][2] :cropped['cord'][3] ] | |
| return ratio, diagnosis, result, cropped | |
| title = "Glaucoma Detection in Retinal Fundus Images" | |
| description = "The method detects disc and cup in the retinal image, then it computes the Vertical cup to disc ratio" | |
| outputs = [gr.Textbox(label="Vertical cup to disc ratio:"), gr.Textbox(label="predicted diagnosis (Rule of thumb ~0.6 or greater is suspicious)"), gr.Image(label='labeled image'), gr.Image(label='zoomed in')] | |
| with gr.Blocks(css='#title {text-align : center;} ') as demo: | |
| with gr.Row(): | |
| gr.Markdown( | |
| f''' | |
| # {title} | |
| {description} | |
| ''', | |
| elem_id='title' | |
| ) | |
| with gr.Row(): | |
| with gr.Column(): | |
| prompt = gr.Image(label="Upload Your Retinal Fundus Image") | |
| btn = gr.Button(value='Submit') | |
| examples = gr.Examples( | |
| ['M00027.png','M00056.png','M00073.png','M00093.png', 'M00018.png', 'M00034.png'], | |
| inputs=[prompt], fn=infer, outputs=[outputs], cache_examples=False) | |
| with gr.Column(): | |
| with gr.Row(): | |
| text1 = gr.Textbox(label="Vertical Cup to Disc Ratio:") | |
| text2 = gr.Textbox(label="Predicted Diagnosis (Rule of thumb ~0.6 or greater is suspicious)") | |
| img = gr.Image(label='Detected disc and cup') | |
| zoom = gr.Image(label='Croppped') | |
| outputs = [text1,text2,img,zoom] | |
| btn.click(fn=infer, inputs=prompt, outputs=outputs) | |
| if __name__ == '__main__': | |
| demo.launch() |