Spaces:
Running on Zero
Running on Zero
| """Mask manipulation: scaling organ regions and resolving overlaps.""" | |
| from __future__ import annotations | |
| from pathlib import Path | |
| import numpy as np | |
| from PIL import Image | |
| from scipy import ndimage | |
| from scipy.ndimage import map_coordinates | |
| def resolve_overlaps( | |
| mask: np.ndarray, | |
| priority: tuple[int, int, int] = (2, 0, 1), | |
| threshold: int = 10, | |
| ) -> np.ndarray: | |
| """Assign overlapping pixels to the highest-priority channel. | |
| Default priority: heart (2) > left lung (0) > right lung (1). | |
| """ | |
| result = mask.copy() | |
| active = mask > threshold | |
| overlap_mask = active.sum(axis=2) > 1 | |
| if not overlap_mask.any(): | |
| return result | |
| for y, x in zip(*np.where(overlap_mask)): | |
| active_channels = [ch for ch in range(3) if mask[y, x, ch] > threshold] | |
| best = min(active_channels, key=lambda ch: priority.index(ch)) | |
| for ch in active_channels: | |
| if ch != best: | |
| result[y, x, ch] = 0 | |
| return result | |
| def scale_mask_channel( | |
| mask: np.ndarray, | |
| channel: int, | |
| scale_factor: float, | |
| threshold: int = 10, | |
| ) -> np.ndarray: | |
| """Scale a single channel's region around its centroid. | |
| ``channel``: 0 = left lung (red), 1 = right lung (green), 2 = heart (blue). | |
| """ | |
| result = mask.copy() | |
| channel_data = mask[:, :, channel] | |
| binary = channel_data > threshold | |
| if not binary.any(): | |
| return result | |
| cy, cx = ndimage.center_of_mass(binary) | |
| h, w = mask.shape[:2] | |
| y_coords, x_coords = np.mgrid[0:h, 0:w] | |
| y_t = ((y_coords - cy) / scale_factor + cy).astype(np.float32) | |
| x_t = ((x_coords - cx) / scale_factor + cx).astype(np.float32) | |
| result[:, :, channel] = 0 | |
| scaled = map_coordinates( | |
| channel_data.astype(np.float32), | |
| [y_t, x_t], | |
| order=1, | |
| mode="constant", | |
| cval=0, | |
| ) | |
| result[:, :, channel] = np.clip(scaled, 0, 255).astype(np.uint8) | |
| return result | |
| def modify_mask( | |
| input_path: Path, | |
| output_path: Path, | |
| heart_scale: float = 1.0, | |
| left_lung_scale: float = 1.0, | |
| right_lung_scale: float = 1.0, | |
| ) -> None: | |
| """Load a conditioning mask, apply scale factors, and save.""" | |
| with Image.open(input_path) as img: | |
| mask = np.array(img.convert("RGB")) | |
| if left_lung_scale != 1.0: | |
| mask = scale_mask_channel(mask, channel=0, scale_factor=left_lung_scale) | |
| if right_lung_scale != 1.0: | |
| mask = scale_mask_channel(mask, channel=1, scale_factor=right_lung_scale) | |
| if heart_scale != 1.0: | |
| mask = scale_mask_channel(mask, channel=2, scale_factor=heart_scale) | |
| mask = resolve_overlaps(mask, priority=(2, 0, 1)) | |
| output_path.parent.mkdir(parents=True, exist_ok=True) | |
| Image.fromarray(mask).save(output_path) | |
| print(f"[INFO] Saved modified mask to {output_path}") | |