""" TILA — Image Processor Single processor that handles the full pipeline: raw image (path, numpy, or PIL) → model-ready tensor [1, 3, 448, 448] Combines: 1. Medical image preprocessing (windowing, padding removal, resize) 2. Model transforms (resize, center crop, to tensor, expand channels) Usage: from processor import TILAProcessor processor = TILAProcessor() # From file path (applies full preprocessing) tensor = processor("raw_cxr.png") # From PIL image (skips medical preprocessing, applies model transforms only) tensor = processor(Image.open("preprocessed.png")) # Pair of images for the model current = processor("current.png") previous = processor("previous.png") result = model.get_interval_change_prediction(current, previous) """ import cv2 import numpy as np import torch from PIL import Image from torchvision import transforms from typing import Union from preprocess import preprocess_image class TILAProcessor: """End-to-end image processor for the TILA model. Accepts file paths (str/Path), numpy arrays, or PIL Images. - File paths: full pipeline (windowing → crop → resize → model transform) - Numpy arrays: treated as raw, full pipeline applied - PIL Images: assumed already preprocessed, only model transforms applied Args: raw_preprocess: Apply medical preprocessing (windowing, padding removal). Set False if images are already preprocessed PNGs. width_param: Windowing width parameter (default: 4.0) max_size: Resize longest side to this before model transforms (default: 512) crop_size: Center crop size for model input (default: 448) dtype: Output tensor dtype (default: torch.bfloat16) device: Output tensor device (default: "cpu") """ def __init__( self, raw_preprocess: bool = True, width_param: float = 4.0, max_size: int = 512, crop_size: int = 448, dtype: torch.dtype = torch.bfloat16, device: str = "cpu", ): self.raw_preprocess = raw_preprocess self.width_param = width_param self.max_size = max_size self.dtype = dtype self.device = device self.model_transform = transforms.Compose([ transforms.Resize(max_size), transforms.CenterCrop(crop_size), transforms.ToTensor(), _ExpandChannels(), ]) def __call__(self, image: Union[str, np.ndarray, Image.Image]) -> torch.Tensor: """Process a single image into a model-ready tensor. Args: image: File path (str), numpy array, or PIL Image Returns: Tensor of shape [1, 3, 448, 448] """ if isinstance(image, str): if self.raw_preprocess: img_np = preprocess_image(image, self.width_param, self.max_size) pil_img = Image.fromarray(img_np) else: pil_img = Image.open(image).convert("L") elif isinstance(image, np.ndarray): if self.raw_preprocess: from preprocess import apply_windowing, remove_black_padding, resize_preserve_aspect_ratio if len(image.shape) == 3: image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY) image = apply_windowing(image, self.width_param) image = (image * 255.0).astype(np.uint8) image = remove_black_padding(image) image = resize_preserve_aspect_ratio(image, self.max_size) pil_img = Image.fromarray(image) elif isinstance(image, Image.Image): pil_img = image.convert("L") else: raise TypeError(f"Expected str, np.ndarray, or PIL.Image, got {type(image)}") tensor = self.model_transform(pil_img).unsqueeze(0) return tensor.to(dtype=self.dtype, device=self.device) class _ExpandChannels: """Expand single-channel tensor to 3 channels.""" def __call__(self, x: torch.Tensor) -> torch.Tensor: if x.shape[0] == 1: return x.repeat(3, 1, 1) return x