File size: 3,846 Bytes
f46fb4d
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
"""
TILA — Image Preprocessing

Converts raw chest X-ray images (DICOM-derived PNGs or raw PNGs) into the
normalized format expected by the TILA model.

Pipeline:
    1. Read image as-is (preserving bit depth)
    2. Windowing: clip to mean +/- 2*std, normalize to [0, 1]
    3. Convert to uint8
    4. Remove black padding (contour-based crop)
    5. Resize preserving aspect ratio (longest side = 512)

Usage:
    from preprocess import preprocess_image

    img = preprocess_image("raw_cxr.png")
    cv2.imwrite("preprocessed.png", img)
"""

import cv2
import numpy as np
from pathlib import Path


def apply_windowing(image: np.ndarray, width_param: float = 4.0) -> np.ndarray:
    """Apply intensity windowing based on image statistics.

    Clips intensities to [mean - width_param/2 * std, mean + width_param/2 * std]
    and normalizes to [0, 1].
    """
    image = image.astype(np.float64)
    mean = np.mean(image)
    std = np.std(image)
    window_center = mean
    window_width = width_param * std
    img_min = window_center - window_width / 2
    img_max = window_center + window_width / 2
    image = np.clip(image, img_min, img_max)
    image = (image - img_min) / (img_max - img_min + 1e-8)
    return image


def remove_black_padding(image: np.ndarray) -> np.ndarray:
    """Remove black padded borders by finding the largest contour."""
    _, thresh = cv2.threshold(image, 0, 255, cv2.THRESH_BINARY)
    contours, _ = cv2.findContours(thresh, cv2.RETR_EXTERNAL, cv2.CHAIN_APPROX_SIMPLE)
    if not contours:
        return image
    largest = max(contours, key=cv2.contourArea)
    x, y, w, h = cv2.boundingRect(largest)
    return image[y:y + h, x:x + w]


def resize_preserve_aspect_ratio(image: np.ndarray, max_size: int = 512) -> np.ndarray:
    """Resize so the longest side equals max_size, preserving aspect ratio."""
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)
    h, w = image.shape
    if w < h:
        new_w = max_size
        new_h = int(new_w / (w / h))
    else:
        new_h = max_size
        new_w = int(new_h * (w / h))
    return cv2.resize(image, (new_w, new_h), interpolation=cv2.INTER_AREA)


def preprocess_image(
    image_path: str,
    width_param: float = 4.0,
    max_size: int = 512,
) -> np.ndarray:
    """Full preprocessing pipeline for a chest X-ray image.

    Args:
        image_path: Path to raw image (PNG, JPEG, etc.)
        width_param: Windowing width in multiples of std (default: 4.0)
        max_size: Target size for longest dimension (default: 512)

    Returns:
        Preprocessed uint8 grayscale image
    """
    # IMREAD_UNCHANGED preserves bit depth (important for 16-bit DICOM-derived PNGs)
    image = cv2.imread(str(image_path), cv2.IMREAD_UNCHANGED)
    if image is None:
        raise ValueError(f"Could not read image: {image_path}")
    # Convert color to grayscale if needed
    if len(image.shape) == 3:
        image = cv2.cvtColor(image, cv2.COLOR_BGR2GRAY)

    image = apply_windowing(image, width_param)
    image = (image * 255.0).astype(np.uint8)
    image = remove_black_padding(image)
    image = resize_preserve_aspect_ratio(image, max_size)
    return image


if __name__ == "__main__":
    import argparse

    parser = argparse.ArgumentParser(description="Preprocess chest X-ray images for TILA")
    parser.add_argument("--input", required=True, help="Input image path")
    parser.add_argument("--output", required=True, help="Output image path")
    parser.add_argument("--width-param", type=float, default=4.0)
    parser.add_argument("--max-size", type=int, default=512)
    args = parser.parse_args()

    img = preprocess_image(args.input, args.width_param, args.max_size)
    cv2.imwrite(args.output, img)
    print(f"Saved preprocessed image to {args.output} ({img.shape[1]}x{img.shape[0]})")