""" Post-processing for SEN2SR: reconstruct the full image from inferred patches. Overlap handling ---------------- Patches typically overlap (stride < patch_size). Naïve copying causes visible seam artefacts at patch boundaries. We use **weighted averaging** instead: each pixel in the output accumulates the weighted sum of all patches that cover it, then divides by the total weight accumulated. The default weight is a **2-D Hanning window** — values are 0 at the border and 1 at the centre. This smoothly de-emphasises edges where SR quality is lower (convolution networks have weaker receptive fields near borders). For non-overlapping patches (stride == patch_size) the weight map degenerates to all-ones, giving standard copy-paste. """ from __future__ import annotations from pathlib import Path from typing import List, Tuple import numpy as np from s2sr_pipe.utils.logging_utils import get_logger logger = get_logger("postprocessing") NORM_SCALE: float = 10_000.0 TCI_GAIN: float = 1.25 # white point at 0.8 reflectance (1/0.8) TCI_GAMMA: float = 1.25 # Mild gamma correction def _ensure_crs(profile: dict) -> dict: """Log a warning if CRS is still None after safe_reader recovery.""" if not profile.get("crs"): logger.warning( "CRS is None in the output profile - the GeoTIFF will have no spatial " "reference. Check that safe_reader correctly parsed the MGRS tile code." ) return profile # ── Weight window ───────────────────────────────────────────────────────────── def _hanning_window(patch_size: int) -> np.ndarray: """ 2-D Hanning window of shape (patch_size, patch_size), float32. Constructed as the outer product of two 1-D Hanning windows. Values range from ~0 at corners to 1.0 at the centre. """ w1d = np.hanning(patch_size).astype(np.float32) w2d = np.outer(w1d, w1d) # Guard against all-zero (degenerate case with patch_size=1) if w2d.max() > 0: w2d /= w2d.max() return w2d # ── Reconstruction ──────────────────────────────────────────────────────────── def reconstruct_from_patches( patches: List[np.ndarray], coords: List[Tuple[int, int]], image_shape: Tuple[int, int, int], # (C, H, W) patch_size: int, use_hanning: bool = True, ) -> np.ndarray: """ Reconstruct the full (C, H, W) float32 image from inferred patches. Parameters ---------- patches : List of (C, patch_size, patch_size) float32 arrays (model output, values already in [0, 1]). coords : Matching list of (row, col) top-left positions. image_shape : Target (C, H, W) shape. patch_size : Spatial patch size (must be consistent with coords). use_hanning : Use Hanning-weighted averaging for smooth seams. Set False for benchmarking / unit tests. Returns ------- np.ndarray — (C, H, W) float32, values averaged over overlaps. Raises ------ ValueError if len(patches) != len(coords). """ if len(patches) != len(coords): raise ValueError( f"patches ({len(patches)}) and coords ({len(coords)}) must have the same length." ) C, H, W = image_shape accumulator = np.zeros((C, H, W), dtype=np.float64) weight_map = np.zeros((H, W), dtype=np.float64) window = _hanning_window(patch_size) if use_hanning else np.ones( (patch_size, patch_size), dtype=np.float32 ) logger.info("Reconstructing image from %d patches ...", len(patches)) for patch, (r, c) in zip(patches, coords): # Clip patch to image bounds (edge patches may be snapped inward but # should always be exactly patch_size — guard anyway) ph = min(patch_size, H - r) pw = min(patch_size, W - c) accumulator[:, r : r + ph, c : c + pw] += patch[:, :ph, :pw] * window[:ph, :pw] weight_map[r : r + ph, c : c + pw] += window[:ph, :pw] # Avoid division by zero for any pixel not covered by any patch (including near-zero floats) uncovered = int((weight_map < 1e-9).sum()) if uncovered > 0: logger.warning( "%d pixels ne sont couverts par aucun patch et seront mis a 0. " "Verifier la configuration du patch grid.", uncovered ) weight_map = np.where(weight_map < 1e-9, 1.0, weight_map) # Broadcast weight across channels reconstructed = (accumulator / weight_map[np.newaxis, :, :]).astype(np.float32) logger.info( "Reconstruction complete. Output range: [%.4f, %.4f]", float(reconstructed.min()), float(reconstructed.max()), ) return reconstructed # ── Band name mapping ───────────────────────────────────────────────────────── # Canonical SEN2SR band order (index → Sentinel-2 band name) BAND_NAMES: List[str] = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"] # Indices of B04 (Red), B03 (Green), B02 (Blue) within BAND_NAMES for TCI TCI_RGB_INDICES: Tuple[int, int, int] = (2, 1, 0) # R=B04, G=B03, B=B02 # ── Low-level windowed writer ───────────────────────────────────────────────── def _write_band_streamed( data: np.ndarray, # (H, W) float32 in [0, 1], single band profile: dict, # 1-band rasterio profile output_path: Path, tile_size: int, desc: str, ) -> None: """Write a single band as float32 [0,1] directly — no denormalisation.""" import rasterio from rasterio.windows import Window from tqdm import tqdm H, W = data.shape col_starts = list(range(0, W, tile_size)) row_starts = list(range(0, H, tile_size)) n_windows = len(col_starts) * len(row_starts) with rasterio.open(output_path, "w", **profile) as dst: with tqdm(total=n_windows, desc=desc, unit="tile", leave=False) as pbar: for row_off in row_starts: row_end = min(row_off + tile_size, H) for col_off in col_starts: col_end = min(col_off + tile_size, W) tile = data[row_off:row_end, col_off:col_end] tile = np.where(np.isfinite(tile), tile, 0.0) tile_u16 = np.clip(tile * NORM_SCALE, 0, 10_000).astype(np.uint16) win = Window(col_off, row_off, col_end - col_off, row_end - row_off) dst.write(tile_u16[np.newaxis, :, :], window=win) pbar.update(1) def _write_jpeg_preview( data: np.ndarray, output_path: Path, max_size: int = 2048, quality: int = 90, gain: float = TCI_GAIN, gamma: float = TCI_GAMMA, ) -> Path: """Write a JPEG preview of the TCI, resized so the longest side <= max_size.""" from PIL import Image inv_gamma = 1.0 / gamma rgb_f = np.clip(data * gain, 0.0, 1.0) ** inv_gamma # (3, H, W) rgb_u8 = (rgb_f * 255.0).astype(np.uint8).transpose(1, 2, 0) # (H, W, 3) img = Image.fromarray(rgb_u8) H, W = rgb_u8.shape[:2] if max(H, W) > max_size: scale = max_size / max(H, W) img = img.resize((int(W * scale), int(H * scale)), Image.LANCZOS) img.save(output_path, format="JPEG", quality=quality) return output_path def _write_tci_streamed( data: np.ndarray, # (3, H, W) float32 in [0,1] profile: dict, # 3-band uint8 rasterio profile output_path: Path, tile_size: int, gain: float = TCI_GAIN, gamma: float = TCI_GAMMA, ) -> None: """Write TCI as uint8 using Sentinel Hub true color formula: gain + gamma correction.""" import rasterio from rasterio.windows import Window from tqdm import tqdm _, H, W = data.shape col_starts = list(range(0, W, tile_size)) row_starts = list(range(0, H, tile_size)) n_windows = len(col_starts) * len(row_starts) inv_gamma = 1.0 / gamma with rasterio.open(output_path, "w", **profile) as dst: with tqdm(total=n_windows, desc="TCI", unit="tile", leave=False) as pbar: for row_off in row_starts: row_end = min(row_off + tile_size, H) for col_off in col_starts: col_end = min(col_off + tile_size, W) tile = data[:, row_off:row_end, col_off:col_end] tile = np.where(np.isfinite(tile), tile, 0.0) # gain + gamma: matches Sentinel Hub EO Browser default True Color tile_f = np.clip(tile * gain, 0.0, 1.0) ** inv_gamma tile_u8 = (tile_f * 255.0).astype(np.uint8) win = Window(col_off, row_off, col_end - col_off, row_end - row_off) dst.write(tile_u8, window=win) pbar.update(1) # ── Public API ──────────────────────────────────────────────────────────────── def export_result( reconstructed_float: np.ndarray, ref_profile: dict, output_path: Path | str, scale: int = 1, compress: str = "deflate", tile_size: int = 1024, bands: List[str] | None = None, preview: bool = False, preview_size: int = 2048, ) -> List[Path]: """ Write one GeoTIFF per spectral band + one TCI (True Color Image). Output files ------------ Given ``output_path = "results/sen2sr_2.5m.tif"``, the function creates: results/sen2sr_2.5m_B02.tif ← Blue (uint16, DN [0,10000]) results/sen2sr_2.5m_B03.tif ← Green ... results/sen2sr_2.5m_TCI.tif ← RGB true color (uint8, stretched [0,255]) Parameters ---------- reconstructed_float : (C, H, W) float32 in [0, 1]. ref_profile : rasterio profile from the 10 m LR reference band. output_path : Base path used to derive per-band filenames. scale : SR upscale factor (4 = 10 m → 2.5 m). compress : GeoTIFF compression codec. tile_size : Spatial tile size for windowed writing. bands : Liste des sorties à écrire (ex: ["B02", "TCI"]). None = toutes les bandes + TCI (comportement par défaut). Valeurs valides : noms de bandes (B02..B12) + "TCI". Returns ------- List[Path] : Paths of all files written. Raises ------ ValueError : si bands contient des noms inconnus. """ from s2sr_pipe.utils.geo_utils import build_output_profile from s2sr_pipe.utils.logging_utils import get_logger log = get_logger("postprocessing") # ── Résoudre la sélection de bandes ────────────────────────────────────── _valid = set(BAND_NAMES) | {"TCI"} if bands is not None: invalid = sorted(set(bands) - _valid) if invalid: raise ValueError( f"Bandes inconnues : {invalid}. " f"Valeurs valides : {sorted(_valid)}" ) write_spectral = [b for b in BAND_NAMES if b in bands] write_tci = "TCI" in bands else: write_spectral = list(BAND_NAMES) write_tci = True C, H, W = reconstructed_float.shape output_dir = Path(output_path) output_dir.mkdir(parents=True, exist_ok=True) # Build base profile (1 band, uint16) — force CRS if absent base_profile = _ensure_crs(build_output_profile( ref_profile, n_bands=1, scale=scale, dtype="uint16", compress=compress )) written: List[Path] = [] log.info( "Exporting %d bande(s) spectrale(s)%s -> %s (tile_size=%d px)", len(write_spectral), " + TCI" if write_tci else "", output_dir, tile_size, ) # ── Per-band TIFs ───────────────────────────────────────────────────────── for band_name in write_spectral: idx = BAND_NAMES.index(band_name) out_file = output_dir / f"{band_name}.tif" log.info(" Writing band %s -> %s", band_name, out_file.name) _write_band_streamed( data=reconstructed_float[idx], # (H, W) view — zero copy profile=base_profile, output_path=out_file, tile_size=tile_size, desc=band_name, ) written.append(out_file) # ── TCI (RGB uint8) ─────────────────────────────────────────────────────── if write_tci: r_idx, g_idx, b_idx = TCI_RGB_INDICES rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :] # (3, H, W) view tci_profile = build_output_profile( ref_profile, n_bands=3, scale=scale, dtype="uint8", compress=compress ) tci_file = output_dir / "TCI.tif" log.info(" Writing TCI (R=B04 G=B03 B=B02, gain=%.1f gamma=%.1f) -> %s", TCI_GAIN, TCI_GAMMA, tci_file.name) _write_tci_streamed( data=rgb_data, profile=tci_profile, output_path=tci_file, tile_size=tile_size, ) written.append(tci_file) if preview: preview_file = output_dir / "preview.jpg" log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size) _write_jpeg_preview(rgb_data, preview_file, max_size=preview_size) written.append(preview_file) elif preview: # --preview sans TCI dans --bands : générer le preview en mémoire seulement r_idx, g_idx, b_idx = TCI_RGB_INDICES rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :] preview_file = output_dir / "preview.jpg" log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size) _write_jpeg_preview(rgb_data, preview_file, max_size=preview_size) written.append(preview_file) log.info("Export complete - %d files written to %s", len(written), output_dir) return written