"""Preprocessing for Jolia: raw CT volume -> model-ready tensor. Reproduces the inference-time CPU transform pipeline of the Magritte parallel-organs run: PrepareVolume -> Resample3D(1.5mm) -> Crop3D(192) -> Pad3D(192) -> ApplyWindowing("all" -> 11 channels) The output is a ``(1, 11, 192, 192, 192)`` float tensor ready for :meth:`JoliaModel.forward`. Use it directly:: from preprocessing_jolia import JoliaPreprocessor pre = JoliaPreprocessor() # resolution = (row_spacing, col_spacing, slice_thickness) in mm image = pre(volume, resolution=(0.7, 0.7, 1.0)) # -> (11, 192, 192, 192) """ from __future__ import annotations from typing import Union import torch # Works both inside a package (HF trust_remote_code) and as a top-level module # (the `snapshot_download` + `sys.path.append` flow in the README). try: from .jolia_atlas_transform import ( ApplyWindowing, AtlasTransform, Crop3D, Pad3D, PrepareVolume, Resample3D, ) except ImportError: from jolia_atlas_transform import ( ApplyWindowing, AtlasTransform, Crop3D, Pad3D, PrepareVolume, Resample3D, ) try: # numpy is optional at import time; only needed for ndarray inputs import numpy as np Tensorable = Union[torch.Tensor, "np.ndarray"] except ImportError: # pragma: no cover Tensorable = torch.Tensor # type: ignore[misc] class JoliaPreprocessor: """Deterministic Atlas preprocessing matching the released checkpoint. Args: target_shape: Output spatial size (D, H, W). Default ``(192, 192, 192)``. target_spacing: Resample spacing in mm. Default ``(1.5, 1.5, 1.5)``. depth_last / flip_depth: Volume orientation handling (run defaults). window_type: CT windowing preset(s). ``"all"`` -> 11 channels. modality: ``"CT"``. padding_value: HU value used to pad. Default ``-1024``. """ def __init__( self, target_shape: tuple[int, int, int] = (192, 192, 192), target_spacing: tuple[float, float, float] = (1.5, 1.5, 1.5), depth_last: bool = True, flip_depth: bool = True, window_type: str | list[str] = "all", modality: str = "CT", padding_value: float = -1024.0, ) -> None: self.transform = AtlasTransform( precomputed=False, depth_last=depth_last, training=False, cpu_transforms=[ PrepareVolume(depth_last=depth_last, flip_depth=flip_depth), Resample3D(target_spacing=target_spacing), Crop3D(target_shape=target_shape, training=False), Pad3D(target_shape=target_shape, padding_value=padding_value), ApplyWindowing(window_type=window_type, modality=modality), ], ) def __call__( self, volume: "Tensorable", resolution: tuple[float, float, float] | None = None, metadata: dict | None = None, ) -> torch.Tensor: """Transform one volume into a ``(11, 192, 192, 192)`` tensor. Args: volume: A 3D CT volume (H, W, D) in Hounsfield units (tensor or ndarray). resolution: Voxel spacing ``(row_spacing, col_spacing, slice_thickness)`` in mm — required (the volume is resampled to 1.5 mm isotropic). metadata: Optional raw metadata dict; ``resolution`` takes precedence. """ md = dict(metadata or {}) if resolution is not None: md["resolution"] = tuple(resolution) if "resolution" not in md: raise ValueError( "JoliaPreprocessor needs the voxel spacing — pass " "resolution=(row_spacing, col_spacing, slice_thickness) in mm." ) # Windowing emits bfloat16; cast to float32 to match the released weights. return self.transform(volume, md).float() @classmethod def from_pretrained(cls, *_args: object, **_kwargs: object) -> "JoliaPreprocessor": """Convenience constructor (defaults match ``raidium/Jolia``).""" return cls()