| """ |
| Post-processing for SEN2SR: reconstruct the full image from inferred patches. |
| |
| Overlap handling |
| ---------------- |
| Patches typically overlap (stride < patch_size). NaΓ―ve copying causes visible |
| seam artefacts at patch boundaries. We use **weighted averaging** instead: |
| each pixel in the output accumulates the weighted sum of all patches that cover |
| it, then divides by the total weight accumulated. |
| |
| The default weight is a **2-D Hanning window** β values are 0 at the border |
| and 1 at the centre. This smoothly de-emphasises edges where SR quality is |
| lower (convolution networks have weaker receptive fields near borders). |
| |
| For non-overlapping patches (stride == patch_size) the weight map degenerates |
| to all-ones, giving standard copy-paste. |
| """ |
| from __future__ import annotations |
|
|
| from pathlib import Path |
| from typing import List, Tuple |
|
|
| import numpy as np |
|
|
| from s2sr_pipe.utils.logging_utils import get_logger |
|
|
| logger = get_logger("postprocessing") |
|
|
| NORM_SCALE: float = 10_000.0 |
| TCI_GAIN: float = 1.25 |
| TCI_GAMMA: float = 1.25 |
|
|
| def _ensure_crs(profile: dict) -> dict: |
| """Log a warning if CRS is still None after safe_reader recovery.""" |
| if not profile.get("crs"): |
| logger.warning( |
| "CRS is None in the output profile - the GeoTIFF will have no spatial " |
| "reference. Check that safe_reader correctly parsed the MGRS tile code." |
| ) |
| return profile |
|
|
|
|
| |
| def _hanning_window(patch_size: int) -> np.ndarray: |
| """ |
| 2-D Hanning window of shape (patch_size, patch_size), float32. |
| |
| Constructed as the outer product of two 1-D Hanning windows. |
| Values range from ~0 at corners to 1.0 at the centre. |
| """ |
| w1d = np.hanning(patch_size).astype(np.float32) |
| w2d = np.outer(w1d, w1d) |
| |
| if w2d.max() > 0: |
| w2d /= w2d.max() |
| return w2d |
|
|
|
|
| |
| def reconstruct_from_patches( |
| patches: List[np.ndarray], |
| coords: List[Tuple[int, int]], |
| image_shape: Tuple[int, int, int], |
| patch_size: int, |
| use_hanning: bool = True, |
| ) -> np.ndarray: |
| """ |
| Reconstruct the full (C, H, W) float32 image from inferred patches. |
| |
| Parameters |
| ---------- |
| patches : List of (C, patch_size, patch_size) float32 arrays |
| (model output, values already in [0, 1]). |
| coords : Matching list of (row, col) top-left positions. |
| image_shape : Target (C, H, W) shape. |
| patch_size : Spatial patch size (must be consistent with coords). |
| use_hanning : Use Hanning-weighted averaging for smooth seams. |
| Set False for benchmarking / unit tests. |
| |
| Returns |
| ------- |
| np.ndarray β (C, H, W) float32, values averaged over overlaps. |
| |
| Raises |
| ------ |
| ValueError if len(patches) != len(coords). |
| """ |
| if len(patches) != len(coords): |
| raise ValueError( |
| f"patches ({len(patches)}) and coords ({len(coords)}) must have the same length." |
| ) |
|
|
| C, H, W = image_shape |
| accumulator = np.zeros((C, H, W), dtype=np.float64) |
| weight_map = np.zeros((H, W), dtype=np.float64) |
|
|
| window = _hanning_window(patch_size) if use_hanning else np.ones( |
| (patch_size, patch_size), dtype=np.float32 |
| ) |
|
|
| logger.info("Reconstructing image from %d patches ...", len(patches)) |
|
|
| for patch, (r, c) in zip(patches, coords): |
| |
| |
| ph = min(patch_size, H - r) |
| pw = min(patch_size, W - c) |
|
|
| accumulator[:, r : r + ph, c : c + pw] += patch[:, :ph, :pw] * window[:ph, :pw] |
| weight_map[r : r + ph, c : c + pw] += window[:ph, :pw] |
|
|
| |
| uncovered = int((weight_map < 1e-9).sum()) |
| if uncovered > 0: |
| logger.warning( |
| "%d pixels ne sont couverts par aucun patch et seront mis a 0. " |
| "Verifier la configuration du patch grid.", uncovered |
| ) |
| weight_map = np.where(weight_map < 1e-9, 1.0, weight_map) |
| |
| reconstructed = (accumulator / weight_map[np.newaxis, :, :]).astype(np.float32) |
|
|
| logger.info( |
| "Reconstruction complete. Output range: [%.4f, %.4f]", |
| float(reconstructed.min()), |
| float(reconstructed.max()), |
| ) |
| return reconstructed |
|
|
|
|
| |
| |
| BAND_NAMES: List[str] = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"] |
|
|
| |
| TCI_RGB_INDICES: Tuple[int, int, int] = (2, 1, 0) |
|
|
|
|
| |
| def _write_band_streamed( |
| data: np.ndarray, |
| profile: dict, |
| output_path: Path, |
| tile_size: int, |
| desc: str, |
| ) -> None: |
| """Write a single band as float32 [0,1] directly β no denormalisation.""" |
| import rasterio |
| from rasterio.windows import Window |
| from tqdm import tqdm |
|
|
| H, W = data.shape |
| col_starts = list(range(0, W, tile_size)) |
| row_starts = list(range(0, H, tile_size)) |
| n_windows = len(col_starts) * len(row_starts) |
|
|
| with rasterio.open(output_path, "w", **profile) as dst: |
| with tqdm(total=n_windows, desc=desc, unit="tile", leave=False) as pbar: |
| for row_off in row_starts: |
| row_end = min(row_off + tile_size, H) |
| for col_off in col_starts: |
| col_end = min(col_off + tile_size, W) |
|
|
| tile = data[row_off:row_end, col_off:col_end] |
| tile = np.where(np.isfinite(tile), tile, 0.0) |
| tile_u16 = np.clip(tile * NORM_SCALE, 0, 10_000).astype(np.uint16) |
|
|
| win = Window(col_off, row_off, col_end - col_off, row_end - row_off) |
| dst.write(tile_u16[np.newaxis, :, :], window=win) |
| pbar.update(1) |
|
|
|
|
| def _write_jpeg_preview( |
| data: np.ndarray, |
| output_path: Path, |
| max_size: int = 2048, |
| quality: int = 90, |
| gain: float = TCI_GAIN, |
| gamma: float = TCI_GAMMA, |
| ) -> Path: |
| """Write a JPEG preview of the TCI, resized so the longest side <= max_size.""" |
| from PIL import Image |
|
|
| inv_gamma = 1.0 / gamma |
| rgb_f = np.clip(data * gain, 0.0, 1.0) ** inv_gamma |
| rgb_u8 = (rgb_f * 255.0).astype(np.uint8).transpose(1, 2, 0) |
| img = Image.fromarray(rgb_u8) |
| H, W = rgb_u8.shape[:2] |
| if max(H, W) > max_size: |
| scale = max_size / max(H, W) |
| img = img.resize((int(W * scale), int(H * scale)), Image.LANCZOS) |
| img.save(output_path, format="JPEG", quality=quality) |
| return output_path |
|
|
|
|
| def _write_tci_streamed( |
| data: np.ndarray, |
| profile: dict, |
| output_path: Path, |
| tile_size: int, |
| gain: float = TCI_GAIN, |
| gamma: float = TCI_GAMMA, |
| ) -> None: |
| """Write TCI as uint8 using Sentinel Hub true color formula: gain + gamma correction.""" |
| import rasterio |
| from rasterio.windows import Window |
| from tqdm import tqdm |
|
|
| _, H, W = data.shape |
| col_starts = list(range(0, W, tile_size)) |
| row_starts = list(range(0, H, tile_size)) |
| n_windows = len(col_starts) * len(row_starts) |
| inv_gamma = 1.0 / gamma |
|
|
| with rasterio.open(output_path, "w", **profile) as dst: |
| with tqdm(total=n_windows, desc="TCI", unit="tile", leave=False) as pbar: |
| for row_off in row_starts: |
| row_end = min(row_off + tile_size, H) |
| for col_off in col_starts: |
| col_end = min(col_off + tile_size, W) |
|
|
| tile = data[:, row_off:row_end, col_off:col_end] |
| tile = np.where(np.isfinite(tile), tile, 0.0) |
| |
| tile_f = np.clip(tile * gain, 0.0, 1.0) ** inv_gamma |
| tile_u8 = (tile_f * 255.0).astype(np.uint8) |
|
|
| win = Window(col_off, row_off, col_end - col_off, row_end - row_off) |
| dst.write(tile_u8, window=win) |
| pbar.update(1) |
|
|
|
|
| |
| def export_result( |
| reconstructed_float: np.ndarray, |
| ref_profile: dict, |
| output_path: Path | str, |
| scale: int = 1, |
| compress: str = "deflate", |
| tile_size: int = 1024, |
| bands: List[str] | None = None, |
| preview: bool = False, |
| preview_size: int = 2048, |
| ) -> List[Path]: |
| """ |
| Write one GeoTIFF per spectral band + one TCI (True Color Image). |
| |
| Output files |
| ------------ |
| Given ``output_path = "results/sen2sr_2.5m.tif"``, the function creates: |
| |
| results/sen2sr_2.5m_B02.tif β Blue (uint16, DN [0,10000]) |
| results/sen2sr_2.5m_B03.tif β Green |
| ... |
| results/sen2sr_2.5m_TCI.tif β RGB true color (uint8, stretched [0,255]) |
| |
| Parameters |
| ---------- |
| reconstructed_float : (C, H, W) float32 in [0, 1]. |
| ref_profile : rasterio profile from the 10 m LR reference band. |
| output_path : Base path used to derive per-band filenames. |
| scale : SR upscale factor (4 = 10 m β 2.5 m). |
| compress : GeoTIFF compression codec. |
| tile_size : Spatial tile size for windowed writing. |
| bands : Liste des sorties Γ Γ©crire (ex: ["B02", "TCI"]). |
| None = toutes les bandes + TCI (comportement par dΓ©faut). |
| Valeurs valides : noms de bandes (B02..B12) + "TCI". |
| |
| Returns |
| ------- |
| List[Path] : Paths of all files written. |
| |
| Raises |
| ------ |
| ValueError : si bands contient des noms inconnus. |
| """ |
| from s2sr_pipe.utils.geo_utils import build_output_profile |
| from s2sr_pipe.utils.logging_utils import get_logger |
|
|
| log = get_logger("postprocessing") |
|
|
| |
| _valid = set(BAND_NAMES) | {"TCI"} |
| if bands is not None: |
| invalid = sorted(set(bands) - _valid) |
| if invalid: |
| raise ValueError( |
| f"Bandes inconnues : {invalid}. " |
| f"Valeurs valides : {sorted(_valid)}" |
| ) |
| write_spectral = [b for b in BAND_NAMES if b in bands] |
| write_tci = "TCI" in bands |
| else: |
| write_spectral = list(BAND_NAMES) |
| write_tci = True |
|
|
| C, H, W = reconstructed_float.shape |
| output_dir = Path(output_path) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| |
| base_profile = _ensure_crs(build_output_profile( |
| ref_profile, n_bands=1, scale=scale, dtype="uint16", compress=compress |
| )) |
|
|
| written: List[Path] = [] |
|
|
| log.info( |
| "Exporting %d bande(s) spectrale(s)%s -> %s (tile_size=%d px)", |
| len(write_spectral), |
| " + TCI" if write_tci else "", |
| output_dir, |
| tile_size, |
| ) |
|
|
| |
| for band_name in write_spectral: |
| idx = BAND_NAMES.index(band_name) |
| out_file = output_dir / f"{band_name}.tif" |
|
|
| log.info(" Writing band %s -> %s", band_name, out_file.name) |
| _write_band_streamed( |
| data=reconstructed_float[idx], |
| profile=base_profile, |
| output_path=out_file, |
| tile_size=tile_size, |
| desc=band_name, |
| ) |
| written.append(out_file) |
|
|
| |
| if write_tci: |
| r_idx, g_idx, b_idx = TCI_RGB_INDICES |
| rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :] |
|
|
| tci_profile = build_output_profile( |
| ref_profile, n_bands=3, scale=scale, dtype="uint8", compress=compress |
| ) |
|
|
| tci_file = output_dir / "TCI.tif" |
| log.info(" Writing TCI (R=B04 G=B03 B=B02, gain=%.1f gamma=%.1f) -> %s", TCI_GAIN, TCI_GAMMA, tci_file.name) |
| _write_tci_streamed( |
| data=rgb_data, |
| profile=tci_profile, |
| output_path=tci_file, |
| tile_size=tile_size, |
| ) |
| written.append(tci_file) |
|
|
| if preview: |
| preview_file = output_dir / "preview.jpg" |
| log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size) |
| _write_jpeg_preview(rgb_data, preview_file, max_size=preview_size) |
| written.append(preview_file) |
|
|
| elif preview: |
| |
| r_idx, g_idx, b_idx = TCI_RGB_INDICES |
| rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :] |
| preview_file = output_dir / "preview.jpg" |
| log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size) |
| _write_jpeg_preview(rgb_data, preview_file, max_size=preview_size) |
| written.append(preview_file) |
|
|
| log.info("Export complete - %d files written to %s", len(written), output_dir) |
| return written |