| """ |
| Preprocessing for SEN2SR inference. |
| |
| Steps |
| ----- |
| 1. Normalize : uint16 β float32, divide by 10 000, clip to [0, 1]. |
| 2. Patch grid : compute (row, col) top-left coordinates for a sliding window |
| with configurable size and stride. |
| 3. Extract : cut patches from the full array; optionally use multiprocessing |
| for the extraction step (useful when patch count is large and |
| copy overhead dominates GPU transfer). |
| |
| Design note |
| ----------- |
| Patch extraction is kept **separate** from model inference so that large images |
| (> 10 000 Γ 10 000 px) never have to be fully loaded into GPU memory. |
| The caller controls batch assembly. |
| """ |
| from __future__ import annotations |
|
|
| import multiprocessing as mp |
| import warnings |
| from dataclasses import dataclass, field |
| from functools import partial |
| from typing import List, Tuple |
|
|
| import numpy as np |
|
|
| from s2sr_pipe.utils.logging_utils import get_logger |
|
|
| logger = get_logger("preprocessing") |
|
|
| |
| NORM_SCALE: float = 10_000.0 |
|
|
|
|
| def normalize(array: np.ndarray) -> np.ndarray: |
| """ |
| Normalize a uint16 Sentinel-2 array to float32 in [0, 1]. |
| |
| Parameters |
| ---------- |
| array : np.ndarray β shape (C, H, W), dtype uint16 or float32. |
| |
| Returns |
| ------- |
| np.ndarray β shape (C, H, W), dtype float32, values in [0, 1]. |
| """ |
| if array.dtype == np.float32 and float(array.max()) < 2.0: |
| warnings.warn( |
| "normalize() reΓ§oit un tableau float32 avec des valeurs < 2.0 β " |
| "il est peut-Γͺtre dΓ©jΓ normalisΓ©. La division par 10000 va produire " |
| "des valeurs quasi nulles.", |
| UserWarning, |
| stacklevel=2, |
| ) |
| out = array.astype(np.float32) / NORM_SCALE |
| np.clip(out, 0.0, 1.0, out=out) |
| return out |
|
|
|
|
| def denormalize(array: np.ndarray) -> np.ndarray: |
| """ |
| Reverse normalisation: float32 [0, 1] β uint16 [0, 10 000]. |
| |
| Parameters |
| ---------- |
| array : np.ndarray β shape (C, H, W), dtype float32. |
| |
| Returns |
| ------- |
| np.ndarray β shape (C, H, W), dtype uint16. |
| """ |
| out = array * NORM_SCALE |
| np.clip(out, 0, 10_000, out=out) |
| return out.astype(np.uint16) |
|
|
|
|
| |
| @dataclass |
| class PatchConfig: |
| """Configuration for sliding-window patch extraction.""" |
| patch_size: int = 128 |
| stride: int = 64 |
|
|
| def __post_init__(self) -> None: |
| if self.stride > self.patch_size: |
| raise ValueError( |
| f"stride ({self.stride}) must be β€ patch_size ({self.patch_size})" |
| ) |
| if self.patch_size <= 0 or self.stride <= 0: |
| raise ValueError("patch_size and stride must be positive integers.") |
|
|
|
|
| def compute_patch_coords( |
| height: int, |
| width: int, |
| cfg: PatchConfig, |
| ) -> List[Tuple[int, int]]: |
| """ |
| Compute (row, col) top-left coordinates for all patches. |
| |
| The window always covers the full image: if the last patch would go |
| beyond the border it is snapped inward (no border reflection or padding). |
| |
| Parameters |
| ---------- |
| height, width : Image spatial dimensions (after band alignment). |
| cfg : PatchConfig instance. |
| |
| Returns |
| ------- |
| List of (row, col) tuples. |
| """ |
| ps, st = cfg.patch_size, cfg.stride |
|
|
| if height < ps or width < ps: |
| raise ValueError( |
| f"L'image ({height}Γ{width} px) est plus petite que patch_size ({ps}). " |
| "RΓ©duire patch_size ou utiliser une image plus grande." |
| ) |
|
|
| rows = list(range(0, height - ps, st)) + [height - ps] |
| cols = list(range(0, width - ps, st)) + [width - ps] |
|
|
| |
| rows = sorted(set(rows)) |
| cols = sorted(set(cols)) |
|
|
| coords = [(r, c) for r in rows for c in cols] |
| logger.debug( |
| "Patch grid: %d rows x %d cols = %d patches " |
| "(patch_size=%d, stride=%d, image=%dx%d)", |
| len(rows), |
| len(cols), |
| len(coords), |
| ps, |
| st, |
| width, |
| height, |
| ) |
| return coords |
|
|
|
|
| |
| def _extract_one( |
| coord: Tuple[int, int], |
| array: np.ndarray, |
| patch_size: int, |
| ) -> np.ndarray: |
| """Extract a single (C, patch_size, patch_size) patch β pickleable worker.""" |
| r, c = coord |
| return array[:, r : r + patch_size, c : c + patch_size].copy() |
|
|
|
|
| def extract_patches( |
| array: np.ndarray, |
| coords: List[Tuple[int, int]], |
| patch_size: int, |
| num_workers: int = 0, |
| ) -> List[np.ndarray]: |
| """ |
| Extract all patches from *array* at positions given by *coords*. |
| |
| Parameters |
| ---------- |
| array : (C, H, W) float32 normalised array. |
| coords : List of (row, col) top-left positions. |
| patch_size : Spatial size of each patch. |
| num_workers : Number of worker processes. |
| 0 = single-threaded (safe for small images or debugging). |
| >0 = use multiprocessing.Pool (useful for very large images |
| with hundreds of thousands of patches). |
| |
| Returns |
| ------- |
| List[np.ndarray] β each element has shape (C, patch_size, patch_size). |
| |
| Notes |
| ----- |
| * Multiprocessing on large numpy arrays uses shared memory via fork (Linux). |
| On Windows/macOS, the array is serialised per chunk β use a smaller |
| num_workers or num_workers=0 if memory is limited on those platforms. |
| * For GPU-bound workloads the extraction is rarely the bottleneck; |
| num_workers=0 is the safe default. |
| """ |
| worker = partial(_extract_one, array=array, patch_size=patch_size) |
|
|
| if num_workers > 0: |
| logger.info("Extracting %d patches using %d workers ...", len(coords), num_workers) |
| with mp.Pool(num_workers) as pool: |
| patches = pool.map(worker, coords) |
| else: |
| logger.info("Extracting %d patches (single-threaded) ...", len(coords)) |
| patches = [worker(c) for c in coords] |
|
|
| logger.info("Patch extraction complete.") |
| return patches |
|
|