DylanLi's picture
Update app.py
71e4d24
# dependences import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.networks.nets import Transchex
from monai.config import print_config
from monai.utils import set_determinism
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import gradio as gr
# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transchex(
in_channels=3,
img_size=(256, 256),
num_classes=14,
patch_size=(32, 32),
num_language_layers=2,
num_vision_layers=2,
num_mixed_layers=2,
).to(device)
model.load_state_dict(torch.load(r'transchex.pt', map_location=torch.device('cpu'))["state_dict"])
model.eval()
# preprocess components
preprocess = transforms.Compose( [
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
## Args
max_seq_length = 512
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
## encoder
def encode_features(sent, max_seq_length=512, tokenizer=tokenizer):
tokens = tokenizer.tokenize(sent.strip())
if len(tokens) > max_seq_length - 2:
tokens = tokens[: (max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(segment_ids) == max_seq_length
return input_ids, segment_ids
# gradio application function
def model_predict(image_address, input_report):
if len(input_report):
report = str(input_report)
report = " ".join(report.split())
else:
report = "Not available"
report = " ".join(report.split())
preds_cls = np.zeros((1, 14))
image = Image.open(image_address)
image = preprocess(image)
image = image.reshape([1,3,256,256])
input_ids, segment_ids = encode_features(report)
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_ids = input_ids.reshape([1,-1])
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
segment_ids = segment_ids.reshape([1,-1])
logits_lang = model(
input_ids=input_ids, vision_feats=image, token_type_ids=segment_ids
)
prob = torch.sigmoid(logits_lang)
prob = prob.reshape([-1,])
preds_cls[0,:] = prob.detach().cpu().numpy()
result = "\nResultsfor each class in 14 disease categories\
:\n\nAtelectasis: {}\nCardiomegaly: {}\nConsolidation: {}\nEdema: \
{}\nEnlarged-Cardiomediastinum: {}\nFracture: {}\nLung-Lesion: {}\nLung-Opacity: \
{}\nNo-Finding: {}\nPleural-Effusion: {}\nPleural_Other: {}\nPneumonia: \
{}\nPneumothorax: {}\nSupport-Devices: {}".format(
prob[0],
prob[1],
prob[2],
prob[3],
prob[4],
prob[5],
prob[6],
prob[7],
prob[8],
prob[9],
prob[10],
prob[11],
prob[12],
prob[13],
)
return result
# Interface part
with gr.Blocks() as iface:
gr.Markdown(
"""
# Welcome to covid-19 detection demo
This is a model created by gradio for covid-19 detection. The capstone model is TranChex - Based on the original model, I made some modifications to facilitate fine-tuning on more datasets.
> `TransCheX` model consists of vision, language, and mixed-modality transformer layers for processing chest X-ray and their corresponding radiological reports within a unified framework.
I modified the architecture by varying the number of vision, language and mixed-modality layers and customizing the classification head. In addition, I added image preprocessing and more language processing modules.
Finally, the model is pre-trained on the Open-I dataset and fine-tuned on modified and relabeled [SIIM-FISABIO-RSNA COVID-19 Detection](https://www.kaggle.com/competitions/siim-covid19-detection/data) and [VinBigData Chest X-ray Abnormalities Detection ](https://www.kaggle.com/competitions/vinbigdata-chest-xray-abnormalities-detection/data).
## Components
This demo incorporated with following open source packages
- [MONAI - Home](https://monai.io/): MONAI is a set of open-source, freely available collaborative frameworks built for accelerating research and clinical collaboration in Medical Imaging. The goal is to accelerate the pace of innovation and clinical translation by building a robust software framework that benefits nearly every level of medical imaging, deep learning research, and deployment. In this demo, I used MONAI to build a multimodal model for detection detection, which was implemented in [monai.networks.nets.transchex](https://docs.monai.io/en/latest/_modules/monai/networks/nets/transchex.html), with my own modifications.
- [Hugging Face – The AI community building the future.](https://huggingface.co/) : I used the Bert model of the hugging face for the tokenizer of the text feature encoding part, and I deployed my model demo on the hugging face.
- [scikit-image: Image processing in Python](https://scikit-image.org/): is a collection of algorithms for image processing. I used this module for all the image preprocessing.
- [gradio](https://github.com/gradio-app/gradio):  is an open-source Python library that is used to build machine learning and data science demos and web applications. This demo is created by gradio.
- [Modal](https://modal.com/): is a end to end stack for cloud compute. This is used to deploy the model-service for subsequent fine-tuning on more datasets.
- [mage-ai/mage-ai: 🧙 A modern replacement for Airflow.](https://github.com/mage-ai/mage-ai) is a tool to build real-time and batch pipelines to **transform** data using Python, SQL, and R. Run, monitor, and **orchestrate** thousands of pipelines without losing sleep. I'm using this tool to build an online data processing pipeline.
## Dataset
The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in this [link](https://openi.nlm.nih.gov/faq)
## How to use it
You can upload Chest X-ray images in any standard form, and attach your Report , of course, Report_input can also be empty, which will slightly affect the final judgment of the model. You can see examples as following.
- Image_input: Just upload your chest CT image(PNG)
- report_input: You can write down your INDICATIONS, FINDINGS and IMPRESSIONS here. See example part for more information.
""")
with gr.Row():
with gr.Column():
inp1 = gr.Textbox(label="input_report")
inp2 = gr.Image(label="input_image", type="filepath")
out = gr.Textbox(label="inference result")
with gr.Row():
btn = gr.Button("Inference now")
btn.click(fn=model_predict, inputs=[inp2, inp1], outputs=out)
gr.Examples(
[['example1.png', 'The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax']
,['example2.png', 'Stable postsurgical changes. Heart XXXX, mediastinum and lung XXXX are unremarkable. Stable calcified small granuloma in left base']],
[inp2, inp1],
out,
model_predict,
cache_examples=True,
)
gr.Markdown(
"""## Todo/Doing
- TODO -- Fine-tuning on more datasets
- Doing -- Add the object detection model, plan to use yolov5
- Doing -- Build an online machine learning pipeline
"""
)
iface.launch()