| """ |
| TILA β Image Processor |
| |
| Single processor that handles the full pipeline: |
| raw image (path, numpy, or PIL) β model-ready tensor [1, 3, 448, 448] |
| |
| Combines: |
| 1. Medical image preprocessing (windowing, padding removal, resize) |
| 2. Model transforms (resize, center crop, to tensor, expand channels) |
| |
| Usage: |
| from processor import TILAProcessor |
| |
| processor = TILAProcessor() |
| |
| # From file path (applies full preprocessing) |
| tensor = processor("raw_cxr.png") |
| |
| # From PIL image (skips medical preprocessing, applies model transforms only) |
| tensor = processor(Image.open("preprocessed.png")) |
| |
| # Pair of images for the model |
| current = processor("current.png") |
| previous = processor("previous.png") |
| result = model.get_interval_change_prediction(current, previous) |
| """ |
|
|
| import cv2 |
| import numpy as np |
| import torch |
| from PIL import Image |
| from torchvision import transforms |
| from typing import Union |
|
|
| from preprocess import preprocess_image |
|
|
|
|
| class TILAProcessor: |
| """End-to-end image processor for the TILA model. |
| |
| Accepts file paths (str/Path), numpy arrays, or PIL Images. |
| - File paths: full pipeline (windowing β crop β resize β model transform) |
| - Numpy arrays: treated as raw, full pipeline applied |
| - PIL Images: assumed already preprocessed, only model transforms applied |
| |
| Args: |
| raw_preprocess: Apply medical preprocessing (windowing, padding removal). |
| Set False if images are already preprocessed PNGs. |
| width_param: Windowing width parameter (default: 4.0) |
| max_size: Resize longest side to this before model transforms (default: 512) |
| crop_size: Center crop size for model input (default: 448) |
| dtype: Output tensor dtype (default: torch.bfloat16) |
| device: Output tensor device (default: "cpu") |
| """ |
|
|
| def __init__( |
| self, |
| raw_preprocess: bool = True, |
| width_param: float = 4.0, |
| max_size: int = 512, |
| crop_size: int = 448, |
| dtype: torch.dtype = torch.bfloat16, |
| device: str = "cpu", |
| ): |
| self.raw_preprocess = raw_preprocess |
| self.width_param = width_param |
| self.max_size = max_size |
| self.dtype = dtype |
| self.device = device |
|
|
| self.model_transform = transforms.Compose([ |
| transforms.Resize(max_size), |
| transforms.CenterCrop(crop_size), |
| transforms.ToTensor(), |
| _ExpandChannels(), |
| ]) |
|
|
| def __call__(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor: |
| """Process a single image into a model-ready tensor. |
| |
| Args: |
| image: File path (str), numpy array, or PIL Image |
| |
| Returns: |
| Tensor of shape [1, 3, 448, 448] |
| """ |
| if isinstance(image, str): |
| if self.raw_preprocess: |
| img_np = preprocess_image(image, self.width_param, self.max_size) |
| pil_img = Image.fromarray(img_np) |
| else: |
| pil_img = Image.open(image).convert("L") |
| elif isinstance(image, np.ndarray): |
| if self.raw_preprocess: |
| from preprocess import apply_windowing, remove_black_padding, resize_preserve_aspect_ratio |
| if len(image.shape) == 3: |
| image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) |
| image = apply_windowing(image, self.width_param) |
| image = (image * 255.0).astype(np.uint8) |
| image = remove_black_padding(image) |
| image = resize_preserve_aspect_ratio(image, self.max_size) |
| pil_img = Image.fromarray(image) |
| elif isinstance(image, Image.Image): |
| pil_img = image.convert("L") |
| else: |
| raise TypeError(f"Expected str, np.ndarray, or PIL.Image, got {type(image)}") |
|
|
| tensor = self.model_transform(pil_img).unsqueeze(0) |
| return tensor.to(dtype=self.dtype, device=self.device) |
|
|
|
|
| class _ExpandChannels: |
| """Expand single-channel tensor to 3 channels.""" |
| def __call__(self, x: torch.Tensor) -> torch.Tensor: |
| if x.shape[0] == 1: |
| return x.repeat(3, 1, 1) |
| return x |
|
|