Spaces:
Build error
Build error
| # dependences import os | |
| import torch | |
| import numpy as np | |
| import pandas as pd | |
| from PIL import Image | |
| from torchvision import transforms | |
| from monai.optimizers.lr_scheduler import WarmupCosineSchedule | |
| from monai.networks.nets import Transchex | |
| from monai.config import print_config | |
| from monai.utils import set_determinism | |
| from torch.utils.data import Dataset, DataLoader | |
| from transformers import BertTokenizer | |
| import gradio as gr | |
| # model | |
| device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
| model = Transchex( | |
| in_channels=3, | |
| img_size=(256, 256), | |
| num_classes=14, | |
| patch_size=(32, 32), | |
| num_language_layers=2, | |
| num_vision_layers=2, | |
| num_mixed_layers=2, | |
| ).to(device) | |
| model.load_state_dict(torch.load(r'transchex.pt', map_location=torch.device('cpu'))["state_dict"]) | |
| model.eval() | |
| # preprocess components | |
| preprocess = transforms.Compose( [ | |
| transforms.Resize(256), | |
| transforms.ToTensor(), | |
| transforms.Normalize([0.5, 0.5, 0.5], [0.5, 0.5, 0.5]), | |
| ]) | |
| ## Args | |
| max_seq_length = 512 | |
| tokenizer = BertTokenizer.from_pretrained("bert-base-uncased", do_lower_case=False) | |
| ## encoder | |
| def encode_features(sent, max_seq_length=512, tokenizer=tokenizer): | |
| tokens = tokenizer.tokenize(sent.strip()) | |
| if len(tokens) > max_seq_length - 2: | |
| tokens = tokens[: (max_seq_length - 2)] | |
| tokens = ["[CLS]"] + tokens + ["[SEP]"] | |
| input_ids = tokenizer.convert_tokens_to_ids(tokens) | |
| segment_ids = [0] * len(input_ids) | |
| while len(input_ids) < max_seq_length: | |
| input_ids.append(0) | |
| segment_ids.append(0) | |
| assert len(input_ids) == max_seq_length | |
| assert len(segment_ids) == max_seq_length | |
| return input_ids, segment_ids | |
| # gradio application function | |
| def model_predict(image_address, input_report): | |
| if len(input_report): | |
| report = str(input_report) | |
| report = " ".join(report.split()) | |
| else: | |
| report = "Not available" | |
| report = " ".join(report.split()) | |
| preds_cls = np.zeros((1, 14)) | |
| image = Image.open(image_address) | |
| image = preprocess(image) | |
| image = image.reshape([1,3,256,256]) | |
| input_ids, segment_ids = encode_features(report) | |
| input_ids = torch.tensor(input_ids, dtype=torch.long) | |
| input_ids = input_ids.reshape([1,-1]) | |
| segment_ids = torch.tensor(segment_ids, dtype=torch.long) | |
| segment_ids = segment_ids.reshape([1,-1]) | |
| logits_lang = model( | |
| input_ids=input_ids, vision_feats=image, token_type_ids=segment_ids | |
| ) | |
| prob = torch.sigmoid(logits_lang) | |
| prob = prob.reshape([-1,]) | |
| preds_cls[0,:] = prob.detach().cpu().numpy() | |
| result = "\nResultsfor each class in 14 disease categories\ | |
| :\n\nAtelectasis: {}\nCardiomegaly: {}\nConsolidation: {}\nEdema: \ | |
| {}\nEnlarged-Cardiomediastinum: {}\nFracture: {}\nLung-Lesion: {}\nLung-Opacity: \ | |
| {}\nNo-Finding: {}\nPleural-Effusion: {}\nPleural_Other: {}\nPneumonia: \ | |
| {}\nPneumothorax: {}\nSupport-Devices: {}".format( | |
| prob[0], | |
| prob[1], | |
| prob[2], | |
| prob[3], | |
| prob[4], | |
| prob[5], | |
| prob[6], | |
| prob[7], | |
| prob[8], | |
| prob[9], | |
| prob[10], | |
| prob[11], | |
| prob[12], | |
| prob[13], | |
| ) | |
| return result | |
| # Interface part | |
| with gr.Blocks() as iface: | |
| gr.Markdown( | |
| """ | |
| # Welcome to covid-19 detection demo | |
| This is a model created by gradio for covid-19 detection. The capstone model is TranChex - Based on the original model, I made some modifications to facilitate fine-tuning on more datasets. | |
| > `TransCheX` model consists of vision, language, and mixed-modality transformer layers for processing chest X-ray and their corresponding radiological reports within a unified framework. | |
| I modified the architecture by varying the number of vision, language and mixed-modality layers and customizing the classification head. In addition, I added image preprocessing and more language processing modules. | |
| Finally, the model is pre-trained on the Open-I dataset and fine-tuned on modified and relabeled [SIIM-FISABIO-RSNA COVID-19 Detection](https://www.kaggle.com/competitions/siim-covid19-detection/data) and [VinBigData Chest X-ray Abnormalities Detection ](https://www.kaggle.com/competitions/vinbigdata-chest-xray-abnormalities-detection/data). | |
| ## Components | |
| This demo incorporated with following open source packages | |
| - [MONAI - Home](https://monai.io/): MONAI is a set of open-source, freely available collaborative frameworks built for accelerating research and clinical collaboration in Medical Imaging. The goal is to accelerate the pace of innovation and clinical translation by building a robust software framework that benefits nearly every level of medical imaging, deep learning research, and deployment. In this demo, I used MONAI to build a multimodal model for detection detection, which was implemented in [monai.networks.nets.transchex](https://docs.monai.io/en/latest/_modules/monai/networks/nets/transchex.html), with my own modifications. | |
| - [Hugging Face – The AI community building the future.](https://huggingface.co/) : I used the Bert model of the hugging face for the tokenizer of the text feature encoding part, and I deployed my model demo on the hugging face. | |
| - [scikit-image: Image processing in Python](https://scikit-image.org/): is a collection of algorithms for image processing. I used this module for all the image preprocessing. | |
| - [gradio](https://github.com/gradio-app/gradio): is an open-source Python library that is used to build machine learning and data science demos and web applications. This demo is created by gradio. | |
| - [Modal](https://modal.com/): is a end to end stack for cloud compute. This is used to deploy the model-service for subsequent fine-tuning on more datasets. | |
| - [mage-ai/mage-ai: 🧙 A modern replacement for Airflow.](https://github.com/mage-ai/mage-ai) is a tool to build real-time and batch pipelines to **transform** data using Python, SQL, and R. Run, monitor, and **orchestrate** thousands of pipelines without losing sleep. I'm using this tool to build an online data processing pipeline. | |
| ## Dataset | |
| The Open-I dataset provides a collection of 3,996 radiology reports with 8,121 associated images in PA, AP and lateral views. The 14 finding categories in this work include Atelectasis, Cardiomegaly, Consolidation, Edema, Enlarged-Cardiomediastinum, Fracture, Lung-Lesion, Lung-Opacity, No-Finding, Pleural-Effusion, Pleural-Other, Pneumonia, Pneumothorax and Support-Devices. More information can be found in this [link](https://openi.nlm.nih.gov/faq) | |
| ## How to use it | |
| You can upload Chest X-ray images in any standard form, and attach your Report , of course, Report_input can also be empty, which will slightly affect the final judgment of the model. You can see examples as following. | |
| - Image_input: Just upload your chest CT image(PNG) | |
| - report_input: You can write down your INDICATIONS, FINDINGS and IMPRESSIONS here. See example part for more information. | |
| """) | |
| with gr.Row(): | |
| with gr.Column(): | |
| inp1 = gr.Textbox(label="input_report") | |
| inp2 = gr.Image(label="input_image", type="filepath") | |
| out = gr.Textbox(label="inference result") | |
| with gr.Row(): | |
| btn = gr.Button("Inference now") | |
| btn.click(fn=model_predict, inputs=[inp2, inp1], outputs=out) | |
| gr.Examples( | |
| [['example1.png', 'The cardiac silhouette and mediastinum size are within normal limits. There is no pulmonary edema. There is no focal consolidation. There are no XXXX of a pleural effusion. There is no evidence of pneumothorax'] | |
| ,['example2.png', 'Stable postsurgical changes. Heart XXXX, mediastinum and lung XXXX are unremarkable. Stable calcified small granuloma in left base']], | |
| [inp2, inp1], | |
| out, | |
| model_predict, | |
| cache_examples=True, | |
| ) | |
| gr.Markdown( | |
| """## Todo/Doing | |
| - TODO -- Fine-tuning on more datasets | |
| - Doing -- Add the object detection model, plan to use yolov5 | |
| - Doing -- Build an online machine learning pipeline | |
| """ | |
| ) | |
| iface.launch() | |