Spaces:
Build error
Build error
File size: 8,301 Bytes
97a59bb 2d54fe9 97a59bb d10fa6c 71e4d24 d10fa6c 557305f 71e4d24 d10fa6c 557305f d10fa6c 557305f d10fa6c 557305f d10fa6c 557305f 36805f1 d10fa6c 557305f 71e4d24 557305f 71e4d24 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 | # dependences import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.networks.nets import Transchex
from monai.config import print_config
from monai.utils import set_determinism
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer
import gradio as gr
# model
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transchex(
in_channels=3,
img_size=(256, 256),
num_classes=14,
patch_size=(32, 32),
num_language_layers=2,
num_vision_layers=2,
num_mixed_layers=2,
).to(device)
model.load_state_dict(torch.load(r'transchex.pt', map_location=torch.device('cpu'))["state_dict"])
model.eval()
# preprocess components
preprocess = transforms.Compose( [
transforms.Resize(256),
transforms.ToTensor(),
transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
])
## Args
max_seq_length = 512
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)
## encoder
def encode_features(sent, max_seq_length=512, tokenizer=tokenizer):
tokens = tokenizer.tokenize(sent.strip())
if len(tokens) > max_seq_length - 2:
tokens = tokens[: (max_seq_length - 2)]
tokens = ["[CLS]"] + tokens + ["[SEP]"]
input_ids = tokenizer.convert_tokens_to_ids(tokens)
segment_ids = [0] * len(input_ids)
while len(input_ids) < max_seq_length:
input_ids.append(0)
segment_ids.append(0)
assert len(input_ids) == max_seq_length
assert len(segment_ids) == max_seq_length
return input_ids, segment_ids
# gradio application function
def model_predict(image_address, input_report):
if len(input_report):
report = str(input_report)
report = " ".join(report.split())
else:
report = "Not available"
report = " ".join(report.split())
preds_cls = np.zeros((1, 14))
image = Image.open(image_address)
image = preprocess(image)
image = image.reshape([1,3,256,256])
input_ids, segment_ids = encode_features(report)
input_ids = torch.tensor(input_ids, dtype=torch.long)
input_ids = input_ids.reshape([1,-1])
segment_ids = torch.tensor(segment_ids, dtype=torch.long)
segment_ids = segment_ids.reshape([1,-1])
logits_lang = model(
input_ids=input_ids, vision_feats=image, token_type_ids=segment_ids
)
prob = torch.sigmoid(logits_lang)
prob = prob.reshape([-1,])
preds_cls[0,:] = prob.detach().cpu().numpy()
result = "\nResultsfor each class in 14 disease categories\
:\n\nAtelectasis: {}\nCardiomegaly: {}\nConsolidation: {}\nEdema: \
{}\nEnlarged-Cardiomediastinum: {}\nFracture: {}\nLung-Lesion: {}\nLung-Opacity: \
{}\nNo-Finding: {}\nPleural-Effusion: {}\nPleural_Other: {}\nPneumonia: \
{}\nPneumothorax: {}\nSupport-Devices: {}".format(
prob[0],
prob[1],
prob[2],
prob[3],
prob[4],
prob[5],
prob[6],
prob[7],
prob[8],
prob[9],
prob[10],
prob[11],
prob[12],
prob[13],
)
return result
# Interface part
with gr.Blocks() as iface:
gr.Markdown(
"""
# Welcome to covid-19 detection demo
This is a model created by gradio for covid-19 detection. The capstone model is TranChex - Based on the original model, I made some modifications to facilitate fine-tuning on more datasets.
> `TransCheX` model consists of vision, language, and mixed-modality transformer layers for processing chest X-ray and their corresponding radiological reports within a unified framework.
I modified the architecture by varying the number of vision, language and mixed-modality layers and customizing the classification head. In addition, I added image preprocessing and more language processing modules.
Finally, the model is pre-trained on the Open-I dataset and fine-tuned on modified and relabeled [SIIM-FISABIO-RSNA COVID-19 Detection](https://www.kaggle.com/competitions/siim-covid19-detection/data) and [VinBigData Chest X-ray Abnormalities Detection ](https://www.kaggle.com/competitions/vinbigdata-chest-xray-abnormalities-detection/data).
## Components
This demo incorporated with following open source packages
- [MONAI - Home](https://monai.io/): MONAI is a set of open-source, freely available collaborative frameworks built for accelerating research and clinical collaboration in Medical Imaging. The goal is to accelerate the pace of innovation and clinical translation by building a robust software framework that benefits nearly every level of medical imaging, deep learning research, and deployment. In this demo, I used MONAI to build a multimodal model for detection detection, which was implemented in [monai.networks.nets.transchex](https://docs.monai.io/en/latest/_modules/monai/networks/nets/transchex.html), with my own modifications.
- [Hugging Face – The AI community building the future.](https://huggingface.co/) : I used the Bert model of the hugging face for the tokenizer of the text feature encoding part, and I deployed my model demo on the hugging face.
- [scikit-image: Image processing in Python](https://scikit-image.org/): is a collection of algorithms for image processing. I used this module for all the image preprocessing.
- [gradio](https://github.com/gradio-app/gradio): is an open-source Python library that is used to build machine learning and data science demos and web applications. This demo is created by gradio.
- [Modal](https://modal.com/): is a end to end stack for cloud compute. This is used to deploy the model-service for subsequent fine-tuning on more datasets.
- [mage-ai/mage-ai: 🧙 A modern replacement for Airflow.](https://github.com/mage-ai/mage-ai) is a tool to build real-time and batch pipelines to **transform** data using Python, SQL, and R. Run, monitor, and **orchestrate** thousands of pipelines without losing sleep. I'm using this tool to build an online data processing pipeline.
## Dataset
The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in this [link](https://openi.nlm.nih.gov/faq)
## How to use it
You can upload Chest X-ray images in any standard form, and attach your Report , of course, Report_input can also be empty, which will slightly affect the final judgment of the model. You can see examples as following.
- Image_input: Just upload your chest CT image(PNG)
- report_input: You can write down your INDICATIONS, FINDINGS and IMPRESSIONS here. See example part for more information.
""")
with gr.Row():
with gr.Column():
inp1 = gr.Textbox(label="input_report")
inp2 = gr.Image(label="input_image", type="filepath")
out = gr.Textbox(label="inference result")
with gr.Row():
btn = gr.Button("Inference now")
btn.click(fn=model_predict, inputs=[inp2, inp1], outputs=out)
gr.Examples(
[['example1.png', 'The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax']
,['example2.png', 'Stable postsurgical changes. Heart XXXX, mediastinum and lung XXXX are unremarkable. Stable calcified small granuloma in left base']],
[inp2, inp1],
out,
model_predict,
cache_examples=True,
)
gr.Markdown(
"""## Todo/Doing
- TODO -- Fine-tuning on more datasets
- Doing -- Add the object detection model, plan to use yolov5
- Doing -- Build an online machine learning pipeline
"""
)
iface.launch()
|