gdubrasquetd's picture
deploy: bundle s2sr_pipe, fix requirements
e854df7 verified
Raw
History Blame Contribute Delete
6.53 kB
"""
Preprocessing for SEN2SR inference.
Steps
-----
1. Normalize : uint16 β†’ float32, divide by 10 000, clip to [0, 1].
2. Patch grid : compute (row, col) top-left coordinates for a sliding window
with configurable size and stride.
3. Extract : cut patches from the full array; optionally use multiprocessing
for the extraction step (useful when patch count is large and
copy overhead dominates GPU transfer).
Design note
-----------
Patch extraction is kept **separate** from model inference so that large images
(> 10 000 Γ— 10 000 px) never have to be fully loaded into GPU memory.
The caller controls batch assembly.
"""
from __future__ import annotations
import multiprocessing as mp
import warnings
from dataclasses import dataclass, field
from functools import partial
from typing import List, Tuple
import numpy as np
from s2sr_pipe.utils.logging_utils import get_logger
logger = get_logger("preprocessing")
# ── Normalisation ─────────────────────────────────────────────────────────────
NORM_SCALE: float = 10_000.0
def normalize(array: np.ndarray) -> np.ndarray:
"""
Normalize a uint16 Sentinel-2 array to float32 in [0, 1].
Parameters
----------
array : np.ndarray β€” shape (C, H, W), dtype uint16 or float32.
Returns
-------
np.ndarray β€” shape (C, H, W), dtype float32, values in [0, 1].
"""
if array.dtype == np.float32 and float(array.max()) < 2.0:
warnings.warn(
"normalize() reΓ§oit un tableau float32 avec des valeurs < 2.0 β€” "
"il est peut-Γͺtre dΓ©jΓ  normalisΓ©. La division par 10000 va produire "
"des valeurs quasi nulles.",
UserWarning,
stacklevel=2,
)
out = array.astype(np.float32) / NORM_SCALE
np.clip(out, 0.0, 1.0, out=out)
return out
def denormalize(array: np.ndarray) -> np.ndarray:
"""
Reverse normalisation: float32 [0, 1] β†’ uint16 [0, 10 000].
Parameters
----------
array : np.ndarray β€” shape (C, H, W), dtype float32.
Returns
-------
np.ndarray β€” shape (C, H, W), dtype uint16.
"""
out = array * NORM_SCALE
np.clip(out, 0, 10_000, out=out)
return out.astype(np.uint16)
# ── Patch grid ────────────────────────────────────────────────────────────────
@dataclass
class PatchConfig:
"""Configuration for sliding-window patch extraction."""
patch_size: int = 128
stride: int = 64
def __post_init__(self) -> None:
if self.stride > self.patch_size:
raise ValueError(
f"stride ({self.stride}) must be ≀ patch_size ({self.patch_size})"
)
if self.patch_size <= 0 or self.stride <= 0:
raise ValueError("patch_size and stride must be positive integers.")
def compute_patch_coords(
height: int,
width: int,
cfg: PatchConfig,
) -> List[Tuple[int, int]]:
"""
Compute (row, col) top-left coordinates for all patches.
The window always covers the full image: if the last patch would go
beyond the border it is snapped inward (no border reflection or padding).
Parameters
----------
height, width : Image spatial dimensions (after band alignment).
cfg : PatchConfig instance.
Returns
-------
List of (row, col) tuples.
"""
ps, st = cfg.patch_size, cfg.stride
if height < ps or width < ps:
raise ValueError(
f"L'image ({height}Γ—{width} px) est plus petite que patch_size ({ps}). "
"RΓ©duire patch_size ou utiliser une image plus grande."
)
rows = list(range(0, height - ps, st)) + [height - ps]
cols = list(range(0, width - ps, st)) + [width - ps]
# Deduplicate while preserving order
rows = sorted(set(rows))
cols = sorted(set(cols))
coords = [(r, c) for r in rows for c in cols]
logger.debug(
"Patch grid: %d rows x %d cols = %d patches "
"(patch_size=%d, stride=%d, image=%dx%d)",
len(rows),
len(cols),
len(coords),
ps,
st,
width,
height,
)
return coords
# ── Patch extraction ──────────────────────────────────────────────────────────
def _extract_one(
coord: Tuple[int, int],
array: np.ndarray,
patch_size: int,
) -> np.ndarray:
"""Extract a single (C, patch_size, patch_size) patch β€” pickleable worker."""
r, c = coord
return array[:, r : r + patch_size, c : c + patch_size].copy()
def extract_patches(
array: np.ndarray,
coords: List[Tuple[int, int]],
patch_size: int,
num_workers: int = 0,
) -> List[np.ndarray]:
"""
Extract all patches from *array* at positions given by *coords*.
Parameters
----------
array : (C, H, W) float32 normalised array.
coords : List of (row, col) top-left positions.
patch_size : Spatial size of each patch.
num_workers : Number of worker processes.
0 = single-threaded (safe for small images or debugging).
>0 = use multiprocessing.Pool (useful for very large images
with hundreds of thousands of patches).
Returns
-------
List[np.ndarray] β€” each element has shape (C, patch_size, patch_size).
Notes
-----
* Multiprocessing on large numpy arrays uses shared memory via fork (Linux).
On Windows/macOS, the array is serialised per chunk β€” use a smaller
num_workers or num_workers=0 if memory is limited on those platforms.
* For GPU-bound workloads the extraction is rarely the bottleneck;
num_workers=0 is the safe default.
"""
worker = partial(_extract_one, array=array, patch_size=patch_size)
if num_workers > 0:
logger.info("Extracting %d patches using %d workers ...", len(coords), num_workers)
with mp.Pool(num_workers) as pool:
patches = pool.map(worker, coords)
else:
logger.info("Extracting %d patches (single-threaded) ...", len(coords))
patches = [worker(c) for c in coords]
logger.info("Patch extraction complete.")
return patches