""" HNE2Cell — Step 2: Patch Extraction Extract overlapping patches from color-normalized H&E images for cell detection. Supports both 20x and 40x magnification (40x recommended for best results). Usage: # 40x (recommended) python patchify.py \ --input_dir /path/to/slides \ --patch_size 256 \ --overlap 64 \ --magnification 40 \ --workers 8 # 20x (supported but 40x preferred) python patchify.py \ --input_dir /path/to/slides \ --patch_size 256 \ --overlap 64 \ --magnification 20 \ --workers 8 Notes: - 40x magnification is recommended for optimal cell detection accuracy. - 20x is supported and functional, but fine-grained cell boundaries (especially small immune cells) may be less precise. - Input: Aligned-hne.tif (output of normalize.py) - Output:
/patches_x_p_o/__.png """ import os import argparse import glob import time from multiprocessing import Pool import numpy as np from PIL import Image from tqdm import tqdm Image.MAX_IMAGE_PIXELS = None # ========================== Utility functions ============================== def black_to_white(pil_img: Image.Image) -> Image.Image: """Replace pure-black (0,0,0) pixels with white — avoids dark-border artifacts.""" arr = np.array(pil_img) if arr.ndim == 3 and arr.shape[2] >= 3: mask = (arr[..., :3] == 0).all(axis=-1) arr[mask] = 255 return Image.fromarray(arr) def make_start_positions(length: int, patch_size: int, stride: int) -> list[int]: """Generate start positions so the last patch always reaches the edge.""" if length < patch_size: return [0] starts = list(range(0, length - patch_size + 1, stride)) last = length - patch_size if starts[-1] != last: starts.append(last) return starts # ========================== Core patching ================================== def extract_patches( image_path: str, output_dir: str, patch_size: int = 256, overlap: int = 64, prefix: str = "patch", ) -> int: """Crop overlapping patches from a single image and save as PNG. Returns the number of patches saved. """ os.makedirs(output_dir, exist_ok=True) stride = patch_size - overlap assert stride > 0, f"overlap ({overlap}) must be < patch_size ({patch_size})" img = Image.open(image_path).convert("RGB") width, height = img.size xs = make_start_positions(width, patch_size, stride) ys = make_start_positions(height, patch_size, stride) count = 0 with tqdm(total=len(xs) * len(ys), desc=prefix, unit="patch", leave=False) as pbar: for x0 in xs: for y0 in ys: patch = img.crop((x0, y0, x0 + patch_size, y0 + patch_size)) patch = black_to_white(patch) patch.save( os.path.join(output_dir, f"{prefix}_{x0}_{y0}.png"), format="PNG", ) count += 1 pbar.update(1) return count # =================== Per-section processing (for Pool) ===================== # These will be set once in main() before the pool is created _ARGS = {} def _process_section(section_dir: str) -> str: """Process a single section directory. Designed for multiprocessing.Pool.""" patch_size = _ARGS["patch_size"] overlap = _ARGS["overlap"] magnification = _ARGS["magnification"] input_filename = _ARGS["input_filename"] # Locate input file candidates = [ os.path.join(section_dir, f"{input_filename}.tif"), os.path.join(section_dir, f"{input_filename}.tiff"), ] image_path = next((p for p in candidates if os.path.exists(p)), None) if image_path is None: return f"[SKIP] {section_dir}: {input_filename}.tif not found" stride = patch_size - overlap out_dir = os.path.join( section_dir, f"patches_{magnification}x_p{patch_size}_o{overlap}", ) section_name = os.path.basename(section_dir) t0 = time.time() n = extract_patches( image_path=image_path, output_dir=out_dir, patch_size=patch_size, overlap=overlap, prefix=section_name, ) dt = time.time() - t0 return ( f"[OK] {section_name} | {magnification}x | " f"stride={stride} | {n} patches | {dt:.1f}s → {out_dir}" ) # =============================== CLI ======================================= def main(): parser = argparse.ArgumentParser( description="Extract overlapping patches from normalized H&E images" ) parser.add_argument( "--input_dir", type=str, required=True, help="Root directory containing section folders with Aligned-hne.tif files", ) parser.add_argument( "--input_filename", type=str, default="Aligned-hne", help="Base filename of the normalized image (default: Aligned-hne)", ) parser.add_argument( "--patch_size", type=int, default=256, help="Patch size in pixels (default: 256)" ) parser.add_argument( "--overlap", type=int, default=64, help="Overlap in pixels (default: 64)" ) parser.add_argument( "--magnification", type=int, default=40, choices=[20, 40], help="Slide magnification. 40x recommended; 20x supported. (default: 40)", ) parser.add_argument( "--pattern", type=str, default="*", help="Glob pattern to match section folders (default: '*')", ) parser.add_argument( "--workers", type=int, default=8, help="Number of parallel workers (default: 8)" ) args = parser.parse_args() if args.magnification == 20: print( "⚠️ 20x magnification is supported but 40x is recommended for best " "cell detection accuracy (especially small immune cells)." ) # Collect section directories section_dirs = sorted( p for p in glob.glob(os.path.join(args.input_dir, args.pattern)) if os.path.isdir(p) ) if not section_dirs: # Maybe input_dir itself contains the image directly candidates = [ os.path.join(args.input_dir, f"{args.input_filename}.tif"), os.path.join(args.input_dir, f"{args.input_filename}.tiff"), ] if any(os.path.exists(c) for c in candidates): section_dirs = [args.input_dir] else: raise SystemExit( f"No section folders matching '{args.pattern}' found in {args.input_dir}" ) print(f"Found {len(section_dirs)} section(s) | {args.magnification}x | " f"patch={args.patch_size} overlap={args.overlap}") # Set global args for worker processes global _ARGS _ARGS = { "patch_size": args.patch_size, "overlap": args.overlap, "magnification": args.magnification, "input_filename": args.input_filename, } if args.workers <= 1 or len(section_dirs) == 1: results = [_process_section(d) for d in tqdm(section_dirs, desc="Sections")] else: with Pool(processes=min(args.workers, len(section_dirs))) as pool: results = list( tqdm( pool.imap_unordered(_process_section, section_dirs), total=len(section_dirs), desc="Sections", ) ) print("\n".join(results)) if __name__ == "__main__": main()