Feature Extraction
Transformers
Safetensors
PyTorch
brain-mri-siglip
medical-imaging
mri
brain-mri
siglip
vision-language
contrastive-learning
custom-code
custom_code
Instructions to use shenxiaochen/brain-mri-siglip with libraries, inference providers, notebooks, and local apps. Follow these links to get started.
- Libraries
- Transformers
How to use shenxiaochen/brain-mri-siglip with Transformers:
# Use a pipeline as a high-level helper from transformers import pipeline pipe = pipeline("feature-extraction", model="shenxiaochen/brain-mri-siglip", trust_remote_code=True)# Load model directly from transformers import AutoModel model = AutoModel.from_pretrained("shenxiaochen/brain-mri-siglip", trust_remote_code=True, dtype="auto") - Notebooks
- Google Colab
- Kaggle
| """Shared offline-aligned preprocessing helpers for 3D brain MRI volumes.""" | |
| from __future__ import annotations | |
| import math | |
| from pathlib import Path | |
| from typing import Any, Mapping | |
| import nibabel as nib | |
| import numpy as np | |
| import torch | |
| import torch.nn.functional as F | |
| try: | |
| from scipy import ndimage as scipy_ndimage | |
| except Exception: # pragma: no cover - optional import surface | |
| scipy_ndimage = None | |
| TARGET_SHAPE = (128, 192, 192) | |
| TARGET_SPACING = (1.25, 1.0, 1.0) | |
| CROP_MARGIN_MM = 5.0 | |
| FOREGROUND_THRESHOLD = 1e-3 | |
| BACKGROUND_VALUE = -1.0 | |
| FOREGROUND_STRATEGY = "largest_component_nonzero" | |
| GENERIC_RECIPE_ID = "generic_foreground_128x192x192_fp16_v1" | |
| GENERIC_CACHE_VERSION = 1 | |
| def load_canonical_nifti(path: str | Path): | |
| return nib.as_closest_canonical(nib.load(str(path))) | |
| def load_image_spacing(image) -> tuple[float, float, float]: | |
| zooms = image.header.get_zooms()[:3] | |
| if len(zooms) != 3: | |
| raise ValueError(f"Expected a 3D image spacing tuple, got {zooms}.") | |
| return tuple(float(value) for value in zooms) | |
| def coerce_volume_to_3d(volume: np.ndarray) -> np.ndarray: | |
| if volume.ndim == 3: | |
| return volume.astype(np.float32, copy=False) | |
| if volume.ndim != 4: | |
| raise ValueError(f"Expected a 3D or 4D volume, got shape {volume.shape}.") | |
| if volume.shape[0] <= 4 and volume.shape[-1] > 4: | |
| selected = volume[0] | |
| else: | |
| selected = volume[..., 0] | |
| return np.asarray(selected, dtype=np.float32) | |
| def largest_connected_component(mask: np.ndarray) -> np.ndarray: | |
| if not mask.any() or scipy_ndimage is None: | |
| return mask | |
| structure = scipy_ndimage.generate_binary_structure(mask.ndim, 1) | |
| labels, num_labels = scipy_ndimage.label(mask, structure=structure) | |
| if num_labels <= 1: | |
| return mask | |
| counts = np.bincount(labels.reshape(-1)) | |
| if counts.size <= 1: | |
| return mask | |
| counts[0] = 0 | |
| winning_label = int(counts.argmax()) | |
| if winning_label <= 0 or counts[winning_label] <= 0: | |
| return mask | |
| return labels == winning_label | |
| def build_foreground_mask(volume: np.ndarray, threshold: float = FOREGROUND_THRESHOLD) -> np.ndarray: | |
| sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0) | |
| raw_mask = np.abs(sanitized) > float(threshold) | |
| if not raw_mask.any(): | |
| return np.ones_like(sanitized, dtype=bool) | |
| component_mask = largest_connected_component(raw_mask) | |
| component_count = int(component_mask.sum()) | |
| raw_count = int(raw_mask.sum()) | |
| if component_count <= 0: | |
| return raw_mask | |
| if component_count < 512 and raw_count > component_count: | |
| return raw_mask | |
| return component_mask | |
| def compute_crop_bbox( | |
| mask: np.ndarray, | |
| spacing: tuple[float, float, float], | |
| margin_mm: float = CROP_MARGIN_MM, | |
| ) -> tuple[tuple[int, int], ...]: | |
| coords = np.where(mask) | |
| if coords[0].size == 0: | |
| raise ValueError("Foreground mask contains no positive voxels after selection.") | |
| bbox = [] | |
| for axis, values in enumerate(coords): | |
| margin_voxels = int(math.ceil(float(margin_mm) / float(spacing[axis]))) | |
| start = max(0, int(values.min()) - margin_voxels) | |
| stop = min(mask.shape[axis], int(values.max()) + margin_voxels + 1) | |
| bbox.append((start, stop)) | |
| return tuple(bbox) | |
| def crop_volume_and_mask( | |
| volume: np.ndarray, | |
| mask: np.ndarray, | |
| spacing: tuple[float, float, float], | |
| margin_mm: float = CROP_MARGIN_MM, | |
| ) -> tuple[np.ndarray, np.ndarray, tuple[tuple[int, int], ...]]: | |
| bbox = compute_crop_bbox(mask, spacing, margin_mm=margin_mm) | |
| slices = tuple(slice(start, stop) for start, stop in bbox) | |
| return volume[slices], mask[slices], bbox | |
| def normalize_foreground_only(volume: np.ndarray, mask: np.ndarray) -> np.ndarray: | |
| sanitized = np.nan_to_num(volume, nan=0.0, posinf=0.0, neginf=0.0).astype(np.float32, copy=False) | |
| foreground_values = sanitized[mask] | |
| if foreground_values.size == 0: | |
| raise ValueError("Cannot normalize volume because the foreground mask is empty.") | |
| if foreground_values.size > 1_000_000: | |
| step = max(1, foreground_values.size // 1_000_000) | |
| foreground_values = foreground_values[::step] | |
| low, high = np.percentile(foreground_values, [0.5, 99.5]) | |
| if not np.isfinite(low) or not np.isfinite(high) or high <= low: | |
| normalized = np.zeros_like(sanitized, dtype=np.float32) | |
| else: | |
| normalized = np.clip(sanitized, float(low), float(high)) | |
| normalized = np.clip((normalized - float(low)) / float(high - low), 0.0, 1.0) | |
| normalized = normalized * 2.0 - 1.0 | |
| return normalized.astype(np.float32, copy=False) | |
| def resize_volume(volume: np.ndarray, size: tuple[int, int, int], mode: str) -> np.ndarray: | |
| tensor = torch.from_numpy(volume).unsqueeze(0).unsqueeze(0) | |
| kwargs = {} | |
| if mode in {"linear", "bilinear", "bicubic", "trilinear"}: | |
| kwargs["align_corners"] = False | |
| tensor = F.interpolate(tensor, size=size, mode=mode, **kwargs) | |
| return tensor.squeeze(0).squeeze(0).cpu().numpy().astype(np.float32, copy=False) | |
| def resize_mask(mask: np.ndarray, size: tuple[int, int, int]) -> np.ndarray: | |
| tensor = torch.from_numpy(mask.astype(np.float32, copy=False)).unsqueeze(0).unsqueeze(0) | |
| tensor = F.interpolate(tensor, size=size, mode="nearest") | |
| return tensor.squeeze(0).squeeze(0).cpu().numpy() > 0.5 | |
| def resample_to_target_spacing( | |
| volume: np.ndarray, | |
| mask: np.ndarray, | |
| source_spacing: tuple[float, float, float], | |
| target_spacing: tuple[float, float, float] = TARGET_SPACING, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| target_shape = [] | |
| for current_size, src, dst in zip(volume.shape, source_spacing, target_spacing): | |
| target_shape.append(max(1, int(round(float(current_size) * float(src) / float(dst))))) | |
| target_shape_tuple = tuple(target_shape) | |
| if target_shape_tuple == tuple(int(v) for v in volume.shape): | |
| return volume.astype(np.float32, copy=False), mask | |
| return ( | |
| resize_volume(volume, target_shape_tuple, mode="trilinear"), | |
| resize_mask(mask, target_shape_tuple), | |
| ) | |
| def downscale_to_fit( | |
| volume: np.ndarray, | |
| mask: np.ndarray, | |
| target_shape: tuple[int, int, int] = TARGET_SHAPE, | |
| ) -> tuple[np.ndarray, np.ndarray]: | |
| current_shape = tuple(int(v) for v in volume.shape) | |
| if all(current <= target for current, target in zip(current_shape, target_shape)): | |
| return volume, mask | |
| scale = min(float(target) / float(current) for current, target in zip(current_shape, target_shape)) | |
| if scale >= 1.0: | |
| return volume, mask | |
| new_shape = tuple( | |
| min(target, max(1, int(math.floor(float(current) * scale)))) | |
| for current, target in zip(current_shape, target_shape) | |
| ) | |
| return ( | |
| resize_volume(volume, new_shape, mode="trilinear"), | |
| resize_mask(mask, new_shape), | |
| ) | |
| def center_pad( | |
| array: np.ndarray, | |
| target_shape: tuple[int, int, int] = TARGET_SHAPE, | |
| fill_value: float = BACKGROUND_VALUE, | |
| ) -> np.ndarray: | |
| if any(current > target for current, target in zip(array.shape, target_shape)): | |
| raise ValueError(f"Cannot center-pad shape {array.shape} into smaller target {target_shape}.") | |
| pad_width = [] | |
| for current, target in zip(array.shape, target_shape): | |
| delta = target - current | |
| before = delta // 2 | |
| after = delta - before | |
| pad_width.append((before, after)) | |
| return np.pad(array, pad_width=tuple(pad_width), mode="constant", constant_values=fill_value) | |
| def preprocess_image_with_foreground_mask( | |
| image_path: str | Path, | |
| *, | |
| target_shape: tuple[int, int, int] = TARGET_SHAPE, | |
| target_spacing: tuple[float, float, float] = TARGET_SPACING, | |
| crop_margin_mm: float = CROP_MARGIN_MM, | |
| foreground_threshold: float = FOREGROUND_THRESHOLD, | |
| background_value: float = BACKGROUND_VALUE, | |
| foreground_strategy: str = FOREGROUND_STRATEGY, | |
| recipe_id: str = GENERIC_RECIPE_ID, | |
| cache_version: int = GENERIC_CACHE_VERSION, | |
| ) -> dict[str, object]: | |
| image_path = Path(image_path) | |
| image = load_canonical_nifti(image_path) | |
| source_shape = tuple(int(value) for value in image.shape) | |
| source_spacing = load_image_spacing(image) | |
| volume = np.asarray(image.get_fdata(dtype=np.float32), dtype=np.float32) | |
| volume = coerce_volume_to_3d(volume) | |
| foreground_mask = build_foreground_mask(volume, threshold=foreground_threshold) | |
| cropped_volume, cropped_mask, crop_bbox = crop_volume_and_mask( | |
| volume, | |
| foreground_mask, | |
| source_spacing, | |
| margin_mm=crop_margin_mm, | |
| ) | |
| normalized_volume = normalize_foreground_only(cropped_volume, cropped_mask) | |
| resampled_volume, resampled_mask = resample_to_target_spacing( | |
| normalized_volume, | |
| cropped_mask, | |
| source_spacing=source_spacing, | |
| target_spacing=target_spacing, | |
| ) | |
| fitted_volume, fitted_mask = downscale_to_fit( | |
| resampled_volume, | |
| resampled_mask, | |
| target_shape=target_shape, | |
| ) | |
| fitted_volume = np.clip(fitted_volume, -1.0, 1.0).astype(np.float32, copy=False) | |
| fitted_volume[~fitted_mask] = float(background_value) | |
| padded_volume = center_pad( | |
| fitted_volume, | |
| target_shape=target_shape, | |
| fill_value=float(background_value), | |
| ).astype(np.float32, copy=False) | |
| pixel_values = torch.from_numpy(padded_volume).unsqueeze(0).to(dtype=torch.float16).contiguous() | |
| return { | |
| "pixel_values": pixel_values, | |
| "source_image": str(image_path), | |
| "source_shape": list(source_shape), | |
| "source_spacing": list(source_spacing), | |
| "crop_bbox": [[int(start), int(stop)] for start, stop in crop_bbox], | |
| "foreground_strategy": foreground_strategy, | |
| "recipe_id": recipe_id, | |
| "cache_version": int(cache_version), | |
| } | |
| def validate_fixed_payload( | |
| payload: Mapping[str, Any], | |
| *, | |
| target_shape: tuple[int, int, int] = TARGET_SHAPE, | |
| ) -> None: | |
| pixel_values = payload.get("pixel_values") | |
| if not isinstance(pixel_values, torch.Tensor): | |
| raise TypeError("`pixel_values` must be a torch.Tensor.") | |
| expected_shape = (1,) + tuple(target_shape) | |
| if tuple(pixel_values.shape) != expected_shape: | |
| raise ValueError(f"Expected tensor shape {expected_shape}, got {tuple(pixel_values.shape)}.") | |
| if pixel_values.dtype != torch.float16: | |
| raise ValueError(f"Expected tensor dtype torch.float16, got {pixel_values.dtype}.") | |
| if not torch.isfinite(pixel_values).all(): | |
| raise ValueError("Tensor contains non-finite values.") | |
| min_value = float(pixel_values.min().item()) | |
| max_value = float(pixel_values.max().item()) | |
| if min_value < -1.01 or max_value > 1.01: | |
| raise ValueError(f"Expected tensor values in [-1, 1]. Got min={min_value}, max={max_value}.") | |