HNE2Cell / normalize.py
roobee79's picture
Upload 7 files
7747544 verified
"""
HNE2Cell — Step 1: Reinhard Color Normalization
Normalize H&E stained whole-slide images (WSI) to a reference color distribution
using the Reinhard method in LAB color space.
Supported input formats: .svs, .tif, .tiff, .ndpi
Output: Aligned-hne.tif (full-resolution normalized), Aligned-hne.jpg (4x downsampled preview)
Usage:
python normalize.py \
--input_dir /path/to/slides \
--target /path/to/standard-ilc.tif \
--patch_size 128 \
--saturation_threshold 0.1
"""
import os
import argparse
import glob
import numpy as np
import tifffile as tiff
from PIL import Image
from skimage import color
Image.MAX_IMAGE_PIXELS = None
os.environ["OPENCV_IO_MAX_IMAGE_PIXELS"] = str(pow(2, 40))
# ---------------------------------------------------------------------------
# Optional: openslide (only needed for .svs / .ndpi)
# ---------------------------------------------------------------------------
try:
import openslide
OPENSLIDE_AVAILABLE = True
except ImportError:
OPENSLIDE_AVAILABLE = False
# ============================= I/O helpers =================================
def load_image(image_path: str, level: int = 0) -> np.ndarray:
"""Load a whole-slide image as an RGB numpy array.
Supports .svs/.ndpi (via OpenSlide) and .tif/.tiff (via tifffile).
"""
ext = os.path.splitext(image_path)[1].lower()
if ext in (".svs", ".ndpi"):
if not OPENSLIDE_AVAILABLE:
raise ImportError(
"openslide-python is required to read .svs/.ndpi files. "
"Install it with: pip install openslide-python"
)
slide = openslide.OpenSlide(image_path)
image = slide.read_region((0, 0), level, slide.level_dimensions[level])
image = image.convert("RGB")
slide.close()
return np.array(image)
if ext in (".tif", ".tiff"):
image = tiff.imread(image_path)
if image.ndim == 2:
image = np.stack((image,) * 3, axis=-1)
elif image.ndim == 4 and image.shape[0] == 1:
image = image[0]
# Ensure RGB uint8
if image.dtype != np.uint8:
image = np.clip(image, 0, 255).astype(np.uint8)
return image
raise ValueError(f"Unsupported file format: {ext}")
# ======================== Saturation filtering =============================
def calculate_saturation(patch: Image.Image) -> float:
hsv = patch.convert("HSV")
return np.mean(np.array(hsv)[:, :, 1] / 255.0)
def extract_high_saturation_patches(
image: np.ndarray, patch_size: int, saturation_threshold: float
) -> list:
"""Return list of ((x0, y0), patch_array) for patches above the saturation threshold."""
pil_img = Image.fromarray(image)
width, height = pil_img.size
patches = []
for i in range(width // patch_size):
for j in range(height // patch_size):
x0, y0 = i * patch_size, j * patch_size
patch = pil_img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
if calculate_saturation(patch) >= saturation_threshold:
patches.append(((x0, y0), np.array(patch)))
return patches
def reconstruct_from_patches(
width: int, height: int, patch_size: int, patches: list
) -> np.ndarray:
"""Place high-saturation patches back into a blank canvas (background = black)."""
canvas = np.zeros((height, width, 3), dtype=np.uint8)
for (x0, y0), arr in patches:
if arr.shape == (patch_size, patch_size, 3):
canvas[y0 : y0 + patch_size, x0 : x0 + patch_size, :] = arr
return canvas
# =================== Reinhard color normalization ==========================
def _color_convert_chunked(image, func, chunk_size=16384):
"""Apply color conversion function in spatial chunks to limit memory."""
h, w, _ = image.shape
out = np.zeros_like(image, dtype=np.float32)
for i in range(0, h, chunk_size):
for j in range(0, w, chunk_size):
out[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :] = func(
image[i : min(i + chunk_size, h), j : min(j + chunk_size, w), :]
)
return out
def reinhard_normalize(source: np.ndarray, target: np.ndarray) -> np.ndarray:
"""Reinhard color normalization in LAB space.
Only non-zero (tissue) pixels are used for statistics.
Returns float64 image in [0, 1] range.
"""
src_lab = _color_convert_chunked(source, color.rgb2lab)
tgt_lab = color.rgb2lab(target)
for ch in range(3):
src_ch = src_lab[:, :, ch]
tgt_ch = tgt_lab[:, :, ch]
src_vals = src_ch[src_ch != 0]
tgt_vals = tgt_ch[tgt_ch != 0]
if len(src_vals) == 0 or len(tgt_vals) == 0:
continue
src_mean, src_std = src_vals.mean(), src_vals.std()
tgt_mean, tgt_std = tgt_vals.mean(), tgt_vals.std()
if src_std < 1e-6:
continue
src_lab[:, :, ch] = np.where(
src_ch != 0,
(src_ch - src_mean) * (tgt_std / src_std) + tgt_mean,
0,
)
return _color_convert_chunked(src_lab, color.lab2rgb)
# ============================= Main pipeline ===============================
def normalize_slide(
slide_path: str,
target_image: np.ndarray,
patch_size: int = 128,
saturation_threshold: float = 0.1,
output_dir: str | None = None,
skip_existing: bool = True,
):
"""Full normalization pipeline for a single slide."""
if output_dir is None:
output_dir = os.path.dirname(slide_path)
output_tif = os.path.join(output_dir, "Aligned-hne.tif")
output_jpg = os.path.join(output_dir, "Aligned-hne.jpg")
if skip_existing and os.path.exists(output_tif):
print(f"[SKIP] {slide_path} — Aligned-hne.tif already exists.")
return
print(f"[LOAD] {slide_path}")
raw = load_image(slide_path)
h, w = raw.shape[:2]
# 1. Saturation-based tissue detection
patches = extract_high_saturation_patches(
raw, patch_size, saturation_threshold
)
reconstructed = reconstruct_from_patches(w, h, patch_size, patches)
# 2. (Optional) save intermediate reconstruction
recon_path = os.path.join(output_dir, "recon.tif")
bigtiff = reconstructed.nbytes > 4 * 1024**3
tiff.imwrite(recon_path, reconstructed, bigtiff=bigtiff)
# 3. Reinhard normalization
normalized = reinhard_normalize(reconstructed, target_image)
normalized_u8 = (normalized * 255).astype(np.uint8)
# 4. Save outputs
tiff.imwrite(output_tif, normalized_u8, bigtiff=bigtiff)
resized = Image.fromarray(normalized_u8).resize(
(w // 4, h // 4), Image.LANCZOS
)
resized.save(output_jpg, quality=90)
print(f"[DONE] {slide_path}{output_tif}")
# =============================== CLI =======================================
def main():
parser = argparse.ArgumentParser(
description="Reinhard color normalization for H&E WSIs"
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Root directory to search for slide files (.svs, .tif, .tiff, .ndpi)",
)
parser.add_argument(
"--target",
type=str,
required=True,
help="Path to the reference/target image (.tif)",
)
parser.add_argument("--patch_size", type=int, default=128)
parser.add_argument("--saturation_threshold", type=float, default=0.1)
parser.add_argument(
"--output_dir",
type=str,
default=None,
help="If set, all outputs go here. Otherwise, outputs are saved next to each slide.",
)
args = parser.parse_args()
# Load target image once
target_image = load_image(args.target)
# Collect slides
extensions = ("*.svs", "*.tif", "*.tiff", "*.ndpi")
slides = []
for ext in extensions:
slides.extend(glob.glob(os.path.join(args.input_dir, "**", ext), recursive=True))
# Exclude files that are already outputs
slides = [
s
for s in slides
if os.path.basename(s) not in ("Aligned-hne.tif", "Aligned-hne.tiff", "recon.tif")
]
print(f"Found {len(slides)} slide(s) in {args.input_dir}")
for slide_path in slides:
try:
normalize_slide(
slide_path,
target_image,
patch_size=args.patch_size,
saturation_threshold=args.saturation_threshold,
output_dir=args.output_dir,
)
except Exception as e:
print(f"[ERROR] {slide_path}: {e}")
if __name__ == "__main__":
main()