""" HNE2Cell — Step 1: Reinhard Color Normalization Normalize H&E stained whole-slide images (WSI) to a reference color distribution using the Reinhard method in LAB color space. Supported input formats: .svs, .tif, .tiff, .ndpi Output: Aligned-hne.tif (full-resolution normalized), Aligned-hne.jpg (4x downsampled preview) Usage: python normalize.py \ --input_dir /path/to/slides \ --target /path/to/standard-ilc.tif \ --patch_size 128 \ --saturation_threshold 0.1 """ import os import argparse import glob import numpy as np import tifffile as tiff from PIL import Image from skimage import color Image.MAX_IMAGE_PIXELS = None os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2, 40)) # --------------------------------------------------------------------------- # Optional: openslide (only needed for .svs / .ndpi) # --------------------------------------------------------------------------- try: import openslide OPENSLIDE_AVAILABLE = True except ImportError: OPENSLIDE_AVAILABLE = False # ============================= I/O helpers ================================= def load_image(image_path: str, level: int = 0) -> np.ndarray: """Load a whole-slide image as an RGB numpy array. Supports .svs/.ndpi (via OpenSlide) and .tif/.tiff (via tifffile). """ ext = os.path.splitext(image_path)[1].lower() if ext in (".svs", ".ndpi"): if not OPENSLIDE_AVAILABLE: raise ImportError( "openslide-python is required to read .svs/.ndpi files. " "Install it with: pip install openslide-python" ) slide = openslide.OpenSlide(image_path) image = slide.read_region((0, 0), level, slide.level_dimensions[level]) image = image.convert("RGB") slide.close() return np.array(image) if ext in (".tif", ".tiff"): image = tiff.imread(image_path) if image.ndim == 2: image = np.stack((image,) * 3, axis=-1) elif image.ndim == 4 and image.shape[0] == 1: image = image[0] # Ensure RGB uint8 if image.dtype != np.uint8: image = np.clip(image, 0, 255).astype(np.uint8) return image raise ValueError(f"Unsupported file format: {ext}") # ======================== Saturation filtering ============================= def calculate_saturation(patch: Image.Image) -> float: hsv = patch.convert("HSV") return np.mean(np.array(hsv)[:, :, 1] / 255.0) def extract_high_saturation_patches( image: np.ndarray, patch_size: int, saturation_threshold: float ) -> list: """Return list of ((x0, y0), patch_array) for patches above the saturation threshold.""" pil_img = Image.fromarray(image) width, height = pil_img.size patches = [] for i in range(width // patch_size): for j in range(height // patch_size): x0, y0 = i * patch_size, j * patch_size patch = pil_img.crop((x0, y0, x0 + patch_size, y0 + patch_size)) if calculate_saturation(patch) >= saturation_threshold: patches.append(((x0, y0), np.array(patch))) return patches def reconstruct_from_patches( width: int, height: int, patch_size: int, patches: list ) -> np.ndarray: """Place high-saturation patches back into a blank canvas (background = black).""" canvas = np.zeros((height, width, 3), dtype=np.uint8) for (x0, y0), arr in patches: if arr.shape == (patch_size, patch_size, 3): canvas[y0 : y0 + patch_size, x0 : x0 + patch_size, :] = arr return canvas # =================== Reinhard color normalization ========================== def _color_convert_chunked(image, func, chunk_size=16384): """Apply color conversion function in spatial chunks to limit memory.""" h, w, _ = image.shape out = np.zeros_like(image, dtype=np.float32) for i in range(0, h, chunk_size): for j in range(0, w, chunk_size): out[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :] = func( image[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :] ) return out def reinhard_normalize(source: np.ndarray, target: np.ndarray) -> np.ndarray: """Reinhard color normalization in LAB space. Only non-zero (tissue) pixels are used for statistics. Returns float64 image in [0, 1] range. """ src_lab = _color_convert_chunked(source, color.rgb2lab) tgt_lab = color.rgb2lab(target) for ch in range(3): src_ch = src_lab[:, :, ch] tgt_ch = tgt_lab[:, :, ch] src_vals = src_ch[src_ch != 0] tgt_vals = tgt_ch[tgt_ch != 0] if len(src_vals) == 0 or len(tgt_vals) == 0: continue src_mean, src_std = src_vals.mean(), src_vals.std() tgt_mean, tgt_std = tgt_vals.mean(), tgt_vals.std() if src_std < 1e-6: continue src_lab[:, :, ch] = np.where( src_ch != 0, (src_ch - src_mean) * (tgt_std / src_std) + tgt_mean, 0, ) return _color_convert_chunked(src_lab, color.lab2rgb) # ============================= Main pipeline =============================== def normalize_slide( slide_path: str, target_image: np.ndarray, patch_size: int = 128, saturation_threshold: float = 0.1, output_dir: str | None = None, skip_existing: bool = True, ): """Full normalization pipeline for a single slide.""" if output_dir is None: output_dir = os.path.dirname(slide_path) output_tif = os.path.join(output_dir, "Aligned-hne.tif") output_jpg = os.path.join(output_dir, "Aligned-hne.jpg") if skip_existing and os.path.exists(output_tif): print(f"[SKIP] {slide_path} — Aligned-hne.tif already exists.") return print(f"[LOAD] {slide_path}") raw = load_image(slide_path) h, w = raw.shape[:2] # 1. Saturation-based tissue detection patches = extract_high_saturation_patches( raw, patch_size, saturation_threshold ) reconstructed = reconstruct_from_patches(w, h, patch_size, patches) # 2. (Optional) save intermediate reconstruction recon_path = os.path.join(output_dir, "recon.tif") bigtiff = reconstructed.nbytes > 4 * 1024**3 tiff.imwrite(recon_path, reconstructed, bigtiff=bigtiff) # 3. Reinhard normalization normalized = reinhard_normalize(reconstructed, target_image) normalized_u8 = (normalized * 255).astype(np.uint8) # 4. Save outputs tiff.imwrite(output_tif, normalized_u8, bigtiff=bigtiff) resized = Image.fromarray(normalized_u8).resize( (w // 4, h // 4), Image.LANCZOS ) resized.save(output_jpg, quality=90) print(f"[DONE] {slide_path} → {output_tif}") # =============================== CLI ======================================= def main(): parser = argparse.ArgumentParser( description="Reinhard color normalization for H&E WSIs" ) parser.add_argument( "--input_dir", type=str, required=True, help="Root directory to search for slide files (.svs, .tif, .tiff, .ndpi)", ) parser.add_argument( "--target", type=str, required=True, help="Path to the reference/target image (.tif)", ) parser.add_argument("--patch_size", type=int, default=128) parser.add_argument("--saturation_threshold", type=float, default=0.1) parser.add_argument( "--output_dir", type=str, default=None, help="If set, all outputs go here. Otherwise, outputs are saved next to each slide.", ) args = parser.parse_args() # Load target image once target_image = load_image(args.target) # Collect slides extensions = ("*.svs", "*.tif", "*.tiff", "*.ndpi") slides = [] for ext in extensions: slides.extend(glob.glob(os.path.join(args.input_dir, "**", ext), recursive=True)) # Exclude files that are already outputs slides = [ s for s in slides if os.path.basename(s) not in ("Aligned-hne.tif", "Aligned-hne.tiff", "recon.tif") ] print(f"Found {len(slides)} slide(s) in {args.input_dir}") for slide_path in slides: try: normalize_slide( slide_path, target_image, patch_size=args.patch_size, saturation_threshold=args.saturation_threshold, output_dir=args.output_dir, ) except Exception as e: print(f"[ERROR] {slide_path}: {e}") if __name__ == "__main__": main()