File size: 6,132 Bytes
cd3fc3a | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 | from typing import Any, Dict, Union
import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ChannelDimension
# at top of the file, add a compatible import
try:
from transformers.image_processing_utils import BatchFeature
except Exception:
from transformers.feature_extraction_utils import BatchFeature
_BICUBIC = Image.BICUBIC
class CuriaImageProcessor(BaseImageProcessor):
"""
1-channel medical preprocessor replicating:
NumpyToTensor -> float32 -> Resize(crop_size, BICUBIC, antialias)
-> optional ClipIntensity(min=-1000) -> NormalizeIntensity(channel_wise=True)
Outputs: pixel_values as (B, 1, crop_size, crop_size)
Images needs to be in:
- PL for axial
- IL for coronal
- IP for sagittal
for CT, no windowing, just hounsfield or normalized image
for MRI, similar, no windowing, just raw values or normalized image
"""
model_input_names = ["pixel_values"]
def __init__(
self,
crop_size: int = 512,
clip_below_air: bool = False,
eps: float = 1e-6,
do_resize: bool = True,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.crop_size = int(crop_size)
self.clip_below_air = bool(clip_below_air)
self.eps = float(eps)
self.do_resize = bool(do_resize)
self.do_normalize = bool(do_normalize)
def _to_tensor(self, image: Union[np.ndarray, torch.Tensor, Image.Image]) -> torch.Tensor:
"""Accepts (H,W), (1,H,W) or PIL; returns torch.float32 tensor (H, W) in grayscale."""
if isinstance(image, Image.Image):
# force single channel
if image.mode != "L" and image.mode != "F":
image = image.convert("L")
arr = np.array(image)
tensor = torch.from_numpy(arr)
return tensor.float()
if isinstance(image, torch.Tensor):
tensor = image.detach().cpu()
if tensor.ndim == 3 and tensor.shape[0] == 1:
tensor = tensor[0]
if tensor.ndim != 2:
raise ValueError(f"Expected 2D grayscale tensor or (1,H,W); got shape {tensor.shape}")
return tensor.float()
if isinstance(image, np.ndarray):
arr = image
# squeeze singleton channel dim if present
if arr.ndim == 3 and arr.shape[0] == 1:
arr = arr[0]
if arr.ndim != 2:
raise ValueError(f"Expected 2D grayscale array or (1,H,W); got shape {arr.shape}")
tensor = torch.from_numpy(arr)
return tensor.to(torch.int16)
def _resize(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Resize a 2D torch.Tensor (H, W) to (crop_size, crop_size) using bicubic interpolation.
If do_resize is False, returns the input tensor unchanged.
"""
if not self.do_resize:
return tensor
if tensor.ndim != 2:
raise ValueError(f"Expected 2D tensor (H, W), got shape {tensor.shape}")
# Add batch and channel dimensions: (1,1,H,W)
tensor = tensor.unsqueeze(0).unsqueeze(0)
tensor = torch.nn.functional.interpolate(
tensor,
size=(self.crop_size, self.crop_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
# Remove batch and channel dimensions: (crop_size, crop_size)
return tensor[0, 0]
def _clip_min(self, tensor: torch.Tensor) -> torch.Tensor:
if self.clip_below_air:
torch.clamp_min(tensor, -1000.0, out=tensor)
return tensor
def _zscore_per_image(self, tensor: torch.Tensor) -> torch.Tensor:
# channel-wise=True with 1 channel -> per image z-score
mean = float(tensor.mean())
std = float(tensor.std())
if std < self.eps:
# avoid exploding when image is constant; center only
return tensor - mean
return (tensor - mean) / std
def __call__(self, images, return_tensors="pt", data_format=ChannelDimension.FIRST, **kwargs):
if not isinstance(images, (list, tuple)):
images = [images]
batch = []
for img in images:
if len(img.shape) == 3:
full_volume = []
for i in range(img.shape[-1]):
x = self._to_tensor(img[:, :, i])
x = convert_image_dtype(x, torch.float32)
x = self._resize(x)
x = self._clip_min(x) # optional
x = x[None, ...]
full_volume.append(x)
x = torch.stack(full_volume, dim=0)
x = self._zscore_per_image(x) # per-image z-score
else:
x = self._to_tensor(img)
x = convert_image_dtype(x, torch.float32)
x = self._resize(x)
x = self._clip_min(x) # optional
x = self._zscore_per_image(x) # per-image z-score
x = x[None, ...] # -> (1,H,W)
batch.append(x)
pixel_values = np.stack(batch, axis=0) # (B,1,H,W)
# 🔧 replace the old self.to_tensor(...) with this:
return BatchFeature(
data={"pixel_values": pixel_values},
tensor_type=return_tensors, # "pt" | "np" | "tf" | "jax" | None
)
# saved as preprocessor_config.json
def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
out.update(
dict(
crop_size=self.crop_size,
clip_below_air=self.clip_below_air,
eps=self.eps,
do_resize=self.do_resize,
do_normalize=self.do_normalize,
)
)
# Make AutoImageProcessor discoverable
out["auto_map"] = {"AutoImageProcessor": "curia_image_processor.CuriaImageProcessor"}
return out
|