HNE2Cell / patchify.py
roobee79's picture
Upload 7 files
7747544 verified
"""
HNE2Cell — Step 2: Patch Extraction
Extract overlapping patches from color-normalized H&E images for cell detection.
Supports both 20x and 40x magnification (40x recommended for best results).
Usage:
# 40x (recommended)
python patchify.py \
--input_dir /path/to/slides \
--patch_size 256 \
--overlap 64 \
--magnification 40 \
--workers 8
# 20x (supported but 40x preferred)
python patchify.py \
--input_dir /path/to/slides \
--patch_size 256 \
--overlap 64 \
--magnification 20 \
--workers 8
Notes:
- 40x magnification is recommended for optimal cell detection accuracy.
- 20x is supported and functional, but fine-grained cell boundaries
(especially small immune cells) may be less precise.
- Input: Aligned-hne.tif (output of normalize.py)
- Output: <section>/patches_<mag>x_p<patch>_o<overlap>/<name>_<x>_<y>.png
"""
import os
import argparse
import glob
import time
from multiprocessing import Pool
import numpy as np
from PIL import Image
from tqdm import tqdm
Image.MAX_IMAGE_PIXELS = None
# ========================== Utility functions ==============================
def black_to_white(pil_img: Image.Image) -> Image.Image:
"""Replace pure-black (0,0,0) pixels with white — avoids dark-border artifacts."""
arr = np.array(pil_img)
if arr.ndim == 3 and arr.shape[2] >= 3:
mask = (arr[..., :3] == 0).all(axis=-1)
arr[mask] = 255
return Image.fromarray(arr)
def make_start_positions(length: int, patch_size: int, stride: int) -> list[int]:
"""Generate start positions so the last patch always reaches the edge."""
if length < patch_size:
return [0]
starts = list(range(0, length - patch_size + 1, stride))
last = length - patch_size
if starts[-1] != last:
starts.append(last)
return starts
# ========================== Core patching ==================================
def extract_patches(
image_path: str,
output_dir: str,
patch_size: int = 256,
overlap: int = 64,
prefix: str = "patch",
) -> int:
"""Crop overlapping patches from a single image and save as PNG.
Returns the number of patches saved.
"""
os.makedirs(output_dir, exist_ok=True)
stride = patch_size - overlap
assert stride > 0, f"overlap ({overlap}) must be < patch_size ({patch_size})"
img = Image.open(image_path).convert("RGB")
width, height = img.size
xs = make_start_positions(width, patch_size, stride)
ys = make_start_positions(height, patch_size, stride)
count = 0
with tqdm(total=len(xs) * len(ys), desc=prefix, unit="patch", leave=False) as pbar:
for x0 in xs:
for y0 in ys:
patch = img.crop((x0, y0, x0 + patch_size, y0 + patch_size))
patch = black_to_white(patch)
patch.save(
os.path.join(output_dir, f"{prefix}_{x0}_{y0}.png"),
format="PNG",
)
count += 1
pbar.update(1)
return count
# =================== Per-section processing (for Pool) =====================
# These will be set once in main() before the pool is created
_ARGS = {}
def _process_section(section_dir: str) -> str:
"""Process a single section directory. Designed for multiprocessing.Pool."""
patch_size = _ARGS["patch_size"]
overlap = _ARGS["overlap"]
magnification = _ARGS["magnification"]
input_filename = _ARGS["input_filename"]
# Locate input file
candidates = [
os.path.join(section_dir, f"{input_filename}.tif"),
os.path.join(section_dir, f"{input_filename}.tiff"),
]
image_path = next((p for p in candidates if os.path.exists(p)), None)
if image_path is None:
return f"[SKIP] {section_dir}: {input_filename}.tif not found"
stride = patch_size - overlap
out_dir = os.path.join(
section_dir,
f"patches_{magnification}x_p{patch_size}_o{overlap}",
)
section_name = os.path.basename(section_dir)
t0 = time.time()
n = extract_patches(
image_path=image_path,
output_dir=out_dir,
patch_size=patch_size,
overlap=overlap,
prefix=section_name,
)
dt = time.time() - t0
return (
f"[OK] {section_name} | {magnification}x | "
f"stride={stride} | {n} patches | {dt:.1f}s → {out_dir}"
)
# =============================== CLI =======================================
def main():
parser = argparse.ArgumentParser(
description="Extract overlapping patches from normalized H&E images"
)
parser.add_argument(
"--input_dir",
type=str,
required=True,
help="Root directory containing section folders with Aligned-hne.tif files",
)
parser.add_argument(
"--input_filename",
type=str,
default="Aligned-hne",
help="Base filename of the normalized image (default: Aligned-hne)",
)
parser.add_argument(
"--patch_size", type=int, default=256, help="Patch size in pixels (default: 256)"
)
parser.add_argument(
"--overlap", type=int, default=64, help="Overlap in pixels (default: 64)"
)
parser.add_argument(
"--magnification",
type=int,
default=40,
choices=[20, 40],
help="Slide magnification. 40x recommended; 20x supported. (default: 40)",
)
parser.add_argument(
"--pattern",
type=str,
default="*",
help="Glob pattern to match section folders (default: '*')",
)
parser.add_argument(
"--workers", type=int, default=8, help="Number of parallel workers (default: 8)"
)
args = parser.parse_args()
if args.magnification == 20:
print(
"⚠️ 20x magnification is supported but 40x is recommended for best "
"cell detection accuracy (especially small immune cells)."
)
# Collect section directories
section_dirs = sorted(
p
for p in glob.glob(os.path.join(args.input_dir, args.pattern))
if os.path.isdir(p)
)
if not section_dirs:
# Maybe input_dir itself contains the image directly
candidates = [
os.path.join(args.input_dir, f"{args.input_filename}.tif"),
os.path.join(args.input_dir, f"{args.input_filename}.tiff"),
]
if any(os.path.exists(c) for c in candidates):
section_dirs = [args.input_dir]
else:
raise SystemExit(
f"No section folders matching '{args.pattern}' found in {args.input_dir}"
)
print(f"Found {len(section_dirs)} section(s) | {args.magnification}x | "
f"patch={args.patch_size} overlap={args.overlap}")
# Set global args for worker processes
global _ARGS
_ARGS = {
"patch_size": args.patch_size,
"overlap": args.overlap,
"magnification": args.magnification,
"input_filename": args.input_filename,
}
if args.workers <= 1 or len(section_dirs) == 1:
results = [_process_section(d) for d in tqdm(section_dirs, desc="Sections")]
else:
with Pool(processes=min(args.workers, len(section_dirs))) as pool:
results = list(
tqdm(
pool.imap_unordered(_process_section, section_dirs),
total=len(section_dirs),
desc="Sections",
)
)
print("\n".join(results))
if __name__ == "__main__":
main()