File size: 8,301 Bytes
97a59bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2d54fe9
97a59bb
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
d10fa6c
 
 
 
 
 
 
 
71e4d24
d10fa6c
 
557305f
 
 
 
 
 
 
71e4d24
d10fa6c
557305f
d10fa6c
 
 
557305f
 
 
 
d10fa6c
 
 
557305f
d10fa6c
 
557305f
36805f1
d10fa6c
 
 
 
 
557305f
 
 
 
 
 
 
 
 
71e4d24
 
 
 
 
 
 
 
 
557305f
71e4d24
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
# dependences import os
import torch
import numpy as np
import pandas as pd
from PIL import Image
from torchvision import transforms
from monai.optimizers.lr_scheduler import WarmupCosineSchedule
from monai.networks.nets import Transchex
from monai.config import print_config
from monai.utils import set_determinism
from torch.utils.data import Dataset, DataLoader
from transformers import BertTokenizer

import gradio as gr

# model 
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
model = Transchex(
    in_channels=3,
    img_size=(256, 256),
    num_classes=14,
    patch_size=(32, 32),
    num_language_layers=2,
    num_vision_layers=2,
    num_mixed_layers=2,
).to(device)

model.load_state_dict(torch.load(r'transchex.pt', map_location=torch.device('cpu'))["state_dict"])
model.eval()

# preprocess components
preprocess = transforms.Compose(            [
                transforms.Resize(256),
                transforms.ToTensor(),
                transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]),
            ])

## Args 
max_seq_length = 512
tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False)

## encoder
def encode_features(sent, max_seq_length=512, tokenizer=tokenizer):
    tokens = tokenizer.tokenize(sent.strip())
    if len(tokens) > max_seq_length - 2:
        tokens = tokens[: (max_seq_length - 2)]
    tokens = ["[CLS]"] + tokens + ["[SEP]"]
    input_ids = tokenizer.convert_tokens_to_ids(tokens)
    segment_ids = [0] * len(input_ids)
    while len(input_ids) < max_seq_length:
        input_ids.append(0)
        segment_ids.append(0)
    assert len(input_ids) == max_seq_length
    assert len(segment_ids) == max_seq_length
    return input_ids, segment_ids

# gradio application function 
def model_predict(image_address, input_report):
    if len(input_report): 
        report = str(input_report)
        report = " ".join(report.split())
    else: 
        report = "Not available"
        report = " ".join(report.split())
        
    preds_cls = np.zeros((1, 14))
    image = Image.open(image_address)
    image = preprocess(image)
    image = image.reshape([1,3,256,256])
    
    input_ids, segment_ids = encode_features(report)

    input_ids = torch.tensor(input_ids, dtype=torch.long)
    input_ids = input_ids.reshape([1,-1])
    segment_ids = torch.tensor(segment_ids, dtype=torch.long)
    segment_ids = segment_ids.reshape([1,-1])
    logits_lang = model(
        input_ids=input_ids, vision_feats=image, token_type_ids=segment_ids
    )
    prob = torch.sigmoid(logits_lang)
    prob = prob.reshape([-1,])
    preds_cls[0,:] = prob.detach().cpu().numpy()
    
    result = "\nResultsfor each class in 14 disease categories\
    :\n\nAtelectasis: {}\nCardiomegaly: {}\nConsolidation: {}\nEdema: \
    {}\nEnlarged-Cardiomediastinum: {}\nFracture: {}\nLung-Lesion: {}\nLung-Opacity: \
    {}\nNo-Finding: {}\nPleural-Effusion: {}\nPleural_Other: {}\nPneumonia: \
    {}\nPneumothorax: {}\nSupport-Devices: {}".format(
        prob[0],
        prob[1],
        prob[2],
        prob[3],
        prob[4],
        prob[5],
        prob[6],
        prob[7],
        prob[8],
        prob[9],
        prob[10],
        prob[11],
        prob[12],
        prob[13],
        )
        
    return result

# Interface part 

with gr.Blocks() as iface: 
    gr.Markdown(
    """
    # Welcome to covid-19 detection demo 
    This is a model created by gradio for covid-19 detection. The capstone model is TranChex - Based on the original model, I made some modifications to facilitate fine-tuning on more datasets. 
    > `TransCheX` model consists of vision, language, and mixed-modality transformer layers for processing chest X-ray and their corresponding radiological reports within a unified framework. 

    I modified the architecture by varying the number of vision, language and mixed-modality layers and customizing the classification head. In addition, I added image preprocessing and more language processing modules.

    Finally, the model is pre-trained on the Open-I dataset and fine-tuned on modified and relabeled [SIIM-FISABIO-RSNA COVID-19 Detection](https://www.kaggle.com/competitions/siim-covid19-detection/data) and [VinBigData Chest X-ray Abnormalities Detection ](https://www.kaggle.com/competitions/vinbigdata-chest-xray-abnormalities-detection/data).

    ## Components
    This demo incorporated with following open source packages 
    - [MONAI - Home](https://monai.io/): MONAI is a set of open-source, freely available collaborative frameworks built for accelerating research and clinical collaboration in Medical Imaging. The goal is to accelerate the pace of innovation and clinical translation by building a robust software framework that benefits nearly every level of medical imaging, deep learning research, and deployment. In this demo, I used MONAI to build a multimodal model for detection detection, which was implemented in [monai.networks.nets.transchex](https://docs.monai.io/en/latest/_modules/monai/networks/nets/transchex.html), with my own modifications.
    - [Hugging Face – The AI community building the future.](https://huggingface.co/) : I used the Bert model of the hugging face for the tokenizer of the text feature encoding part, and I deployed my model demo on the hugging face.
    - [scikit-image: Image processing in Python](https://scikit-image.org/): is a collection of algorithms for image processing. I used this module for all the image preprocessing.
    - [gradio](https://github.com/gradio-app/gradio):  is an open-source Python library that is used to build machine learning and data science demos and web applications. This demo is created by gradio. 
    - [Modal](https://modal.com/): is a end to end stack for cloud compute. This is used to deploy the model-service for subsequent fine-tuning on more datasets.
    - [mage-ai/mage-ai: 🧙 A modern replacement for Airflow.](https://github.com/mage-ai/mage-ai) is a tool to build real-time and batch pipelines to **transform** data using Python, SQL, and R. Run, monitor, and **orchestrate** thousands of pipelines without losing sleep. I'm using this tool to build an online data processing pipeline.

    ## Dataset 
    The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in this [link](https://openi.nlm.nih.gov/faq)

    ## How to use it 
    You can upload Chest X-ray images in any standard form, and attach your Report , of course, Report_input can also be empty, which will slightly affect the final judgment of the model. You can see examples as following.
    - Image_input: Just upload your chest CT image(PNG)
    - report_input: You can write down your INDICATIONS, FINDINGS and IMPRESSIONS here. See example part for more information.
    """)
    
    with gr.Row(): 
        with gr.Column():
            inp1 = gr.Textbox(label="input_report")
            inp2 = gr.Image(label="input_image", type="filepath")
        
        out = gr.Textbox(label="inference result")
        
        
    with gr.Row():
        btn = gr.Button("Inference now")
    
    btn.click(fn=model_predict, inputs=[inp2, inp1], outputs=out)        
        
    gr.Examples(
        [['example1.png', 'The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax']
         ,['example2.png', 'Stable postsurgical changes. Heart XXXX, mediastinum and lung XXXX are unremarkable. Stable calcified small granuloma in left base']],
         [inp2, inp1],
         out,
         model_predict,
         cache_examples=True,
    )
    
    gr.Markdown(
    """## Todo/Doing
    - TODO -- Fine-tuning on more datasets 
    - Doing -- Add the object detection model, plan to use yolov5 
    - Doing -- Build an online machine learning pipeline

    """
    )
        
iface.launch()