TILA / preprocess.py
lukeingawesome's picture
Upload folder using huggingface_hub
f46fb4d verified
"""
TILA — Image Preprocessing
Converts raw chest X-ray images (DICOM-derived PNGs or raw PNGs) into the
normalized format expected by the TILA model.
Pipeline:
1. Read image as-is (preserving bit depth)
2. Windowing: clip to mean +/- 2*std, normalize to [0, 1]
3. Convert to uint8
4. Remove black padding (contour-based crop)
5. Resize preserving aspect ratio (longest side = 512)
Usage:
from preprocess import preprocess_image
img = preprocess_image("raw_cxr.png")
cv2.imwrite("preprocessed.png", img)
"""
import cv2
import numpy as np
from pathlib import Path
def apply_windowing(image: np.ndarray, width_param: float = 4.0) -> np.ndarray:
"""Apply intensity windowing based on image statistics.
Clips intensities to [mean - width_param/2 * std, mean + width_param/2 * std]
and normalizes to [0, 1].
"""
image = image.astype(np.float64)
mean = np.mean(image)
std = np.std(image)
window_center = mean
window_width = width_param * std
img_min = window_center - window_width / 2
img_max = window_center + window_width / 2
image = np.clip(image, img_min, img_max)
image = (image - img_min) / (img_max - img_min + 1e-8)
return image
def remove_black_padding(image: np.ndarray) -> np.ndarray:
"""Remove black padded borders by finding the largest contour."""
_, thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY)
contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
if not contours:
return image
largest = max(contours, key=cv2.contourArea)
x, y, w, h = cv2.boundingRect(largest)
return image[y:y + h, x:x + w]
def resize_preserve_aspect_ratio(image: np.ndarray, max_size: int = 512) -> np.ndarray:
"""Resize so the longest side equals max_size, preserving aspect ratio."""
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
h, w = image.shape
if w < h:
new_w = max_size
new_h = int(new_w / (w / h))
else:
new_h = max_size
new_w = int(new_h * (w / h))
return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)
def preprocess_image(
image_path: str,
width_param: float = 4.0,
max_size: int = 512,
) -> np.ndarray:
"""Full preprocessing pipeline for a chest X-ray image.
Args:
image_path: Path to raw image (PNG, JPEG, etc.)
width_param: Windowing width in multiples of std (default: 4.0)
max_size: Target size for longest dimension (default: 512)
Returns:
Preprocessed uint8 grayscale image
"""
# IMREAD_UNCHANGED preserves bit depth (important for 16-bit DICOM-derived PNGs)
image = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
if image is None:
raise ValueError(f"Could not read image: {image_path}")
# Convert color to grayscale if needed
if len(image.shape) == 3:
image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
image = apply_windowing(image, width_param)
image = (image * 255.0).astype(np.uint8)
image = remove_black_padding(image)
image = resize_preserve_aspect_ratio(image, max_size)
return image
if __name__ == "__main__":
import argparse
parser = argparse.ArgumentParser(description="Preprocess chest X-ray images for TILA")
parser.add_argument("--input", required=True, help="Input image path")
parser.add_argument("--output", required=True, help="Output image path")
parser.add_argument("--width-param", type=float, default=4.0)
parser.add_argument("--max-size", type=int, default=512)
args = parser.parse_args()
img = preprocess_image(args.input, args.width_param, args.max_size)
cv2.imwrite(args.output, img)
print(f"Saved preprocessed image to {args.output} ({img.shape[1]}x{img.shape[0]})")