# dependences import os import torch import numpy as np import pandas as pd from PIL import Image from torchvision import transforms from monai.optimizers.lr_scheduler import WarmupCosineSchedule from monai.networks.nets import Transchex from monai.config import print_config from monai.utils import set_determinism from torch.utils.data import Dataset, DataLoader from transformers import BertTokenizer import gradio as gr # model device = torch.device("cuda" if torch.cuda.is_available() else "cpu") model = Transchex( in_channels=3, img_size=(256, 256), num_classes=14, patch_size=(32, 32), num_language_layers=2, num_vision_layers=2, num_mixed_layers=2, ).to(device) model.load_state_dict(torch.load(r'transchex.pt', map_location=torch.device('cpu'))["state_dict"]) model.eval() # preprocess components preprocess = transforms.Compose( [ transforms.Resize(256), transforms.ToTensor(), transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), ]) ## Args max_seq_length = 512 tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False) ## encoder def encode_features(sent, max_seq_length=512, tokenizer=tokenizer): tokens = tokenizer.tokenize(sent.strip()) if len(tokens) > max_seq_length - 2: tokens = tokens[: (max_seq_length - 2)] tokens = ["[CLS]"] + tokens + ["[SEP]"] input_ids = tokenizer.convert_tokens_to_ids(tokens) segment_ids = [0] * len(input_ids) while len(input_ids) < max_seq_length: input_ids.append(0) segment_ids.append(0) assert len(input_ids) == max_seq_length assert len(segment_ids) == max_seq_length return input_ids, segment_ids # gradio application function def model_predict(image_address, input_report): if len(input_report): report = str(input_report) report = " ".join(report.split()) else: report = "Not available" report = " ".join(report.split()) preds_cls = np.zeros((1, 14)) image = Image.open(image_address) image = preprocess(image) image = image.reshape([1,3,256,256]) input_ids, segment_ids = encode_features(report) input_ids = torch.tensor(input_ids, dtype=torch.long) input_ids = input_ids.reshape([1,-1]) segment_ids = torch.tensor(segment_ids, dtype=torch.long) segment_ids = segment_ids.reshape([1,-1]) logits_lang = model( input_ids=input_ids, vision_feats=image, token_type_ids=segment_ids ) prob = torch.sigmoid(logits_lang) prob = prob.reshape([-1,]) preds_cls[0,:] = prob.detach().cpu().numpy() result = "\nResultsfor each class in 14 disease categories\ :\n\nAtelectasis: {}\nCardiomegaly: {}\nConsolidation: {}\nEdema: \ {}\nEnlarged-Cardiomediastinum: {}\nFracture: {}\nLung-Lesion: {}\nLung-Opacity: \ {}\nNo-Finding: {}\nPleural-Effusion: {}\nPleural_Other: {}\nPneumonia: \ {}\nPneumothorax: {}\nSupport-Devices: {}".format( prob[0], prob[1], prob[2], prob[3], prob[4], prob[5], prob[6], prob[7], prob[8], prob[9], prob[10], prob[11], prob[12], prob[13], ) return result # Interface part with gr.Blocks() as iface: gr.Markdown( """ # Welcome to covid-19 detection demo This is a model created by gradio for covid-19 detection. The capstone model is TranChex - Based on the original model, I made some modifications to facilitate fine-tuning on more datasets. > `TransCheX` model consists of vision, language, and mixed-modality transformer layers for processing chest X-ray and their corresponding radiological reports within a unified framework. I modified the architecture by varying the number of vision, language and mixed-modality layers and customizing the classification head. In addition, I added image preprocessing and more language processing modules. Finally, the model is pre-trained on the Open-I dataset and fine-tuned on modified and relabeled [SIIM-FISABIO-RSNA COVID-19 Detection](https://www.kaggle.com/competitions/siim-covid19-detection/data) and [VinBigData Chest X-ray Abnormalities Detection ](https://www.kaggle.com/competitions/vinbigdata-chest-xray-abnormalities-detection/data). ## Components This demo incorporated with following open source packages - [MONAI - Home](https://monai.io/): MONAI is a set of open-source, freely available collaborative frameworks built for accelerating research and clinical collaboration in Medical Imaging. The goal is to accelerate the pace of innovation and clinical translation by building a robust software framework that benefits nearly every level of medical imaging, deep learning research, and deployment. In this demo, I used MONAI to build a multimodal model for detection detection, which was implemented in [monai.networks.nets.transchex](https://docs.monai.io/en/latest/_modules/monai/networks/nets/transchex.html), with my own modifications. - [Hugging Face – The AI community building the future.](https://huggingface.co/) : I used the Bert model of the hugging face for the tokenizer of the text feature encoding part, and I deployed my model demo on the hugging face. - [scikit-image: Image processing in Python](https://scikit-image.org/): is a collection of algorithms for image processing. I used this module for all the image preprocessing. - [gradio](https://github.com/gradio-app/gradio):  is an open-source Python library that is used to build machine learning and data science demos and web applications. This demo is created by gradio. - [Modal](https://modal.com/): is a end to end stack for cloud compute. This is used to deploy the model-service for subsequent fine-tuning on more datasets. - [mage-ai/mage-ai: 🧙 A modern replacement for Airflow.](https://github.com/mage-ai/mage-ai) is a tool to build real-time and batch pipelines to **transform** data using Python, SQL, and R. Run, monitor, and **orchestrate** thousands of pipelines without losing sleep. I'm using this tool to build an online data processing pipeline. ## Dataset The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in this [link](https://openi.nlm.nih.gov/faq) ## How to use it You can upload Chest X-ray images in any standard form, and attach your Report , of course, Report_input can also be empty, which will slightly affect the final judgment of the model. You can see examples as following. - Image_input: Just upload your chest CT image(PNG) - report_input: You can write down your INDICATIONS, FINDINGS and IMPRESSIONS here. See example part for more information. """) with gr.Row(): with gr.Column(): inp1 = gr.Textbox(label="input_report") inp2 = gr.Image(label="input_image", type="filepath") out = gr.Textbox(label="inference result") with gr.Row(): btn = gr.Button("Inference now") btn.click(fn=model_predict, inputs=[inp2, inp1], outputs=out) gr.Examples( [['example1.png', 'The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax'] ,['example2.png', 'Stable postsurgical changes. Heart XXXX, mediastinum and lung XXXX are unremarkable. Stable calcified small granuloma in left base']], [inp2, inp1], out, model_predict, cache_examples=True, ) gr.Markdown( """## Todo/Doing - TODO -- Fine-tuning on more datasets - Doing -- Add the object detection model, plan to use yolov5 - Doing -- Build an online machine learning pipeline """ ) iface.launch()