Spaces:
Runtime error
Runtime error
| import os | |
| from os.path import splitext | |
| import numpy as np | |
| import sys | |
| import matplotlib.pyplot as plt | |
| import torch | |
| import torchvision | |
| import wget | |
| import gradio as gr | |
| import yolov5 # Make sure YOLOv5 is installed | |
| # Download U-Net weights | |
| segmentationWeightsURL = 'https://huggingface.co/spaces/aritheanalyst/unetdiagnosis/resolve/main/unet.pt' | |
| filename = os.path.basename(segmentationWeightsURL) | |
| if not os.path.exists(filename): | |
| print("Downloading Segmentation Weights from", segmentationWeightsURL) | |
| wget.download(segmentationWeightsURL) | |
| else: | |
| print("Segmentation Weights already present") | |
| torch.cuda.empty_cache() | |
| def load_unet(): | |
| model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False) | |
| model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size) | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| model = torch.nn.DataParallel(model) | |
| model.to(device) | |
| checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))) | |
| model.load_state_dict(checkpoint['state_dict']) | |
| else: | |
| device = torch.device("cpu") | |
| checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location="cpu") | |
| state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()} | |
| model.load_state_dict(state_dict_cpu) | |
| return model, device | |
| def load_yolo(): | |
| model = yolov5.load('yolov5s') # Load YOLOv5 small model | |
| if torch.cuda.is_available(): | |
| device = torch.device("cuda") | |
| else: | |
| device = torch.device("cpu") | |
| model.to(device) | |
| return model, device | |
| def segment(input, model_type="unet"): | |
| if model_type == "yolo": | |
| model, device = load_yolo() | |
| # Process input with YOLO | |
| results = model(input) | |
| mask = results.render()[0] # Get rendered result | |
| else: | |
| model, device = load_unet() | |
| inp = input | |
| x = inp.transpose([2, 0, 1]) | |
| x = np.expand_dims(x, axis=0) | |
| mean = x.mean(axis=(0, 2, 3)) | |
| std = x.std(axis=(0, 2, 3)) | |
| x = x - mean.reshape(1, 3, 1, 1) | |
| x = x / std.reshape(1, 3, 1, 1) | |
| with torch.no_grad(): | |
| x = torch.from_numpy(x).type('torch.FloatTensor').to(device) | |
| output = model(x) | |
| y = output['out'].numpy() | |
| y = y.squeeze() | |
| out = y > 0 | |
| mask = inp.copy() | |
| mask[out] = np.array([0, 0, 255]) | |
| return mask | |
| i = gr.inputs.Image(shape=(112, 112), label="Input Brain MRI") | |
| model_choice = gr.inputs.Dropdown(choices=["unet", "yolo"], label="Model Type") | |
| o = gr.outputs.Image(label="") | |
| examples = [ | |
| ["TCGA_CS_5395_19981004_12.png", "unet"], | |
| ["TCGA_CS_5395_19981004_14.png", "unet"], | |
| ["TCGA_DU_5849_19950405_20.png", "yolo"], | |
| ["TCGA_DU_5849_19950405_24.png", "yolo"], | |
| ["TCGA_DU_5849_19950405_28.png", "unet"], | |
| ] | |
| title = "MRI Segmentation With Artificial Intelligence" | |
| description = "Accurately segmenting brain MRIs into regions of peak interest. Built using the UBNet-Seg Architecture trained on a large dataset of manually annotated brain images." | |
| article = "<p style='text-align: center'></p>" | |
| gr.Interface(segment, [i, model_choice], o, | |
| allow_flagging=False, | |
| description=description, | |
| title=title, | |
| article=article, | |
| examples=examples, | |
| analytics_enabled=False).launch() | |