"""Mask manipulation: scaling organ regions and resolving overlaps.""" from __future__ import annotations from pathlib import Path import numpy as np from PIL import Image from scipy import ndimage from scipy.ndimage import map_coordinates def resolve_overlaps( mask: np.ndarray, priority: tuple[int, int, int] = (2, 0, 1), threshold: int = 10, ) -> np.ndarray: """Assign overlapping pixels to the highest-priority channel. Default priority: heart (2) > left lung (0) > right lung (1). """ result = mask.copy() active = mask > threshold overlap_mask = active.sum(axis=2) > 1 if not overlap_mask.any(): return result for y, x in zip(*np.where(overlap_mask)): active_channels = [ch for ch in range(3) if mask[y, x, ch] > threshold] best = min(active_channels, key=lambda ch: priority.index(ch)) for ch in active_channels: if ch != best: result[y, x, ch] = 0 return result def scale_mask_channel( mask: np.ndarray, channel: int, scale_factor: float, threshold: int = 10, ) -> np.ndarray: """Scale a single channel's region around its centroid. ``channel``: 0 = left lung (red), 1 = right lung (green), 2 = heart (blue). """ result = mask.copy() channel_data = mask[:, :, channel] binary = channel_data > threshold if not binary.any(): return result cy, cx = ndimage.center_of_mass(binary) h, w = mask.shape[:2] y_coords, x_coords = np.mgrid[0:h, 0:w] y_t = ((y_coords - cy) / scale_factor + cy).astype(np.float32) x_t = ((x_coords - cx) / scale_factor + cx).astype(np.float32) result[:, :, channel] = 0 scaled = map_coordinates( channel_data.astype(np.float32), [y_t, x_t], order=1, mode="constant", cval=0, ) result[:, :, channel] = np.clip(scaled, 0, 255).astype(np.uint8) return result def modify_mask( input_path: Path, output_path: Path, heart_scale: float = 1.0, left_lung_scale: float = 1.0, right_lung_scale: float = 1.0, ) -> None: """Load a conditioning mask, apply scale factors, and save.""" with Image.open(input_path) as img: mask = np.array(img.convert("RGB")) if left_lung_scale != 1.0: mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale) if right_lung_scale != 1.0: mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale) if heart_scale != 1.0: mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale) mask = resolve_overlaps(mask, priority=(2, 0, 1)) output_path.parent.mkdir(parents=True, exist_ok=True) Image.fromarray(mask).save(output_path) print(f"[INFO] Saved modified mask to {output_path}")