File size: 5,521 Bytes
4cbe18f 61b5bfd 4cbe18f 61b5bfd a8f4052 61b5bfd | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 | ---
library_name: transformers
tags:
- ct
- computed_tomography
- crop
- dicom
- radiology
license: apache-2.0
base_model:
- timm/mobilenetv3_small_100.lamb_in1k
pipeline_tag: object-detection
---
This model crops the foreground from the background in CT slices. It is a lightweight `mobilenetv3_small_100`
model trained on CT examinations from the [public TotalSegmentator dataset](https://zenodo.org/records/10047292), version.2.0.1.
The following function was used to generate masks for each CT:
```
import nibabel as nib
import numpy as np
from scipy.ndimage import binary_closing, binary_fill_holes, minimum_filter
from skimage.measure import label
def generate_mask(array):
mask = (array > 0).astype("uint8")
mask_label = label(mask)
labels, counts = np.unique(mask_label, return_counts=True)
labels, counts = labels[1:], counts[1:]
max_label = labels[np.argmax(counts)]
mask = mask_label == max_label
mask = np.stack([
binary_fill_holes(binary_closing(mask[:, :, i]))
for i in range(mask.shape[2])
], axis=2).astype("uint8")
mask = np.stack([
minimum_filter(mask[:, :, i], size=3)
for i in range(mask.shape[2])
], axis=2)
return mask
array = nib.load("ct.nii.gz").get_fdata()
# apply soft tissue window
array = apply_ct_window(array, window_level=50, window_width=400)
mask = generate_mask(array)
```
Bounding box coordinates were generated from the masks for individual slices.
The model was then trained to predict normalized (0-1) `xwyh` coordinates, given an individual CT slice.
If the mask was empty, the coordinates were set to all zero.
Images were converted from Hounsfield units (HU) to 4 CT windows:
1. Soft tissue (level=50, width=400)
2. Brain (level=40, width=80)
3. Lung (level=-600, width=1500)
4. Bone (level=400, width=1800)
During training, random combinations of channels were selected. If more than 1 channel was selected,
the images were averaged channel-wise to create a single-channel output. Strong data augmentation was also applied.
Thus, this model should be robust to different CT windows and combinations thereof.
Example usage below:
```
import cv2
import torch
from transformers import AutoModel
device = "cuda" if torch.cuda.is_available() else "cpu"
cropper = AutoModel.from_pretrained("ianpan/ct-crop", trust_remote_code=True).eval().to(device)
# single image
img = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=None)
# expand all 4 sides by 2.5% each
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=0.025)
# expand box height by 2.5% in each direction
# and box width by 5% in each direction
buffer = (0.05, 0.025)
cropped_img = cropper.crop(img, mode="2d", device=device, raw_hu=False, add_buffer=buffer)
# stack of images
img_list = ["ct_slice_1.png", "ct_slice_2.png", ...]
stack = np.stack([cv2.imread(img, cv2.IMREAD_GRAYSCALE) for img in img_list], axis=0)
cropped_stack = cropper.crop(img, mode="3d", device=device, raw_hu=False, add_buffer=None)
```
You can also get the coordinates directly and do the cropping yourself.
You must separately preprocess the input. Example below:
```
# single image
img0 = cv2.imread("ct_slice.png", cv2.IMREAD_GRAYSCALE)
img_shapes = torch.tensor([_.shape[:2] for _ in [img0]]).to(device)
img = cropper.preprocess(img0, mode="2d")
# if multi-channel, need to convert from channels-last -> channels-first
img = torch.from_numpy(img).expand(1, 1, -1, -1).float().to(device)
with torch.inference_mode():
coords = cropper(img, img_shape=img_shapes, add_buffer=None)
# if you do not provide img_shapes, output will be normalized (0-1) coordinates
# otherwise will be scaled to img_shape
```
The model also contains methods to load DICOM images, if you have `pydicom` installed:
```
img = cropper.load_image_from_dicom(path_to_dicom_file, windows=None)
# note: RescaleSlope and RescaleIntercept already applied in the method
# apply CT window
brain_window = (40, 80)
img = cropper.load_image_from_dicom(path_to_dicom_file, windows=brain_window)
# or multiple windows
soft_tissue_window = (50, 400)
img = cropper.load_image_from_dicom(path_to_dicom_file,
windows=[brain_window, soft_tissue_window])
# each window is a separate channel, img will be channels-last format
```
You can also load a stack of DICOM images from a folder:
```
dicom_folder = "/path/to/ct/head/images/"
# dicom_extension is used to filter files, default is ".dcm"
# can pass "" if you do not want to filter files
# default sort is by ImagePositionPatient using automatically determined
# orientation, can also sort by InstanceNumber
# can also apply CT windows, as above
stack = cropper.load_stack_from_dicom_folder(dicom_folder,
windows=None,
dicom_extension=".dcm",
sort_by_instance_number=False)
# can input raw Hounsfield units into cropper
cropped_stack = cropper.crop(stack, mode="3d", device=device, raw_hu=True)
```
By default, the cropper will not remove slices in a stack, even if they are predicted to be empty.
You can enable this by specifying `remove_empty_slices=True`, which will also return
the indices in the original input of the removed empty slices.
```
cropped_stack, empty_slice_indices = cropper.crop(stack, mode="3d", remove_empty_slices=True)
``` |