|
|
--- |
|
|
library_name: transformers |
|
|
tags: |
|
|
- mammography |
|
|
- cancer |
|
|
- breast_cancer |
|
|
- radiology |
|
|
- breast_density |
|
|
license: apache-2.0 |
|
|
base_model: |
|
|
- timm/tf_efficientnetv2_s.in21k_ft_in1k |
|
|
pipeline_tag: image-classification |
|
|
--- |
|
|
|
|
|
This is an ensemble model for predicting breast cancer and breast density based on screening mammography. |
|
|
The model uses 3 basic CNNs (`tf_efficientnetv2_s` backbone) and performs inference on each provided image (i.e., CC and MLO view). |
|
|
Each net in the ensemble uses a different resolution: 2048 x 1024, 1920 x 1280, and 1536 x 1536. |
|
|
The final outputs are averaged together across the provided views and the neural nets. |
|
|
The model can also perform inference on a single view (image), although performance will be decreased. |
|
|
|
|
|
A hybrid classification-segmentation model was first pretrained on the Curated Breast Imaging Subset of Digital Database for Screening Mammography |
|
|
[(CBIS-DDSM)](https://www.cancerimagingarchive.net/collection/cbis-ddsm/). This dataset contains film mammography studies |
|
|
(as opposed to digital) with accompanying ROI annotations for benign and malignant masses and calcifications. |
|
|
|
|
|
The resultant model was further trained on data from the [RSNA Screening Mammography Breast Cancer Detection challenge](https://www.kaggle.com/competitions/rsna-breast-cancer-detection/). |
|
|
The data was split into 80%/10%/10% train/val/test. Evaluation was performed on the 10% holdout test split. |
|
|
This procedure was repeated 3 separate times to better assess the model's performance. |
|
|
The provided weights are from the first data split. |
|
|
|
|
|
Exponential moving averaging was used during training and increased performance. |
|
|
|
|
|
Note that the model was trained using cropped images, and thus it is recommended to crop the image prior to inference. |
|
|
A cropping model is provided here: https://huggingface.co/ianpan/mammo-crop |
|
|
|
|
|
The primary evaluation metric is the area under the receiver operating characteristic curve (AUC/AUROC). |
|
|
Below are the average and standard deviation across the 3 splits. |
|
|
|
|
|
``` |
|
|
Split 1: 0.9464 |
|
|
Split 2: 0.9467 |
|
|
Split 3: 0.9422 |
|
|
|
|
|
Mean (std.): 0.9451 (0.002) |
|
|
|
|
|
``` |
|
|
|
|
|
As this is a screening test, high sensitivity is desirable. We also calculate the specificity |
|
|
at varying sensitivities, shown below (averaged across 3 splits): |
|
|
|
|
|
``` |
|
|
Sensitivity: 98.1%, Specificity: 65.4% +/- 7.2%, Threshold: 0.0072 +/- 0.0021 |
|
|
Sensitivity: 94.3%, Specificity: 78.7% +/- 0.9%, Threshold: 0.0127 +/- 0.0011 |
|
|
Sensitivity: 90.5%, Specificity: 84.8% +/- 2.7%, Threshold: 0.0184 +/- 0.0027 |
|
|
``` |
|
|
|
|
|
Example usage: |
|
|
|
|
|
``` |
|
|
import cv2 |
|
|
import torch |
|
|
from transformers import AutoModel |
|
|
|
|
|
def crop_mammo(img, model, device): |
|
|
img_shape = torch.tensor([img.shape[:2]]).to(device) |
|
|
x = model.preprocess(img) |
|
|
x = torch.from_numpy(x).expand(1, 1, -1, -1).float().to(device) |
|
|
with torch.inference_mode(): |
|
|
coords = model(x, img_shape) |
|
|
coords = coords[0].cpu().numpy() |
|
|
x, y, w, h = coords |
|
|
return img[y: y + h, x: x + w] |
|
|
|
|
|
device = "cuda:0" |
|
|
|
|
|
crop_model = AutoModel.from_pretrained("ianpan/mammo-crop", trust_remote_code=True) |
|
|
crop_model = crop_model.eval().to(device) |
|
|
|
|
|
model = AutoModel.from_pretrained("ianpan/mammoscreen", trust_remote_code=True) |
|
|
model = model.eval().to(device) |
|
|
|
|
|
cc_img = cv2.imread("mammo_cc.png", cv2.IMREAD_GRAYSCALE) |
|
|
mlo_img = cv2.imread("mammo_mlo.png", cv2.IMREAD_GRAYSCALE) |
|
|
|
|
|
cc_img = crop_mammo(cc_img, crop_model, device) |
|
|
mlo_img = crop_mammo(mlo_img, crop_model, device) |
|
|
|
|
|
with torch.inference_mode(): |
|
|
output = model({"cc": cc_img, "mlo": mlo_img}, device=device) |
|
|
``` |
|
|
|
|
|
Note that the model preprocesses the data within the `forward` function into the necessary format. |
|
|
`output` is a dictionary containing two keys: `cancer` and `density`. `output['cancer']` is a tensor of shape (N, 1) and `output['density']` is a tensor of shape (N, 4). |
|
|
If you want the predicted density class, take the argmax: `output['density'].argmax(1)`. If only a single study is provided, then N=1. |
|
|
|
|
|
You can also access each neural net separately using `model.net{i}`. However, you must apply the preprocessing outside of the `forward` function. |
|
|
``` |
|
|
input_dict = model.net0.preprocess({"cc": cc_img, "mlo": mlo_img}, device=device) |
|
|
with torch.inference_mode(): |
|
|
out = model.net0(input_dict) |
|
|
``` |
|
|
|
|
|
The model also supports batch inference. Construct a dictionary for each breast and pass a list of dictionaries to the model. |
|
|
For example, if you want to perform inference for each breast for 2 patients (`pt1`, `pt2`): |
|
|
|
|
|
``` |
|
|
cc_images = ["rt_pt1_cc.png", "lt_pt1_cc.png", "rt_pt2_cc.png", "lt_pt2_cc.png"] |
|
|
mlo_images = ["rt_pt1_mlo.png", lt_pt1_mlo.png", "rt_pt2_mlo.png", "lt_pt2_mlo.png"] |
|
|
|
|
|
cc_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in cc_images] |
|
|
mlo_images = [cv2.imread(_, cv2.IMREAD_GRAYSCALE) for _ in mlo_images] |
|
|
|
|
|
cc_images = [crop_mammo(_, crop_model, device) for _ in cc_images] |
|
|
mlo_images = [crop_mammo(_, crop_model, device), for _ in mlo_images] |
|
|
|
|
|
input_dict = [{"cc": cc_img, "mlo": mlo_img} for cc_img, mlo_img in zip(cc_images, mlo_images)] |
|
|
with torch.inference_mode(): |
|
|
output = model(input_dict, device=device) |
|
|
``` |
|
|
|
|
|
Note that if you are converting images from DICOM to 8-bit PNG/JPEG, it is important to apply the lookup table to the pixel values, which can be done using `pydicom.pixels.apply_voi_lut`. |
|
|
If you have `pydicom` installed, you can also load a DICOM image directly, which handles the proper 8-bit conversion for you: |
|
|
``` |
|
|
img = model.load_image_from_dicom(path_to_dicom) |
|
|
``` |