""" Preprocessing for SEN2SR inference. Steps ----- 1. Normalize : uint16 → float32, divide by 10 000, clip to [0, 1]. 2. Patch grid : compute (row, col) top-left coordinates for a sliding window with configurable size and stride. 3. Extract : cut patches from the full array; optionally use multiprocessing for the extraction step (useful when patch count is large and copy overhead dominates GPU transfer). Design note ----------- Patch extraction is kept **separate** from model inference so that large images (> 10 000 × 10 000 px) never have to be fully loaded into GPU memory. The caller controls batch assembly. """ from __future__ import annotations import multiprocessing as mp import warnings from dataclasses import dataclass, field from functools import partial from typing import List, Tuple import numpy as np from s2sr_pipe.utils.logging_utils import get_logger logger = get_logger("preprocessing") # ── Normalisation ───────────────────────────────────────────────────────────── NORM_SCALE: float = 10_000.0 def normalize(array: np.ndarray) -> np.ndarray: """ Normalize a uint16 Sentinel-2 array to float32 in [0, 1]. Parameters ---------- array : np.ndarray — shape (C, H, W), dtype uint16 or float32. Returns ------- np.ndarray — shape (C, H, W), dtype float32, values in [0, 1]. """ if array.dtype == np.float32 and float(array.max()) < 2.0: warnings.warn( "normalize() reçoit un tableau float32 avec des valeurs < 2.0 — " "il est peut-être déjà normalisé. La division par 10000 va produire " "des valeurs quasi nulles.", UserWarning, stacklevel=2, ) out = array.astype(np.float32) / NORM_SCALE np.clip(out, 0.0, 1.0, out=out) return out def denormalize(array: np.ndarray) -> np.ndarray: """ Reverse normalisation: float32 [0, 1] → uint16 [0, 10 000]. Parameters ---------- array : np.ndarray — shape (C, H, W), dtype float32. Returns ------- np.ndarray — shape (C, H, W), dtype uint16. """ out = array * NORM_SCALE np.clip(out, 0, 10_000, out=out) return out.astype(np.uint16) # ── Patch grid ──────────────────────────────────────────────────────────────── @dataclass class PatchConfig: """Configuration for sliding-window patch extraction.""" patch_size: int = 128 stride: int = 64 def __post_init__(self) -> None: if self.stride > self.patch_size: raise ValueError( f"stride ({self.stride}) must be ≤ patch_size ({self.patch_size})" ) if self.patch_size <= 0 or self.stride <= 0: raise ValueError("patch_size and stride must be positive integers.") def compute_patch_coords( height: int, width: int, cfg: PatchConfig, ) -> List[Tuple[int, int]]: """ Compute (row, col) top-left coordinates for all patches. The window always covers the full image: if the last patch would go beyond the border it is snapped inward (no border reflection or padding). Parameters ---------- height, width : Image spatial dimensions (after band alignment). cfg : PatchConfig instance. Returns ------- List of (row, col) tuples. """ ps, st = cfg.patch_size, cfg.stride if height < ps or width < ps: raise ValueError( f"L'image ({height}×{width} px) est plus petite que patch_size ({ps}). " "Réduire patch_size ou utiliser une image plus grande." ) rows = list(range(0, height - ps, st)) + [height - ps] cols = list(range(0, width - ps, st)) + [width - ps] # Deduplicate while preserving order rows = sorted(set(rows)) cols = sorted(set(cols)) coords = [(r, c) for r in rows for c in cols] logger.debug( "Patch grid: %d rows x %d cols = %d patches " "(patch_size=%d, stride=%d, image=%dx%d)", len(rows), len(cols), len(coords), ps, st, width, height, ) return coords # ── Patch extraction ────────────────────────────────────────────────────────── def _extract_one( coord: Tuple[int, int], array: np.ndarray, patch_size: int, ) -> np.ndarray: """Extract a single (C, patch_size, patch_size) patch — pickleable worker.""" r, c = coord return array[:, r : r + patch_size, c : c + patch_size].copy() def extract_patches( array: np.ndarray, coords: List[Tuple[int, int]], patch_size: int, num_workers: int = 0, ) -> List[np.ndarray]: """ Extract all patches from *array* at positions given by *coords*. Parameters ---------- array : (C, H, W) float32 normalised array. coords : List of (row, col) top-left positions. patch_size : Spatial size of each patch. num_workers : Number of worker processes. 0 = single-threaded (safe for small images or debugging). >0 = use multiprocessing.Pool (useful for very large images with hundreds of thousands of patches). Returns ------- List[np.ndarray] — each element has shape (C, patch_size, patch_size). Notes ----- * Multiprocessing on large numpy arrays uses shared memory via fork (Linux). On Windows/macOS, the array is serialised per chunk — use a smaller num_workers or num_workers=0 if memory is limited on those platforms. * For GPU-bound workloads the extraction is rarely the bottleneck; num_workers=0 is the safe default. """ worker = partial(_extract_one, array=array, patch_size=patch_size) if num_workers > 0: logger.info("Extracting %d patches using %d workers ...", len(coords), num_workers) with mp.Pool(num_workers) as pool: patches = pool.map(worker, coords) else: logger.info("Extracting %d patches (single-threaded) ...", len(coords)) patches = [worker(c) for c in coords] logger.info("Patch extraction complete.") return patches