File size: 4,177 Bytes
0129eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
fc80bee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0129eb2
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
"""Preprocessing for Jolia: raw CT volume -> model-ready tensor.

Reproduces the inference-time CPU transform pipeline of the Magritte
parallel-organs run:

    PrepareVolume -> Resample3D(1.5mm) -> Crop3D(192) -> Pad3D(192)
    -> ApplyWindowing("all" -> 11 channels)

The output is a ``(1, 11, 192, 192, 192)`` float tensor ready for
:meth:`JoliaModel.forward`. Use it directly::

    from preprocessing_jolia import JoliaPreprocessor
    pre = JoliaPreprocessor()
    # resolution = (row_spacing, col_spacing, slice_thickness) in mm
    image = pre(volume, resolution=(0.7, 0.7, 1.0))   # -> (11, 192, 192, 192)
"""

from __future__ import annotations

from typing import Union

import torch

# Works both inside a package (HF trust_remote_code) and as a top-level module
# (the `snapshot_download` + `sys.path.append` flow in the README).
try:
    from .jolia_atlas_transform import (
        ApplyWindowing,
        AtlasTransform,
        Crop3D,
        Pad3D,
        PrepareVolume,
        Resample3D,
    )
except ImportError:
    from jolia_atlas_transform import (
        ApplyWindowing,
        AtlasTransform,
        Crop3D,
        Pad3D,
        PrepareVolume,
        Resample3D,
    )

try:  # numpy is optional at import time; only needed for ndarray inputs
    import numpy as np

    Tensorable = Union[torch.Tensor, "np.ndarray"]
except ImportError:  # pragma: no cover
    Tensorable = torch.Tensor  # type: ignore[misc]


class JoliaPreprocessor:
    """Deterministic Atlas preprocessing matching the released checkpoint.

    Args:
        target_shape: Output spatial size (D, H, W). Default ``(192, 192, 192)``.
        target_spacing: Resample spacing in mm. Default ``(1.5, 1.5, 1.5)``.
        depth_last / flip_depth: Volume orientation handling (run defaults).
        window_type: CT windowing preset(s). ``"all"`` -> 11 channels.
        modality: ``"CT"``.
        padding_value: HU value used to pad. Default ``-1024``.
    """

    def __init__(
        self,
        target_shape: tuple[int, int, int] = (192, 192, 192),
        target_spacing: tuple[float, float, float] = (1.5, 1.5, 1.5),
        depth_last: bool = True,
        flip_depth: bool = True,
        window_type: str | list[str] = "all",
        modality: str = "CT",
        padding_value: float = -1024.0,
    ) -> None:
        self.transform = AtlasTransform(
            precomputed=False,
            depth_last=depth_last,
            training=False,
            cpu_transforms=[
                PrepareVolume(depth_last=depth_last, flip_depth=flip_depth),
                Resample3D(target_spacing=target_spacing),
                Crop3D(target_shape=target_shape, training=False),
                Pad3D(target_shape=target_shape, padding_value=padding_value),
                ApplyWindowing(window_type=window_type, modality=modality),
            ],
        )

    def __call__(
        self,
        volume: "Tensorable",
        resolution: tuple[float, float, float] | None = None,
        metadata: dict | None = None,
    ) -> torch.Tensor:
        """Transform one volume into a ``(11, 192, 192, 192)`` tensor.

        Args:
            volume: A 3D CT volume (H, W, D) in Hounsfield units (tensor or ndarray).
            resolution: Voxel spacing ``(row_spacing, col_spacing, slice_thickness)``
                in mm — required (the volume is resampled to 1.5 mm isotropic).
            metadata: Optional raw metadata dict; ``resolution`` takes precedence.
        """
        md = dict(metadata or {})
        if resolution is not None:
            md["resolution"] = tuple(resolution)
        if "resolution" not in md:
            raise ValueError(
                "JoliaPreprocessor needs the voxel spacing — pass "
                "resolution=(row_spacing, col_spacing, slice_thickness) in mm."
            )
        # Windowing emits bfloat16; cast to float32 to match the released weights.
        return self.transform(volume, md).float()

    @classmethod
    def from_pretrained(cls, *_args: object, **_kwargs: object) -> "JoliaPreprocessor":
        """Convenience constructor (defaults match ``raidium/Jolia``)."""
        return cls()