Safetensors
dinov2
medical
curia-2 / curia_image_processor.py
asaporta-raidium's picture
Upload model
cd3fc3a verified
from typing import Any, Dict, Union
import numpy as np
import torch
from PIL import Image
from torchvision.transforms.functional import convert_image_dtype
from transformers.image_processing_utils import BaseImageProcessor
from transformers.image_utils import ChannelDimension
# at top of the file, add a compatible import
try:
from transformers.image_processing_utils import BatchFeature
except Exception:
from transformers.feature_extraction_utils import BatchFeature
_BICUBIC = Image.BICUBIC
class CuriaImageProcessor(BaseImageProcessor):
"""
1-channel medical preprocessor replicating:
NumpyToTensor -> float32 -> Resize(crop_size, BICUBIC, antialias)
-> optional ClipIntensity(min=-1000) -> NormalizeIntensity(channel_wise=True)
Outputs: pixel_values as (B, 1, crop_size, crop_size)
Images needs to be in:
- PL for axial
- IL for coronal
- IP for sagittal
for CT, no windowing, just hounsfield or normalized image
for MRI, similar, no windowing, just raw values or normalized image
"""
model_input_names = ["pixel_values"]
def __init__(
self,
crop_size: int = 512,
clip_below_air: bool = False,
eps: float = 1e-6,
do_resize: bool = True,
do_normalize: bool = True,
**kwargs,
):
super().__init__(**kwargs)
self.crop_size = int(crop_size)
self.clip_below_air = bool(clip_below_air)
self.eps = float(eps)
self.do_resize = bool(do_resize)
self.do_normalize = bool(do_normalize)
def _to_tensor(self, image: Union[np.ndarray, torch.Tensor, Image.Image]) -> torch.Tensor:
"""Accepts (H,W), (1,H,W) or PIL; returns torch.float32 tensor (H, W) in grayscale."""
if isinstance(image, Image.Image):
# force single channel
if image.mode != "L" and image.mode != "F":
image = image.convert("L")
arr = np.array(image)
tensor = torch.from_numpy(arr)
return tensor.float()
if isinstance(image, torch.Tensor):
tensor = image.detach().cpu()
if tensor.ndim == 3 and tensor.shape[0] == 1:
tensor = tensor[0]
if tensor.ndim != 2:
raise ValueError(f"Expected 2D grayscale tensor or (1,H,W); got shape {tensor.shape}")
return tensor.float()
if isinstance(image, np.ndarray):
arr = image
# squeeze singleton channel dim if present
if arr.ndim == 3 and arr.shape[0] == 1:
arr = arr[0]
if arr.ndim != 2:
raise ValueError(f"Expected 2D grayscale array or (1,H,W); got shape {arr.shape}")
tensor = torch.from_numpy(arr)
return tensor.to(torch.int16)
def _resize(self, tensor: torch.Tensor) -> torch.Tensor:
"""
Resize a 2D torch.Tensor (H, W) to (crop_size, crop_size) using bicubic interpolation.
If do_resize is False, returns the input tensor unchanged.
"""
if not self.do_resize:
return tensor
if tensor.ndim != 2:
raise ValueError(f"Expected 2D tensor (H, W), got shape {tensor.shape}")
# Add batch and channel dimensions: (1,1,H,W)
tensor = tensor.unsqueeze(0).unsqueeze(0)
tensor = torch.nn.functional.interpolate(
tensor,
size=(self.crop_size, self.crop_size),
mode="bicubic",
align_corners=False,
antialias=True,
)
# Remove batch and channel dimensions: (crop_size, crop_size)
return tensor[0, 0]
def _clip_min(self, tensor: torch.Tensor) -> torch.Tensor:
if self.clip_below_air:
torch.clamp_min(tensor, -1000.0, out=tensor)
return tensor
def _zscore_per_image(self, tensor: torch.Tensor) -> torch.Tensor:
# channel-wise=True with 1 channel -> per image z-score
mean = float(tensor.mean())
std = float(tensor.std())
if std < self.eps:
# avoid exploding when image is constant; center only
return tensor - mean
return (tensor - mean) / std
def __call__(self, images, return_tensors="pt", data_format=ChannelDimension.FIRST, **kwargs):
if not isinstance(images, (list, tuple)):
images = [images]
batch = []
for img in images:
if len(img.shape) == 3:
full_volume = []
for i in range(img.shape[-1]):
x = self._to_tensor(img[:, :, i])
x = convert_image_dtype(x, torch.float32)
x = self._resize(x)
x = self._clip_min(x) # optional
x = x[None, ...]
full_volume.append(x)
x = torch.stack(full_volume, dim=0)
x = self._zscore_per_image(x) # per-image z-score
else:
x = self._to_tensor(img)
x = convert_image_dtype(x, torch.float32)
x = self._resize(x)
x = self._clip_min(x) # optional
x = self._zscore_per_image(x) # per-image z-score
x = x[None, ...] # -> (1,H,W)
batch.append(x)
pixel_values = np.stack(batch, axis=0) # (B,1,H,W)
# 🔧 replace the old self.to_tensor(...) with this:
return BatchFeature(
data={"pixel_values": pixel_values},
tensor_type=return_tensors, # "pt" | "np" | "tf" | "jax" | None
)
# saved as preprocessor_config.json
def to_dict(self) -> Dict[str, Any]:
out = super().to_dict()
out.update(
dict(
crop_size=self.crop_size,
clip_below_air=self.clip_below_air,
eps=self.eps,
do_resize=self.do_resize,
do_normalize=self.do_normalize,
)
)
# Make AutoImageProcessor discoverable
out["auto_map"] = {"AutoImageProcessor": "curia_image_processor.CuriaImageProcessor"}
return out