File size: 4,176 Bytes
f46fb4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
"""
TILA — Image Processor

Single processor that handles the full pipeline:
    raw image (path, numpy, or PIL) → model-ready tensor [1, 3, 448, 448]

Combines:
    1. Medical image preprocessing (windowing, padding removal, resize)
    2. Model transforms (resize, center crop, to tensor, expand channels)

Usage:
    from processor import TILAProcessor

    processor = TILAProcessor()

    # From file path (applies full preprocessing)
    tensor = processor("raw_cxr.png")

    # From PIL image (skips medical preprocessing, applies model transforms only)
    tensor = processor(Image.open("preprocessed.png"))

    # Pair of images for the model
    current = processor("current.png")
    previous = processor("previous.png")
    result = model.get_interval_change_prediction(current, previous)
"""

import cv2
import numpy as np
import torch
from PIL import Image
from torchvision import transforms
from typing import Union

from preprocess import preprocess_image


class TILAProcessor:
    """End-to-end image processor for the TILA model.

    Accepts file paths (str/Path), numpy arrays, or PIL Images.
    - File paths: full pipeline (windowing → crop → resize → model transform)
    - Numpy arrays: treated as raw, full pipeline applied
    - PIL Images: assumed already preprocessed, only model transforms applied

    Args:
        raw_preprocess: Apply medical preprocessing (windowing, padding removal).
                        Set False if images are already preprocessed PNGs.
        width_param: Windowing width parameter (default: 4.0)
        max_size: Resize longest side to this before model transforms (default: 512)
        crop_size: Center crop size for model input (default: 448)
        dtype: Output tensor dtype (default: torch.bfloat16)
        device: Output tensor device (default: "cpu")
    """

    def __init__(
        self,
        raw_preprocess: bool = True,
        width_param: float = 4.0,
        max_size: int = 512,
        crop_size: int = 448,
        dtype: torch.dtype = torch.bfloat16,
        device: str = "cpu",
    ):
        self.raw_preprocess = raw_preprocess
        self.width_param = width_param
        self.max_size = max_size
        self.dtype = dtype
        self.device = device

        self.model_transform = transforms.Compose([
            transforms.Resize(max_size),
            transforms.CenterCrop(crop_size),
            transforms.ToTensor(),
            _ExpandChannels(),
        ])

    def __call__(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor:
        """Process a single image into a model-ready tensor.

        Args:
            image: File path (str), numpy array, or PIL Image

        Returns:
            Tensor of shape [1, 3, 448, 448]
        """
        if isinstance(image, str):
            if self.raw_preprocess:
                img_np = preprocess_image(image, self.width_param, self.max_size)
                pil_img = Image.fromarray(img_np)
            else:
                pil_img = Image.open(image).convert("L")
        elif isinstance(image, np.ndarray):
            if self.raw_preprocess:
                from preprocess import apply_windowing, remove_black_padding, resize_preserve_aspect_ratio
                if len(image.shape) == 3:
                    image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
                image = apply_windowing(image, self.width_param)
                image = (image * 255.0).astype(np.uint8)
                image = remove_black_padding(image)
                image = resize_preserve_aspect_ratio(image, self.max_size)
            pil_img = Image.fromarray(image)
        elif isinstance(image, Image.Image):
            pil_img = image.convert("L")
        else:
            raise TypeError(f"Expected str, np.ndarray, or PIL.Image, got {type(image)}")

        tensor = self.model_transform(pil_img).unsqueeze(0)
        return tensor.to(dtype=self.dtype, device=self.device)


class _ExpandChannels:
    """Expand single-channel tensor to 3 channels."""
    def __call__(self, x: torch.Tensor) -> torch.Tensor:
        if x.shape[0] == 1:
            return x.repeat(3, 1, 1)
        return x