|
|
"""Image processor for Sybil CT scan preprocessing""" |
|
|
|
|
|
import cv2 |
|
|
import numpy as np |
|
|
import torch |
|
|
from typing import Dict, List, Optional, Union, Tuple |
|
|
from transformers.image_processing_utils import BaseImageProcessor, BatchFeature |
|
|
from transformers.utils import TensorType |
|
|
import pydicom |
|
|
from PIL import Image |
|
|
import torchio as tio |
|
|
|
|
|
|
|
|
def order_slices(dicoms: List) -> List: |
|
|
"""Order DICOM slices by their position""" |
|
|
|
|
|
try: |
|
|
dicoms = sorted(dicoms, key=lambda x: float(x.ImagePositionPatient[2])) |
|
|
except (AttributeError, TypeError): |
|
|
|
|
|
try: |
|
|
dicoms = sorted(dicoms, key=lambda x: int(x.InstanceNumber)) |
|
|
except (AttributeError, TypeError): |
|
|
pass |
|
|
return dicoms |
|
|
|
|
|
|
|
|
class SybilImageProcessor(BaseImageProcessor): |
|
|
""" |
|
|
Constructs a Sybil image processor for preprocessing CT scans. |
|
|
|
|
|
Args: |
|
|
voxel_spacing (`List[float]`, *optional*, defaults to `[0.703125, 0.703125, 2.5]`): |
|
|
Target voxel spacing for resampling (row, column, slice thickness). |
|
|
img_size (`List[int]`, *optional*, defaults to `[512, 512]`): |
|
|
Target image size after resizing. |
|
|
num_images (`int`, *optional*, defaults to `208`): |
|
|
Number of slices to use from the CT scan. |
|
|
windowing (`Dict[str, float]`, *optional*): |
|
|
Windowing parameters for CT scan visualization. |
|
|
Default uses lung window: center=-600, width=1500. |
|
|
normalize (`bool`, *optional*, defaults to `True`): |
|
|
Whether to normalize pixel values to [0, 1]. |
|
|
**kwargs: |
|
|
Additional keyword arguments passed to the parent class. |
|
|
""" |
|
|
|
|
|
model_input_names = ["pixel_values"] |
|
|
|
|
|
def __init__( |
|
|
self, |
|
|
voxel_spacing: List[float] = None, |
|
|
img_size: List[int] = None, |
|
|
num_images: int = 208, |
|
|
windowing: Dict[str, float] = None, |
|
|
normalize: bool = True, |
|
|
**kwargs |
|
|
): |
|
|
super().__init__(**kwargs) |
|
|
|
|
|
self.voxel_spacing = voxel_spacing if voxel_spacing is not None else [0.703125, 0.703125, 2.5] |
|
|
self.img_size = img_size if img_size is not None else [512, 512] |
|
|
self.num_images = num_images |
|
|
|
|
|
|
|
|
self.windowing = windowing if windowing is not None else { |
|
|
"center": -600, |
|
|
"width": 1500 |
|
|
} |
|
|
self.normalize = normalize |
|
|
|
|
|
|
|
|
self.resample_transform = tio.transforms.Resample(target=self.voxel_spacing) |
|
|
|
|
|
self.default_depth = 200 |
|
|
self.default_size = [256, 256] |
|
|
|
|
|
self.padding_transform = tio.transforms.CropOrPad( |
|
|
target_shape=tuple(self.default_size + [self.default_depth]), |
|
|
padding_mode=0 |
|
|
) |
|
|
|
|
|
def load_dicom_series(self, paths: List[str]) -> Tuple[np.ndarray, Dict]: |
|
|
""" |
|
|
Load a series of DICOM files. |
|
|
|
|
|
Args: |
|
|
paths: List of paths to DICOM files. |
|
|
|
|
|
Returns: |
|
|
Tuple of (volume array, metadata dict) |
|
|
""" |
|
|
dicoms = [] |
|
|
for path in paths: |
|
|
try: |
|
|
dcm = pydicom.dcmread(path, stop_before_pixels=False) |
|
|
dicoms.append(dcm) |
|
|
except Exception as e: |
|
|
print(f"Error reading DICOM file {path}: {e}") |
|
|
continue |
|
|
|
|
|
if not dicoms: |
|
|
raise ValueError("No valid DICOM files found") |
|
|
|
|
|
|
|
|
dicoms = order_slices(dicoms) |
|
|
|
|
|
|
|
|
volume = np.stack([dcm.pixel_array.astype(np.float32) for dcm in dicoms]) |
|
|
|
|
|
|
|
|
metadata = { |
|
|
"slice_thickness": float(dicoms[0].SliceThickness) if hasattr(dicoms[0], 'SliceThickness') else None, |
|
|
"pixel_spacing": list(map(float, dicoms[0].PixelSpacing)) if hasattr(dicoms[0], 'PixelSpacing') else None, |
|
|
"manufacturer": str(dicoms[0].Manufacturer) if hasattr(dicoms[0], 'Manufacturer') else None, |
|
|
"num_slices": len(dicoms) |
|
|
} |
|
|
|
|
|
|
|
|
if hasattr(dicoms[0], 'RescaleSlope') and hasattr(dicoms[0], 'RescaleIntercept'): |
|
|
slope = float(dicoms[0].RescaleSlope) |
|
|
intercept = float(dicoms[0].RescaleIntercept) |
|
|
volume = volume * slope + intercept |
|
|
|
|
|
return volume, metadata |
|
|
|
|
|
def load_png_series(self, paths: List[str]) -> np.ndarray: |
|
|
""" |
|
|
Load a series of PNG files. |
|
|
|
|
|
Args: |
|
|
paths: List of paths to PNG files (must be in anatomical order). |
|
|
|
|
|
Returns: |
|
|
3D volume array |
|
|
""" |
|
|
images = [] |
|
|
for path in paths: |
|
|
img = Image.open(path).convert('L') |
|
|
images.append(np.array(img, dtype=np.float32)) |
|
|
|
|
|
return np.stack(images) |
|
|
|
|
|
def resize_slices(self, volume: np.ndarray, target_size: List[int] = None) -> np.ndarray: |
|
|
""" |
|
|
Resize each slice in the volume to target size using OpenCV bilinear interpolation. |
|
|
This exactly matches the original Sybil's per-slice 2D resize operation. |
|
|
|
|
|
Args: |
|
|
volume: 3D volume array (D, H, W). |
|
|
target_size: Target size [H, W]. Defaults to [256, 256]. |
|
|
|
|
|
Returns: |
|
|
Resized volume. |
|
|
""" |
|
|
if target_size is None: |
|
|
target_size = self.default_size |
|
|
|
|
|
|
|
|
resized_slices = [] |
|
|
for i in range(volume.shape[0]): |
|
|
slice_2d = volume[i] |
|
|
|
|
|
resized = cv2.resize( |
|
|
slice_2d, |
|
|
dsize=(target_size[1], target_size[0]), |
|
|
interpolation=cv2.INTER_LINEAR |
|
|
) |
|
|
resized_slices.append(resized) |
|
|
|
|
|
|
|
|
return np.stack(resized_slices, axis=0) |
|
|
|
|
|
def apply_windowing(self, volume: np.ndarray) -> np.ndarray: |
|
|
""" |
|
|
Apply DICOM-standard windowing to CT scan, matching the original Sybil implementation. |
|
|
|
|
|
This implements the same windowing as the original Sybil: |
|
|
- Uses DICOM standard formula with center-0.5 and width-1 adjustments |
|
|
- Outputs to 16-bit range [0, 65535] then divides by 256 for 8-bit parity |
|
|
- Results in [0, 255] range that will be normalized later |
|
|
|
|
|
Args: |
|
|
volume: 3D CT volume in Hounsfield Units. |
|
|
|
|
|
Returns: |
|
|
Windowed volume in [0, 255] range. |
|
|
""" |
|
|
center = self.windowing["center"] |
|
|
width = self.windowing["width"] |
|
|
|
|
|
|
|
|
bit_size = 16 |
|
|
y_min = 0 |
|
|
y_max = 2 ** bit_size - 1 |
|
|
y_range = y_max - y_min |
|
|
|
|
|
|
|
|
c = center - 0.5 |
|
|
w = width - 1 |
|
|
|
|
|
|
|
|
lower_bound = c - w / 2 |
|
|
upper_bound = c + w / 2 |
|
|
|
|
|
|
|
|
below = volume <= lower_bound |
|
|
above = volume > upper_bound |
|
|
between = np.logical_and(~below, ~above) |
|
|
|
|
|
|
|
|
windowed = np.zeros_like(volume, dtype=np.float32) |
|
|
|
|
|
|
|
|
windowed[below] = y_min |
|
|
windowed[above] = y_max |
|
|
|
|
|
if between.any(): |
|
|
|
|
|
windowed[between] = ((volume[between] - c) / w + 0.5) * y_range + y_min |
|
|
|
|
|
|
|
|
|
|
|
windowed = windowed // 256 |
|
|
|
|
|
return windowed |
|
|
|
|
|
def resample_volume( |
|
|
self, |
|
|
volume: torch.Tensor, |
|
|
original_spacing: Optional[List[float]] = None |
|
|
) -> torch.Tensor: |
|
|
""" |
|
|
Resample volume to target voxel spacing. |
|
|
Uses affine matrix approach matching original Sybil exactly. |
|
|
|
|
|
Args: |
|
|
volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). |
|
|
original_spacing: Original voxel spacing [H_spacing, W_spacing, D_spacing]. |
|
|
|
|
|
Returns: |
|
|
Resampled volume with same number of dimensions. |
|
|
""" |
|
|
|
|
|
if len(volume.shape) == 3: |
|
|
|
|
|
volume_4d = volume.unsqueeze(0) |
|
|
squeeze_output = True |
|
|
elif len(volume.shape) == 4: |
|
|
|
|
|
volume_4d = volume |
|
|
squeeze_output = False |
|
|
else: |
|
|
raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") |
|
|
|
|
|
|
|
|
volume_tio = volume_4d.permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
|
|
|
if original_spacing is not None: |
|
|
|
|
|
voxel_spacing_4d = torch.tensor(original_spacing + [1.0], dtype=torch.float32) |
|
|
affine = torch.diag(voxel_spacing_4d) |
|
|
else: |
|
|
affine = None |
|
|
|
|
|
|
|
|
subject = tio.Subject( |
|
|
image=tio.ScalarImage(tensor=volume_tio, affine=affine) |
|
|
) |
|
|
|
|
|
|
|
|
resampled = self.resample_transform(subject) |
|
|
|
|
|
|
|
|
result = resampled['image'].data.permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
if squeeze_output: |
|
|
return result.squeeze(0) |
|
|
else: |
|
|
return result |
|
|
|
|
|
def pad_or_crop_volume(self, volume: torch.Tensor) -> torch.Tensor: |
|
|
""" |
|
|
Pad or crop volume to target shape. |
|
|
|
|
|
Args: |
|
|
volume: 3D or 4D volume tensor (D, H, W) or (C, D, H, W). |
|
|
|
|
|
Returns: |
|
|
Padded/cropped volume with same number of dimensions. |
|
|
""" |
|
|
|
|
|
if len(volume.shape) == 3: |
|
|
|
|
|
volume_4d = volume.unsqueeze(0) |
|
|
squeeze_output = True |
|
|
elif len(volume.shape) == 4: |
|
|
|
|
|
volume_4d = volume |
|
|
squeeze_output = False |
|
|
else: |
|
|
raise ValueError(f"Expected 3D or 4D volume, got shape {volume.shape}") |
|
|
|
|
|
|
|
|
volume_tio = volume_4d.permute(0, 2, 3, 1) |
|
|
|
|
|
|
|
|
subject = tio.Subject( |
|
|
image=tio.ScalarImage(tensor=volume_tio) |
|
|
) |
|
|
|
|
|
|
|
|
transformed = self.padding_transform(subject) |
|
|
|
|
|
|
|
|
result = transformed['image'].data.permute(0, 3, 1, 2) |
|
|
|
|
|
|
|
|
if squeeze_output: |
|
|
return result.squeeze(0) |
|
|
else: |
|
|
return result |
|
|
|
|
|
def preprocess( |
|
|
self, |
|
|
images: Union[List[str], np.ndarray, torch.Tensor], |
|
|
file_type: str = "dicom", |
|
|
voxel_spacing: Optional[List[float]] = None, |
|
|
return_tensors: Optional[Union[str, TensorType]] = None, |
|
|
**kwargs |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Preprocess CT scan images. |
|
|
|
|
|
Args: |
|
|
images: Either list of file paths or numpy/torch array of images. |
|
|
file_type: Type of input files ("dicom" or "png"). |
|
|
voxel_spacing: Original voxel spacing (required for PNG files). |
|
|
return_tensors: The type of tensors to return. |
|
|
|
|
|
Returns: |
|
|
BatchFeature with preprocessed images. |
|
|
""" |
|
|
|
|
|
if isinstance(images, list) and isinstance(images[0], str): |
|
|
if file_type == "dicom": |
|
|
volume, metadata = self.load_dicom_series(images) |
|
|
if voxel_spacing is None and metadata["pixel_spacing"]: |
|
|
voxel_spacing = metadata["pixel_spacing"] + [metadata["slice_thickness"]] |
|
|
elif file_type == "png": |
|
|
if voxel_spacing is None: |
|
|
raise ValueError("voxel_spacing must be provided for PNG files") |
|
|
volume = self.load_png_series(images) |
|
|
else: |
|
|
raise ValueError(f"Unknown file type: {file_type}") |
|
|
elif isinstance(images, (np.ndarray, torch.Tensor)): |
|
|
volume = images |
|
|
else: |
|
|
raise ValueError("Images must be file paths, numpy array, or torch tensor") |
|
|
|
|
|
|
|
|
if isinstance(volume, torch.Tensor): |
|
|
volume_np = volume.numpy() |
|
|
else: |
|
|
volume_np = volume |
|
|
|
|
|
|
|
|
volume_np = self.apply_windowing(volume_np) |
|
|
|
|
|
|
|
|
volume_np = self.resize_slices(volume_np, target_size=self.default_size) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
volume = torch.from_numpy(volume_np).float() |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
img_mean = 128.1722 |
|
|
img_std = 87.1849 |
|
|
volume = (volume - img_mean) / img_std |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
volume = volume.unsqueeze(0).repeat(3, 1, 1, 1) |
|
|
|
|
|
|
|
|
|
|
|
if voxel_spacing is not None: |
|
|
volume = self.resample_volume(volume, voxel_spacing) |
|
|
|
|
|
|
|
|
volume = self.pad_or_crop_volume(volume) |
|
|
|
|
|
|
|
|
volume = volume.unsqueeze(0) |
|
|
|
|
|
|
|
|
data = {"pixel_values": volume} |
|
|
|
|
|
|
|
|
if return_tensors == "pt": |
|
|
return BatchFeature(data=data, tensor_type=TensorType.PYTORCH) |
|
|
elif return_tensors == "np": |
|
|
data = {k: v.numpy() for k, v in data.items()} |
|
|
return BatchFeature(data=data, tensor_type=TensorType.NUMPY) |
|
|
else: |
|
|
return BatchFeature(data=data) |
|
|
|
|
|
def __call__( |
|
|
self, |
|
|
images: Union[List[str], List[List[str]], np.ndarray, torch.Tensor], |
|
|
**kwargs |
|
|
) -> BatchFeature: |
|
|
""" |
|
|
Main method to prepare images for the model. |
|
|
|
|
|
Args: |
|
|
images: Images to preprocess. Can be: |
|
|
- List of file paths for a single series |
|
|
- List of lists of file paths for multiple series |
|
|
- Numpy array or torch tensor |
|
|
|
|
|
Returns: |
|
|
BatchFeature with preprocessed images ready for model input. |
|
|
""" |
|
|
|
|
|
if isinstance(images, list) and images and isinstance(images[0], list): |
|
|
|
|
|
batch_volumes = [] |
|
|
for series_paths in images: |
|
|
result = self.preprocess(series_paths, **kwargs) |
|
|
batch_volumes.append(result["pixel_values"]) |
|
|
|
|
|
|
|
|
pixel_values = torch.stack(batch_volumes) |
|
|
return BatchFeature(data={"pixel_values": pixel_values}) |
|
|
else: |
|
|
|
|
|
return self.preprocess(images, **kwargs) |