Ari
Update app.py
084f1ac
import os
from os.path import splitext
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torchvision
import wget
import gradio as gr
import yolov5 # Make sure YOLOv5 is installed
# Download U-Net weights
segmentationWeightsURL = 'https://huggingface.co/spaces/aritheanalyst/unetdiagnosis/resolve/main/unet.pt'
filename = os.path.basename(segmentationWeightsURL)
if not os.path.exists(filename):
print("Downloading Segmentation Weights from", segmentationWeightsURL)
wget.download(segmentationWeightsURL)
else:
print("Segmentation Weights already present")
torch.cuda.empty_cache()
def load_unet():
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
if torch.cuda.is_available():
device = torch.device("cuda")
model = torch.nn.DataParallel(model)
model.to(device)
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
model.load_state_dict(checkpoint['state_dict'])
else:
device = torch.device("cpu")
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location="cpu")
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict_cpu)
return model, device
def load_yolo():
model = yolov5.load('yolov5s') # Load YOLOv5 small model
if torch.cuda.is_available():
device = torch.device("cuda")
else:
device = torch.device("cpu")
model.to(device)
return model, device
def segment(input, model_type="unet"):
if model_type == "yolo":
model, device = load_yolo()
# Process input with YOLO
results = model(input)
mask = results.render()[0] # Get rendered result
else:
model, device = load_unet()
inp = input
x = inp.transpose([2, 0, 1])
x = np.expand_dims(x, axis=0)
mean = x.mean(axis=(0, 2, 3))
std = x.std(axis=(0, 2, 3))
x = x - mean.reshape(1, 3, 1, 1)
x = x / std.reshape(1, 3, 1, 1)
with torch.no_grad():
x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
output = model(x)
y = output['out'].numpy()
y = y.squeeze()
out = y > 0
mask = inp.copy()
mask[out] = np.array([0, 0, 255])
return mask
i = gr.inputs.Image(shape=(112, 112), label="Input Brain MRI")
model_choice = gr.inputs.Dropdown(choices=["unet", "yolo"], label="Model Type")
o = gr.outputs.Image(label="")
examples = [
["TCGA_CS_5395_19981004_12.png", "unet"],
["TCGA_CS_5395_19981004_14.png", "unet"],
["TCGA_DU_5849_19950405_20.png", "yolo"],
["TCGA_DU_5849_19950405_24.png", "yolo"],
["TCGA_DU_5849_19950405_28.png", "unet"],
]
title = "MRI Segmentation With Artificial Intelligence"
description = "Accurately segmenting brain MRIs into regions of peak interest. Built using the UBNet-Seg Architecture trained on a large dataset of manually annotated brain images."
article = "<p style='text-align: center'></p>"
gr.Interface(segment, [i, model_choice], o,
allow_flagging=False,
description=description,
title=title,
article=article,
examples=examples,
analytics_enabled=False).launch()