Spaces:
Runtime error
Runtime error
| # AUTOGENERATED! DO NOT EDIT! File to edit: ../main.ipynb. | |
| # %% auto 0 | |
| __all__ = ['ORGAN', 'IMAGE_SIZE', 'MODEL_NAME', 'THRESHOLD', 'CODES', 'learn', 'title', 'description', 'examples', | |
| 'interpretation', 'demo', 'x_getter', 'y_getter', 'splitter', 'make3D', 'predict', 'infer', | |
| 'remove_small_segs', 'to_oberlay_image'] | |
| # %% ../main.ipynb 1 | |
| import numpy as np | |
| import pandas as pd | |
| import skimage | |
| from fastai.vision.all import * | |
| import segmentation_models_pytorch as smp | |
| import gradio as gr | |
| # %% ../main.ipynb 2 | |
| ORGAN = "kidney" | |
| IMAGE_SIZE = 512 | |
| MODEL_NAME = "unetpp_b4_th60_d9414.pkl" | |
| THRESHOLD = float(MODEL_NAME.split("_")[2][2:]) / 100. | |
| CODES = ["Background", "FTU"] # FTU = functional tissue unit | |
| # %% ../main.ipynb 3 | |
| def x_getter(r): return r["fnames"] | |
| def y_getter(r): | |
| rle = r["rle"] | |
| shape = (int(r["img_height"]), int(r["img_width"])) | |
| return rle_decode(rle, shape).T | |
| def splitter(model): | |
| enc_params = L(model.encoder.parameters()) | |
| dec_params = L(model.decoder.parameters()) | |
| sg_params = L(model.segmentation_head.parameters()) | |
| untrained_params = L([*dec_params, *sg_params]) | |
| return L([enc_params, untrained_params]) | |
| # %% ../main.ipynb 4 | |
| learn = load_learner(MODEL_NAME) | |
| # %% ../main.ipynb 5 | |
| def make3D(t: np.array) -> np.array: | |
| t = np.expand_dims(t, axis=2) | |
| t = np.concatenate((t,t,t), axis=2) | |
| return t | |
| def predict(fn, cutoff_area=200): | |
| data = infer(fn) | |
| data = remove_small_segs(data, cutoff_area=cutoff_area) | |
| return to_oberlay_image(data), data["df"] | |
| def infer(fn): | |
| img = PILImage.create(fn) | |
| tf_img,_,_,preds = learn.predict(img, with_input=True) | |
| mask = (F.softmax(preds.float(), dim=0)>THRESHOLD).int()[1] | |
| mask = np.array(mask, dtype=np.uint8) | |
| resized_image = Image.fromarray(tf_img.numpy().transpose(1, 2, 0).astype(np.uint8)).resize(img.shape) | |
| resized_image = np.array(resized_image) | |
| return { | |
| "tf_image": tf_img.numpy().transpose(1, 2, 0).astype(np.uint8), | |
| "tf_mask": mask | |
| } | |
| def remove_small_segs(data, cutoff_area=250): | |
| labeled_mask = skimage.measure.label(data["tf_mask"]) | |
| props = skimage.measure.regionprops(labeled_mask) | |
| df = {"Glomerulus":[], "Area (in px)":[]} | |
| for i, prop in enumerate(props): | |
| if prop.area < cutoff_area: | |
| labeled_mask[labeled_mask==i+1] = 0 | |
| continue | |
| df["Glomerulus"].append(len(df["Glomerulus"]) + 1) | |
| df["Area (in px)"].append(prop.area) | |
| labeled_mask[labeled_mask>0] = 1 | |
| data["tf_mask"] = labeled_mask.astype(np.uint8) | |
| data["df"] = pd.DataFrame(df) | |
| return data | |
| def to_oberlay_image(data): | |
| img, msk = data["tf_image"], data["tf_mask"] | |
| msk_im = np.zeros_like(img) | |
| # rgb code: 255, 80, 80 | |
| msk_im[:,:,0] = 255 | |
| msk_im[:,:,1] = 80 | |
| msk_im[:,:,2] = 80 | |
| img = Image.fromarray(img).convert("RGBA") | |
| msk_im = Image.fromarray(msk_im).convert("RGBA") | |
| msk = Image.fromarray((msk*255*0.5).astype(np.uint8)) | |
| img.paste(msk_im, (0, 0), msk, ) | |
| return img | |
| # %% ../main.ipynb 6 | |
| title = "Segmentação de Glomérulos" | |
| description = """ | |
| Um aplicativo web que segmenta glomérulos em cortes histológicos de rim! | |
| O modelo implantado aqui é um UNet++ com um codificador efficientnet-b4 da biblioteca segmentation_models_pytorch. | |
| As imagens de exemplo fornecidas são um subconjunto aleatório de cortes de rim do Atlas de Proteínas Humanas. Essas imagens foram coletadas separadamente do treinamento do modelo e não fizeram parte do conjunto de treinamento, validação ou teste. | |
| """ | |
| #article="<p style='text-align: center'><a href='Blog post URL' target='_blank'>Blog post</a></p>" | |
| examples = [str(p) for p in get_image_files("example_images")] | |
| interpretation='default' | |
| # %% ../main.ipynb 7 | |
| demo = gr.Interface( | |
| fn=predict, | |
| inputs=gr.components.Image(shape=(IMAGE_SIZE, IMAGE_SIZE)), | |
| outputs=[gr.components.Image(), gr.components.DataFrame()], | |
| title=title, | |
| description=description, | |
| examples=examples, | |
| interpretation=interpretation, | |
| # Fixes error when set to True: | |
| # https://github.com/gradio-app/gradio/pull/1949 | |
| # but generated file names are too long | |
| _api_mode=False | |
| ) | |
| # %% ../main.ipynb 9 | |
| demo.launch() | |