DHEIVER's picture
Update app.py
b9913f3
import os
from os.path import splitext
import numpy as np
import sys
import matplotlib.pyplot as plt
import torch
import torchvision
import wget
destination_folder = "output"
destination_for_weights = "weights"
if os.path.exists(destination_for_weights):
print("Os pesos estão em", destination_for_weights)
else:
print("Criando pasta em ", destination_for_weights, " para armazenar os pesos")
os.mkdir(destination_for_weights)
segmentationWeightsURL = 'https://github.com/douyang/EchoNetDynamic/releases/download/v1.0.0/deeplabv3_resnet50_random.pt'
if not os.path.exists(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL))):
print("Baixando pesos de segmentação, ", segmentationWeightsURL," para ",os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
filename = wget.download(segmentationWeightsURL, out = destination_for_weights)
else:
print("Os pesos de segmentação já estão presentes")
torch.cuda.empty_cache()
def collate_fn(x):
x, f = zip(*x)
i = list(map(lambda t: t.shape[1], x))
x = torch.as_tensor(np.swapaxes(np.concatenate(x, 1), 0, 1))
return x, f, i
model = torchvision.models.segmentation.deeplabv3_resnet50(pretrained=False, aux_loss=False)
model.classifier[-1] = torch.nn.Conv2d(model.classifier[-1].in_channels, 1, kernel_size=model.classifier[-1].kernel_size)
print("Carregando pesos de ", os.path.join(destination_for_weights, "deeplabv3_resnet50_random"))
if torch.cuda.is_available():
print("cuda está disponível, pesos originais")
device = torch.device("cuda")
model = torch.nn.DataParallel(model)
model.to(device)
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)))
model.load_state_dict(checkpoint['state_dict'])
else:
print("cuda não está disponível, pesos da CPU")
device = torch.device("cpu")
checkpoint = torch.load(os.path.join(destination_for_weights, os.path.basename(segmentationWeightsURL)), map_location = "cpu")
state_dict_cpu = {k[7:]: v for (k, v) in checkpoint['state_dict'].items()}
model.load_state_dict(state_dict_cpu)
model.eval()
def segment(input):
inp = input
x = inp.transpose([2, 0, 1])
x = np.expand_dims(x, axis=0)
mean = x.mean(axis=(0, 2, 3))
std = x.std(axis=(0, 2, 3))
x = x - mean.reshape(1, 3, 1, 1)
x = x / std.reshape(1, 3, 1, 1)
with torch.no_grad():
x = torch.from_numpy(x).type('torch.FloatTensor').to(device)
output = model(x)
y = output['out'].numpy()
y = y.squeeze()
out = y>0
mask = inp.copy()
mask[out] = np.array([0, 0, 255])
return mask
import gradio as gr
i = gr.inputs.Image(shape=(112, 112), label="Imagem de Ressonância Magnética do Cérebro de Entrada")
o = gr.outputs.Image(label="Resultado da Segmentação")
examples = [["TCGA_CS_5395_19981004_12.png"],
["TCGA_CS_5395_19981004_14.png"],
["TCGA_DU_5849_19950405_20.png"],
["TCGA_DU_5849_19950405_24.png"],
["TCGA_DU_5849_19950405_28.png"]]
title = "Sistema de Segmentação de Imagens de Ressonância Magnética do Cérebro baseado em Inteligência Artificial"
description = "Este sistema foi desenvolvido para automatizar o processo de segmentação precisa e eficiente de imagens de Ressonância Magnética do cérebro em regiões de interesse. Ele utiliza a arquitetura UBNet-Seg, que foi treinada em um grande conjunto de dados de imagens de cérebros com anotações manuais."
article = "<p style='text-align: center'>Criado por <a target='_blank' href='https://fi.ub.ac.id/'>Jurusan Fisika, FMIPA, Universitas Brawijaya </a></p>"
gr.Interface(segment, i, o,
allow_flagging = False,
description = description,
title = title,
article = article,
examples = examples,
analytics_enabled = False).launch()