brain-mri-siglip / offline_aligned_preprocessing.py
shenxiaochen's picture
Add files using upload-large-folder tool
8360541 verified
"""Shared offline-aligned preprocessing helpers for 3D brain MRI volumes."""
from __future__ import annotations
import math
from pathlib import Path
from typing import Any, Mapping
import nibabel as nib
import numpy as np
import torch
import torch.nn.functional as F
try:
from scipy import ndimage as scipy_ndimage
except Exception: # pragma: no cover - optional import surface
scipy_ndimage = None
TARGET_SHAPE = (128, 192, 192)
TARGET_SPACING = (1.25, 1.0, 1.0)
CROP_MARGIN_MM = 5.0
FOREGROUND_THRESHOLD = 1e-3
BACKGROUND_VALUE = -1.0
FOREGROUND_STRATEGY = "largest_component_nonzero"
GENERIC_RECIPE_ID = "generic_foreground_128x192x192_fp16_v1"
GENERIC_CACHE_VERSION = 1
def load_canonical_nifti(path: str | Path):
return nib.as_closest_canonical(nib.load(str(path)))
def load_image_spacing(image) -> tuple[float, float, float]:
zooms = image.header.get_zooms()[:3]
if len(zooms) != 3:
raise ValueError(f"Expected a 3D image spacing tuple, got {zooms}.")
return tuple(float(value) for value in zooms)
def coerce_volume_to_3d(volume: np.ndarray) -> np.ndarray:
if volume.ndim == 3:
return volume.astype(np.float32, copy=False)
if volume.ndim != 4:
raise ValueError(f"Expected a 3D or 4D volume, got shape {volume.shape}.")
if volume.shape[0] <= 4 and volume.shape[-1] > 4:
selected = volume[0]
else:
selected = volume[..., 0]
return np.asarray(selected, dtype=np.float32)
def largest_connected_component(mask: np.ndarray) -> np.ndarray:
if not mask.any() or scipy_ndimage is None:
return mask
structure = scipy_ndimage.generate_binary_structure(mask.ndim, 1)
labels, num_labels = scipy_ndimage.label(mask, structure=structure)
if num_labels <= 1:
return mask
counts = np.bincount(labels.reshape(-1))
if counts.size <= 1:
return mask
counts[0] = 0
winning_label = int(counts.argmax())
if winning_label <= 0 or counts[winning_label] <= 0:
return mask
return labels == winning_label
def build_foreground_mask(volume: np.ndarray, threshold: float = FOREGROUND_THRESHOLD) -> np.ndarray:
sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0)
raw_mask = np.abs(sanitized) > float(threshold)
if not raw_mask.any():
return np.ones_like(sanitized, dtype=bool)
component_mask = largest_connected_component(raw_mask)
component_count = int(component_mask.sum())
raw_count = int(raw_mask.sum())
if component_count <= 0:
return raw_mask
if component_count < 512 and raw_count > component_count:
return raw_mask
return component_mask
def compute_crop_bbox(
mask: np.ndarray,
spacing: tuple[float, float, float],
margin_mm: float = CROP_MARGIN_MM,
) -> tuple[tuple[int, int], ...]:
coords = np.where(mask)
if coords[0].size == 0:
raise ValueError("Foreground mask contains no positive voxels after selection.")
bbox = []
for axis, values in enumerate(coords):
margin_voxels = int(math.ceil(float(margin_mm) / float(spacing[axis])))
start = max(0, int(values.min()) - margin_voxels)
stop = min(mask.shape[axis], int(values.max()) + margin_voxels + 1)
bbox.append((start, stop))
return tuple(bbox)
def crop_volume_and_mask(
volume: np.ndarray,
mask: np.ndarray,
spacing: tuple[float, float, float],
margin_mm: float = CROP_MARGIN_MM,
) -> tuple[np.ndarray, np.ndarray, tuple[tuple[int, int], ...]]:
bbox = compute_crop_bbox(mask, spacing, margin_mm=margin_mm)
slices = tuple(slice(start, stop) for start, stop in bbox)
return volume[slices], mask[slices], bbox
def normalize_foreground_only(volume: np.ndarray, mask: np.ndarray) -> np.ndarray:
sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False)
foreground_values = sanitized[mask]
if foreground_values.size == 0:
raise ValueError("Cannot normalize volume because the foreground mask is empty.")
if foreground_values.size > 1_000_000:
step = max(1, foreground_values.size // 1_000_000)
foreground_values = foreground_values[::step]
low, high = np.percentile(foreground_values, [0.5, 99.5])
if not np.isfinite(low) or not np.isfinite(high) or high <= low:
normalized = np.zeros_like(sanitized, dtype=np.float32)
else:
normalized = np.clip(sanitized, float(low), float(high))
normalized = np.clip((normalized - float(low)) / float(high - low), 0.0, 1.0)
normalized = normalized * 2.0 - 1.0
return normalized.astype(np.float32, copy=False)
def resize_volume(volume: np.ndarray, size: tuple[int, int, int], mode: str) -> np.ndarray:
tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0)
kwargs = {}
if mode in {"linear", "bilinear", "bicubic", "trilinear"}:
kwargs["align_corners"] = False
tensor = F.interpolate(tensor, size=size, mode=mode, **kwargs)
return tensor.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32, copy=False)
def resize_mask(mask: np.ndarray, size: tuple[int, int, int]) -> np.ndarray:
tensor = torch.from_numpy(mask.astype(np.float32, copy=False)).unsqueeze(0).unsqueeze(0)
tensor = F.interpolate(tensor, size=size, mode="nearest")
return tensor.squeeze(0).squeeze(0).cpu().numpy() > 0.5
def resample_to_target_spacing(
volume: np.ndarray,
mask: np.ndarray,
source_spacing: tuple[float, float, float],
target_spacing: tuple[float, float, float] = TARGET_SPACING,
) -> tuple[np.ndarray, np.ndarray]:
target_shape = []
for current_size, src, dst in zip(volume.shape, source_spacing, target_spacing):
target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst)))))
target_shape_tuple = tuple(target_shape)
if target_shape_tuple == tuple(int(v) for v in volume.shape):
return volume.astype(np.float32, copy=False), mask
return (
resize_volume(volume, target_shape_tuple, mode="trilinear"),
resize_mask(mask, target_shape_tuple),
)
def downscale_to_fit(
volume: np.ndarray,
mask: np.ndarray,
target_shape: tuple[int, int, int] = TARGET_SHAPE,
) -> tuple[np.ndarray, np.ndarray]:
current_shape = tuple(int(v) for v in volume.shape)
if all(current <= target for current, target in zip(current_shape, target_shape)):
return volume, mask
scale = min(float(target) / float(current) for current, target in zip(current_shape, target_shape))
if scale >= 1.0:
return volume, mask
new_shape = tuple(
min(target, max(1, int(math.floor(float(current) * scale))))
for current, target in zip(current_shape, target_shape)
)
return (
resize_volume(volume, new_shape, mode="trilinear"),
resize_mask(mask, new_shape),
)
def center_pad(
array: np.ndarray,
target_shape: tuple[int, int, int] = TARGET_SHAPE,
fill_value: float = BACKGROUND_VALUE,
) -> np.ndarray:
if any(current > target for current, target in zip(array.shape, target_shape)):
raise ValueError(f"Cannot center-pad shape {array.shape} into smaller target {target_shape}.")
pad_width = []
for current, target in zip(array.shape, target_shape):
delta = target - current
before = delta // 2
after = delta - before
pad_width.append((before, after))
return np.pad(array, pad_width=tuple(pad_width), mode="constant", constant_values=fill_value)
def preprocess_image_with_foreground_mask(
image_path: str | Path,
*,
target_shape: tuple[int, int, int] = TARGET_SHAPE,
target_spacing: tuple[float, float, float] = TARGET_SPACING,
crop_margin_mm: float = CROP_MARGIN_MM,
foreground_threshold: float = FOREGROUND_THRESHOLD,
background_value: float = BACKGROUND_VALUE,
foreground_strategy: str = FOREGROUND_STRATEGY,
recipe_id: str = GENERIC_RECIPE_ID,
cache_version: int = GENERIC_CACHE_VERSION,
) -> dict[str, object]:
image_path = Path(image_path)
image = load_canonical_nifti(image_path)
source_shape = tuple(int(value) for value in image.shape)
source_spacing = load_image_spacing(image)
volume = np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32)
volume = coerce_volume_to_3d(volume)
foreground_mask = build_foreground_mask(volume, threshold=foreground_threshold)
cropped_volume, cropped_mask, crop_bbox = crop_volume_and_mask(
volume,
foreground_mask,
source_spacing,
margin_mm=crop_margin_mm,
)
normalized_volume = normalize_foreground_only(cropped_volume, cropped_mask)
resampled_volume, resampled_mask = resample_to_target_spacing(
normalized_volume,
cropped_mask,
source_spacing=source_spacing,
target_spacing=target_spacing,
)
fitted_volume, fitted_mask = downscale_to_fit(
resampled_volume,
resampled_mask,
target_shape=target_shape,
)
fitted_volume = np.clip(fitted_volume, -1.0, 1.0).astype(np.float32, copy=False)
fitted_volume[~fitted_mask] = float(background_value)
padded_volume = center_pad(
fitted_volume,
target_shape=target_shape,
fill_value=float(background_value),
).astype(np.float32, copy=False)
pixel_values = torch.from_numpy(padded_volume).unsqueeze(0).to(dtype=torch.float16).contiguous()
return {
"pixel_values": pixel_values,
"source_image": str(image_path),
"source_shape": list(source_shape),
"source_spacing": list(source_spacing),
"crop_bbox": [[int(start), int(stop)] for start, stop in crop_bbox],
"foreground_strategy": foreground_strategy,
"recipe_id": recipe_id,
"cache_version": int(cache_version),
}
def validate_fixed_payload(
payload: Mapping[str, Any],
*,
target_shape: tuple[int, int, int] = TARGET_SHAPE,
) -> None:
pixel_values = payload.get("pixel_values")
if not isinstance(pixel_values, torch.Tensor):
raise TypeError("`pixel_values` must be a torch.Tensor.")
expected_shape = (1,) + tuple(target_shape)
if tuple(pixel_values.shape) != expected_shape:
raise ValueError(f"Expected tensor shape {expected_shape}, got {tuple(pixel_values.shape)}.")
if pixel_values.dtype != torch.float16:
raise ValueError(f"Expected tensor dtype torch.float16, got {pixel_values.dtype}.")
if not torch.isfinite(pixel_values).all():
raise ValueError("Tensor contains non-finite values.")
min_value = float(pixel_values.min().item())
max_value = float(pixel_values.max().item())
if min_value < -1.01 or max_value > 1.01:
raise ValueError(f"Expected tensor values in [-1, 1]. Got min={min_value}, max={max_value}.")