| --- |
| library_name: transformers |
| tags: |
| - chest_x_ray |
| - x_ray |
| - medical_imaging |
| - radiology |
| - segmentation |
| - classification |
| - lungs |
| - heart |
| base_model: |
| - timm/tf_efficientnetv2_s.in21k_ft_in1k |
| pipeline_tag: image-segmentation |
| --- |
| |
| This model performs both segmentation and classification on chest radiographs (X-rays). |
| The model uses a `tf_efficientnetv2_s` backbone with a U-Net decoder for segmentation and linear layer for classification. |
| For frontal radiographs, the model segments the: 1) right lung, 2) left lung, and 3) heart. |
| The model also predicts the chest X-ray view (AP, PA, lateral), patient age, and patient sex. |
| The [CheXpert](https://stanfordmlgroup.github.io/competitions/chexpert/) (small version) and [NIH Chest X-ray](https://nihcc.app.box.com/v/ChestXray-NIHCC) datasets were used to train the model. |
| Segmentation masks were obtained from the CheXmask [dataset](https://physionet.org/content/chexmask-cxr-segmentation-data/0.4/) ([paper](https://www.nature.com/articles/s41597-024-03358-1)). |
| The final dataset comprised 335,516 images from 96,385 patients and was split into 80% training/20% validation. A holdout test set was not used since minimal tuning was performed. |
| The view classifier was trained only on CheXpert images (NIH images excluded from loss function), given that lateral radiographs are only present in CheXpert. |
| This is to avoid unwanted bias in the model, which can occur if one class originates only from a single dataset. |
|
|
| Validation performance as follows: |
| ``` |
| Segmentation (Dice similarity coefficient): |
| Right Lung: 0.957 |
| Left Lung: 0.948 |
| Heart: 0.943 |
| |
| Age Prediction: |
| Mean Absolute Error: 5.25 years |
| |
| Classification: |
| View (AP, PA, lateral): 99.42% accuracy |
| Female: 0.999 AUC |
| ``` |
|
|
| To use the model: |
| ``` |
| import cv2 |
| import torch |
| from transformers import AutoModel |
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| |
| model = AutoModel.from_pretrained("ianpan/chest-x-ray-basic", trust_remote_code=True) |
| model = model.eval().to(device) |
| img = cv2.imread(..., 0) |
| x = model.preprocess(img) # only takes single image as input |
| x = torch.from_numpy(x).unsqueeze(0).unsqueeze(0) # add channel, batch dims |
| x = x.float() |
| |
| with torch.inference_mode(): |
| out = model(x.to(device)) |
| ``` |
|
|
| The output is a dictionary which contains 4 keys: |
| * `mask` has 3 channels containing the segmentation masks. Take the argmax over the channel dimension to create a single image mask (i.e., `out["mask"].argmax(1)`): 1 = right lung, 2 = left lung, 3 = heart. |
| * `age`, in years. |
| * `view`, with 3 classes for each possible view. Take the argmax to select the predicted view (i.e., `out["view"].argmax(1)`): 0 = AP, 1 = PA, 2 = lateral. |
| * `female`, binarize with `out["female"] >= 0.5`. |
|
|
| You can use the segmentation mask to crop the region containing the lungs from the rest of the X-ray. |
| You can also calculate the [cardiothoracic ratio (CTR)](https://radiopaedia.org/articles/cardiothoracic-ratio?lang=us) using this function: |
| ``` |
| import numpy as np |
| |
| def calculate_ctr(mask): # single mask with dims (height, width) |
| lungs = np.zeros_like(mask) |
| lungs[mask == 1] = 1 |
| lungs[mask == 2] = 1 |
| heart = (mask == 3).astype("int") |
| y, x = np.stack(np.where(lungs == 1)) |
| lung_min = x.min() |
| lung_max = x.max() |
| y, x = np.stack(np.where(heart == 1)) |
| heart_min = x.min() |
| heart_max = x.max() |
| lung_range = lung_max - lung_min |
| heart_range = heart_max - heart_min |
| return heart_range / lung_range |
| ``` |
|
|
| If you have `pydicom` installed, you can also load a DICOM image directly: |
| ``` |
| img = model.load_image_from_dicom(path_to_dicom) |
| ``` |
| |
| This model is for demonstration and research purposes only and has NOT been approved by any regulatory agency for clinical use. |
| The user assumes any and all responsibility regarding their own use of this model and its outputs. |