SynthCXR / synthcxr /mask_utils.py
gradientguild's picture
Upload folder using huggingface_hub
a4aa5c5 verified
"""Mask manipulation: scaling organ regions and resolving overlaps."""
from __future__ import annotations
from pathlib import Path
import numpy as np
from PIL import Image
from scipy import ndimage
from scipy.ndimage import map_coordinates
def resolve_overlaps(
mask: np.ndarray,
priority: tuple[int, int, int] = (2, 0, 1),
threshold: int = 10,
) -> np.ndarray:
"""Assign overlapping pixels to the highest-priority channel.
Default priority: heart (2) > left lung (0) > right lung (1).
"""
result = mask.copy()
active = mask > threshold
overlap_mask = active.sum(axis=2) > 1
if not overlap_mask.any():
return result
for y, x in zip(*np.where(overlap_mask)):
active_channels = [ch for ch in range(3) if mask[y, x, ch] > threshold]
best = min(active_channels, key=lambda ch: priority.index(ch))
for ch in active_channels:
if ch != best:
result[y, x, ch] = 0
return result
def scale_mask_channel(
mask: np.ndarray,
channel: int,
scale_factor: float,
threshold: int = 10,
) -> np.ndarray:
"""Scale a single channel's region around its centroid.
``channel``: 0 = left lung (red), 1 = right lung (green), 2 = heart (blue).
"""
result = mask.copy()
channel_data = mask[:, :, channel]
binary = channel_data > threshold
if not binary.any():
return result
cy, cx = ndimage.center_of_mass(binary)
h, w = mask.shape[:2]
y_coords, x_coords = np.mgrid[0:h, 0:w]
y_t = ((y_coords - cy) / scale_factor + cy).astype(np.float32)
x_t = ((x_coords - cx) / scale_factor + cx).astype(np.float32)
result[:, :, channel] = 0
scaled = map_coordinates(
channel_data.astype(np.float32),
[y_t, x_t],
order=1,
mode="constant",
cval=0,
)
result[:, :, channel] = np.clip(scaled, 0, 255).astype(np.uint8)
return result
def modify_mask(
input_path: Path,
output_path: Path,
heart_scale: float = 1.0,
left_lung_scale: float = 1.0,
right_lung_scale: float = 1.0,
) -> None:
"""Load a conditioning mask, apply scale factors, and save."""
with Image.open(input_path) as img:
mask = np.array(img.convert("RGB"))
if left_lung_scale != 1.0:
mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale)
if right_lung_scale != 1.0:
mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale)
if heart_scale != 1.0:
mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale)
mask = resolve_overlaps(mask, priority=(2, 0, 1))
output_path.parent.mkdir(parents=True, exist_ok=True)
Image.fromarray(mask).save(output_path)
print(f"[INFO] Saved modified mask to {output_path}")