gdubrasquetd's picture
deploy: bundle s2sr_pipe, fix requirements
e854df7 verified
Raw
History Blame Contribute Delete
14.7 kB
"""
Post-processing for SEN2SR: reconstruct the full image from inferred patches.
Overlap handling
----------------
Patches typically overlap (stride < patch_size). NaΓ―ve copying causes visible
seam artefacts at patch boundaries. We use **weighted averaging** instead:
each pixel in the output accumulates the weighted sum of all patches that cover
it, then divides by the total weight accumulated.
The default weight is a **2-D Hanning window** β€” values are 0 at the border
and 1 at the centre. This smoothly de-emphasises edges where SR quality is
lower (convolution networks have weaker receptive fields near borders).
For non-overlapping patches (stride == patch_size) the weight map degenerates
to all-ones, giving standard copy-paste.
"""
from __future__ import annotations
from pathlib import Path
from typing import List, Tuple
import numpy as np
from s2sr_pipe.utils.logging_utils import get_logger
logger = get_logger("postprocessing")
NORM_SCALE: float = 10_000.0
TCI_GAIN: float = 1.25 # white point at 0.8 reflectance (1/0.8)
TCI_GAMMA: float = 1.25 # Mild gamma correction
def _ensure_crs(profile: dict) -> dict:
"""Log a warning if CRS is still None after safe_reader recovery."""
if not profile.get("crs"):
logger.warning(
"CRS is None in the output profile - the GeoTIFF will have no spatial "
"reference. Check that safe_reader correctly parsed the MGRS tile code."
)
return profile
# ── Weight window ─────────────────────────────────────────────────────────────
def _hanning_window(patch_size: int) -> np.ndarray:
"""
2-D Hanning window of shape (patch_size, patch_size), float32.
Constructed as the outer product of two 1-D Hanning windows.
Values range from ~0 at corners to 1.0 at the centre.
"""
w1d = np.hanning(patch_size).astype(np.float32)
w2d = np.outer(w1d, w1d)
# Guard against all-zero (degenerate case with patch_size=1)
if w2d.max() > 0:
w2d /= w2d.max()
return w2d
# ── Reconstruction ────────────────────────────────────────────────────────────
def reconstruct_from_patches(
patches: List[np.ndarray],
coords: List[Tuple[int, int]],
image_shape: Tuple[int, int, int], # (C, H, W)
patch_size: int,
use_hanning: bool = True,
) -> np.ndarray:
"""
Reconstruct the full (C, H, W) float32 image from inferred patches.
Parameters
----------
patches : List of (C, patch_size, patch_size) float32 arrays
(model output, values already in [0, 1]).
coords : Matching list of (row, col) top-left positions.
image_shape : Target (C, H, W) shape.
patch_size : Spatial patch size (must be consistent with coords).
use_hanning : Use Hanning-weighted averaging for smooth seams.
Set False for benchmarking / unit tests.
Returns
-------
np.ndarray β€” (C, H, W) float32, values averaged over overlaps.
Raises
------
ValueError if len(patches) != len(coords).
"""
if len(patches) != len(coords):
raise ValueError(
f"patches ({len(patches)}) and coords ({len(coords)}) must have the same length."
)
C, H, W = image_shape
accumulator = np.zeros((C, H, W), dtype=np.float64)
weight_map = np.zeros((H, W), dtype=np.float64)
window = _hanning_window(patch_size) if use_hanning else np.ones(
(patch_size, patch_size), dtype=np.float32
)
logger.info("Reconstructing image from %d patches ...", len(patches))
for patch, (r, c) in zip(patches, coords):
# Clip patch to image bounds (edge patches may be snapped inward but
# should always be exactly patch_size β€” guard anyway)
ph = min(patch_size, H - r)
pw = min(patch_size, W - c)
accumulator[:, r : r + ph, c : c + pw] += patch[:, :ph, :pw] * window[:ph, :pw]
weight_map[r : r + ph, c : c + pw] += window[:ph, :pw]
# Avoid division by zero for any pixel not covered by any patch (including near-zero floats)
uncovered = int((weight_map < 1e-9).sum())
if uncovered > 0:
logger.warning(
"%d pixels ne sont couverts par aucun patch et seront mis a 0. "
"Verifier la configuration du patch grid.", uncovered
)
weight_map = np.where(weight_map < 1e-9, 1.0, weight_map)
# Broadcast weight across channels
reconstructed = (accumulator / weight_map[np.newaxis, :, :]).astype(np.float32)
logger.info(
"Reconstruction complete. Output range: [%.4f, %.4f]",
float(reconstructed.min()),
float(reconstructed.max()),
)
return reconstructed
# ── Band name mapping ─────────────────────────────────────────────────────────
# Canonical SEN2SR band order (index β†’ Sentinel-2 band name)
BAND_NAMES: List[str] = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"]
# Indices of B04 (Red), B03 (Green), B02 (Blue) within BAND_NAMES for TCI
TCI_RGB_INDICES: Tuple[int, int, int] = (2, 1, 0) # R=B04, G=B03, B=B02
# ── Low-level windowed writer ─────────────────────────────────────────────────
def _write_band_streamed(
data: np.ndarray, # (H, W) float32 in [0, 1], single band
profile: dict, # 1-band rasterio profile
output_path: Path,
tile_size: int,
desc: str,
) -> None:
"""Write a single band as float32 [0,1] directly β€” no denormalisation."""
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
H, W = data.shape
col_starts = list(range(0, W, tile_size))
row_starts = list(range(0, H, tile_size))
n_windows = len(col_starts) * len(row_starts)
with rasterio.open(output_path, "w", **profile) as dst:
with tqdm(total=n_windows, desc=desc, unit="tile", leave=False) as pbar:
for row_off in row_starts:
row_end = min(row_off + tile_size, H)
for col_off in col_starts:
col_end = min(col_off + tile_size, W)
tile = data[row_off:row_end, col_off:col_end]
tile = np.where(np.isfinite(tile), tile, 0.0)
tile_u16 = np.clip(tile * NORM_SCALE, 0, 10_000).astype(np.uint16)
win = Window(col_off, row_off, col_end - col_off, row_end - row_off)
dst.write(tile_u16[np.newaxis, :, :], window=win)
pbar.update(1)
def _write_jpeg_preview(
data: np.ndarray,
output_path: Path,
max_size: int = 2048,
quality: int = 90,
gain: float = TCI_GAIN,
gamma: float = TCI_GAMMA,
) -> Path:
"""Write a JPEG preview of the TCI, resized so the longest side <= max_size."""
from PIL import Image
inv_gamma = 1.0 / gamma
rgb_f = np.clip(data * gain, 0.0, 1.0) ** inv_gamma # (3, H, W)
rgb_u8 = (rgb_f * 255.0).astype(np.uint8).transpose(1, 2, 0) # (H, W, 3)
img = Image.fromarray(rgb_u8)
H, W = rgb_u8.shape[:2]
if max(H, W) > max_size:
scale = max_size / max(H, W)
img = img.resize((int(W * scale), int(H * scale)), Image.LANCZOS)
img.save(output_path, format="JPEG", quality=quality)
return output_path
def _write_tci_streamed(
data: np.ndarray, # (3, H, W) float32 in [0,1]
profile: dict, # 3-band uint8 rasterio profile
output_path: Path,
tile_size: int,
gain: float = TCI_GAIN,
gamma: float = TCI_GAMMA,
) -> None:
"""Write TCI as uint8 using Sentinel Hub true color formula: gain + gamma correction."""
import rasterio
from rasterio.windows import Window
from tqdm import tqdm
_, H, W = data.shape
col_starts = list(range(0, W, tile_size))
row_starts = list(range(0, H, tile_size))
n_windows = len(col_starts) * len(row_starts)
inv_gamma = 1.0 / gamma
with rasterio.open(output_path, "w", **profile) as dst:
with tqdm(total=n_windows, desc="TCI", unit="tile", leave=False) as pbar:
for row_off in row_starts:
row_end = min(row_off + tile_size, H)
for col_off in col_starts:
col_end = min(col_off + tile_size, W)
tile = data[:, row_off:row_end, col_off:col_end]
tile = np.where(np.isfinite(tile), tile, 0.0)
# gain + gamma: matches Sentinel Hub EO Browser default True Color
tile_f = np.clip(tile * gain, 0.0, 1.0) ** inv_gamma
tile_u8 = (tile_f * 255.0).astype(np.uint8)
win = Window(col_off, row_off, col_end - col_off, row_end - row_off)
dst.write(tile_u8, window=win)
pbar.update(1)
# ── Public API ────────────────────────────────────────────────────────────────
def export_result(
reconstructed_float: np.ndarray,
ref_profile: dict,
output_path: Path | str,
scale: int = 1,
compress: str = "deflate",
tile_size: int = 1024,
bands: List[str] | None = None,
preview: bool = False,
preview_size: int = 2048,
) -> List[Path]:
"""
Write one GeoTIFF per spectral band + one TCI (True Color Image).
Output files
------------
Given ``output_path = "results/sen2sr_2.5m.tif"``, the function creates:
results/sen2sr_2.5m_B02.tif ← Blue (uint16, DN [0,10000])
results/sen2sr_2.5m_B03.tif ← Green
...
results/sen2sr_2.5m_TCI.tif ← RGB true color (uint8, stretched [0,255])
Parameters
----------
reconstructed_float : (C, H, W) float32 in [0, 1].
ref_profile : rasterio profile from the 10 m LR reference band.
output_path : Base path used to derive per-band filenames.
scale : SR upscale factor (4 = 10 m β†’ 2.5 m).
compress : GeoTIFF compression codec.
tile_size : Spatial tile size for windowed writing.
bands : Liste des sorties Γ  Γ©crire (ex: ["B02", "TCI"]).
None = toutes les bandes + TCI (comportement par dΓ©faut).
Valeurs valides : noms de bandes (B02..B12) + "TCI".
Returns
-------
List[Path] : Paths of all files written.
Raises
------
ValueError : si bands contient des noms inconnus.
"""
from s2sr_pipe.utils.geo_utils import build_output_profile
from s2sr_pipe.utils.logging_utils import get_logger
log = get_logger("postprocessing")
# ── RΓ©soudre la sΓ©lection de bandes ──────────────────────────────────────
_valid = set(BAND_NAMES) | {"TCI"}
if bands is not None:
invalid = sorted(set(bands) - _valid)
if invalid:
raise ValueError(
f"Bandes inconnues : {invalid}. "
f"Valeurs valides : {sorted(_valid)}"
)
write_spectral = [b for b in BAND_NAMES if b in bands]
write_tci = "TCI" in bands
else:
write_spectral = list(BAND_NAMES)
write_tci = True
C, H, W = reconstructed_float.shape
output_dir = Path(output_path)
output_dir.mkdir(parents=True, exist_ok=True)
# Build base profile (1 band, uint16) β€” force CRS if absent
base_profile = _ensure_crs(build_output_profile(
ref_profile, n_bands=1, scale=scale, dtype="uint16", compress=compress
))
written: List[Path] = []
log.info(
"Exporting %d bande(s) spectrale(s)%s -> %s (tile_size=%d px)",
len(write_spectral),
" + TCI" if write_tci else "",
output_dir,
tile_size,
)
# ── Per-band TIFs ─────────────────────────────────────────────────────────
for band_name in write_spectral:
idx = BAND_NAMES.index(band_name)
out_file = output_dir / f"{band_name}.tif"
log.info(" Writing band %s -> %s", band_name, out_file.name)
_write_band_streamed(
data=reconstructed_float[idx], # (H, W) view β€” zero copy
profile=base_profile,
output_path=out_file,
tile_size=tile_size,
desc=band_name,
)
written.append(out_file)
# ── TCI (RGB uint8) ───────────────────────────────────────────────────────
if write_tci:
r_idx, g_idx, b_idx = TCI_RGB_INDICES
rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :] # (3, H, W) view
tci_profile = build_output_profile(
ref_profile, n_bands=3, scale=scale, dtype="uint8", compress=compress
)
tci_file = output_dir / "TCI.tif"
log.info(" Writing TCI (R=B04 G=B03 B=B02, gain=%.1f gamma=%.1f) -> %s", TCI_GAIN, TCI_GAMMA, tci_file.name)
_write_tci_streamed(
data=rgb_data,
profile=tci_profile,
output_path=tci_file,
tile_size=tile_size,
)
written.append(tci_file)
if preview:
preview_file = output_dir / "preview.jpg"
log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size)
_write_jpeg_preview(rgb_data, preview_file, max_size=preview_size)
written.append(preview_file)
elif preview:
# --preview sans TCI dans --bands : gΓ©nΓ©rer le preview en mΓ©moire seulement
r_idx, g_idx, b_idx = TCI_RGB_INDICES
rgb_data = reconstructed_float[[r_idx, g_idx, b_idx], :, :]
preview_file = output_dir / "preview.jpg"
log.info(" Writing preview -> %s (max_size=%d px)", preview_file.name, preview_size)
_write_jpeg_preview(rgb_data, preview_file, max_size=preview_size)
written.append(preview_file)
log.info("Export complete - %d files written to %s", len(written), output_dir)
return written