Jolia / preprocessing_jolia.py
SovanK's picture
Upload folder using huggingface_hub
fc80bee verified
Raw
History Blame Contribute Delete
4.18 kB
"""Preprocessing for Jolia: raw CT volume -> model-ready tensor.
Reproduces the inference-time CPU transform pipeline of the Magritte
parallel-organs run:
PrepareVolume -> Resample3D(1.5mm) -> Crop3D(192) -> Pad3D(192)
-> ApplyWindowing("all" -> 11 channels)
The output is a ``(1, 11, 192, 192, 192)`` float tensor ready for
:meth:`JoliaModel.forward`. Use it directly::
from preprocessing_jolia import JoliaPreprocessor
pre = JoliaPreprocessor()
# resolution = (row_spacing, col_spacing, slice_thickness) in mm
image = pre(volume, resolution=(0.7, 0.7, 1.0)) # -> (11, 192, 192, 192)
"""
from __future__ import annotations
from typing import Union
import torch
# Works both inside a package (HF trust_remote_code) and as a top-level module
# (the `snapshot_download` + `sys.path.append` flow in the README).
try:
from .jolia_atlas_transform import (
ApplyWindowing,
AtlasTransform,
Crop3D,
Pad3D,
PrepareVolume,
Resample3D,
)
except ImportError:
from jolia_atlas_transform import (
ApplyWindowing,
AtlasTransform,
Crop3D,
Pad3D,
PrepareVolume,
Resample3D,
)
try: # numpy is optional at import time; only needed for ndarray inputs
import numpy as np
Tensorable = Union[torch.Tensor, "np.ndarray"]
except ImportError: # pragma: no cover
Tensorable = torch.Tensor # type: ignore[misc]
class JoliaPreprocessor:
"""Deterministic Atlas preprocessing matching the released checkpoint.
Args:
target_shape: Output spatial size (D, H, W). Default ``(192, 192, 192)``.
target_spacing: Resample spacing in mm. Default ``(1.5, 1.5, 1.5)``.
depth_last / flip_depth: Volume orientation handling (run defaults).
window_type: CT windowing preset(s). ``"all"`` -> 11 channels.
modality: ``"CT"``.
padding_value: HU value used to pad. Default ``-1024``.
"""
def __init__(
self,
target_shape: tuple[int, int, int] = (192, 192, 192),
target_spacing: tuple[float, float, float] = (1.5, 1.5, 1.5),
depth_last: bool = True,
flip_depth: bool = True,
window_type: str | list[str] = "all",
modality: str = "CT",
padding_value: float = -1024.0,
) -> None:
self.transform = AtlasTransform(
precomputed=False,
depth_last=depth_last,
training=False,
cpu_transforms=[
PrepareVolume(depth_last=depth_last, flip_depth=flip_depth),
Resample3D(target_spacing=target_spacing),
Crop3D(target_shape=target_shape, training=False),
Pad3D(target_shape=target_shape, padding_value=padding_value),
ApplyWindowing(window_type=window_type, modality=modality),
],
)
def __call__(
self,
volume: "Tensorable",
resolution: tuple[float, float, float] | None = None,
metadata: dict | None = None,
) -> torch.Tensor:
"""Transform one volume into a ``(11, 192, 192, 192)`` tensor.
Args:
volume: A 3D CT volume (H, W, D) in Hounsfield units (tensor or ndarray).
resolution: Voxel spacing ``(row_spacing, col_spacing, slice_thickness)``
in mm — required (the volume is resampled to 1.5 mm isotropic).
metadata: Optional raw metadata dict; ``resolution`` takes precedence.
"""
md = dict(metadata or {})
if resolution is not None:
md["resolution"] = tuple(resolution)
if "resolution" not in md:
raise ValueError(
"JoliaPreprocessor needs the voxel spacing — pass "
"resolution=(row_spacing, col_spacing, slice_thickness) in mm."
)
# Windowing emits bfloat16; cast to float32 to match the released weights.
return self.transform(volume, md).float()
@classmethod
def from_pretrained(cls, *_args: object, **_kwargs: object) -> "JoliaPreprocessor":
"""Convenience constructor (defaults match ``raidium/Jolia``)."""
return cls()