""" Sentinel-2 L2A SAFE folder parser. Handles both the old (< 2022) and new (>= 2022) ESA naming conventions: Old: IMG_DATA/R10m/T31UDQ_20210601T103021_B02_10m.jp2 New: IMG_DATA/R10m/T31UDQ_20210601T103021_B02.jp2 (no resolution suffix) Compact: IMG_DATA/B02.jp2 (some third-party products) Band ordering returned by this module (index → band name): 0 B02 | 1 B03 | 2 B04 | 3 B08 ← 10 m native 4 B05 | 5 B06 | 6 B07 | 7 B8A ← 20 m → resampled to 10 m 8 B11 | 9 B12 ← 20 m → resampled to 10 m """ from __future__ import annotations import fnmatch import re import zipfile from pathlib import Path from typing import Dict, List, Optional, Tuple, Union import rasterio.windows import numpy as np import rasterio from rasterio.enums import Resampling from s2sr_pipe.utils.logging_utils import get_logger logger = get_logger("safe_reader") # ── Band lists ──────────────────────────────────────────────────────────────── BANDS_10M: List[str] = ["B02", "B03", "B04", "B08"] BANDS_20M: List[str] = ["B05", "B06", "B07", "B8A", "B11", "B12"] # Canonical SEN2SR input order (MUST match tacofoundation/sen2sr model expectations): # B02, B03, B04, B05, B06, B07, B08, B8A, B11, B12 # NOTE: B08 is at index 6, NOT index 3 — differs from a naive 10m-first stacking. ALL_BANDS: List[str] = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"] def _search_band(directory: Path, band_name: str) -> Path: """ Search *directory* for a jp2 file matching *band_name*. Tries several glob patterns to handle naming variations. Raises FileNotFoundError if nothing is found. """ patterns = [ f"*_{band_name}_*m.jp2", # T31UDQ_..._B02_10m.jp2 f"*_{band_name}.jp2", # T31UDQ_..._B02.jp2 f"{band_name}.jp2", # B02.jp2 (compact products) f"*{band_name}*.jp2", # catch-all ] for pat in patterns: matches = sorted(directory.glob(pat)) if matches: logger.debug("Band %s -> %s (pattern: %s)", band_name, matches[0], pat) return matches[0] raise FileNotFoundError( f"Band {band_name} not found in {directory}. " f"Available jp2 files: {sorted(directory.glob('*.jp2'))}" ) def _find_safe_in_zip(zip_path: Path) -> str: """Trouve le nom du dossier .SAFE à la racine d'une archive zip.""" with zipfile.ZipFile(zip_path, "r") as zf: names = zf.namelist() safe_dirs = sorted({ n.split("/")[0] for n in names if "/" in n and n.split("/")[0].endswith(".SAFE") }) if not safe_dirs: raise FileNotFoundError( f"Aucun dossier .SAFE trouve dans {zip_path.name}. " "Format attendu : produit Sentinel-2 standard (.SAFE zippe)." ) if len(safe_dirs) > 1: logger.warning( "%d dossiers .SAFE trouves dans %s - seul le premier (%s) sera traite. " "Format non standard Sentinel-2.", len(safe_dirs), zip_path.name, safe_dirs[0], ) return safe_dirs[0] def _search_band_in_zip( zip_entries: List[str], prefix: str, band_name: str, zip_path: Path, ) -> str: """Trouve un .jp2 de bande dans les entrees zip et retourne son chemin /vsizip/.""" candidates = [e for e in zip_entries if e.startswith(prefix) and e.endswith(".jp2")] patterns = [ f"*_{band_name}_*m.jp2", f"*_{band_name}.jp2", f"{band_name}.jp2", f"*{band_name}*.jp2", ] for pat in patterns: matched = sorted(e for e in candidates if fnmatch.fnmatch(e.split("/")[-1], pat)) if matched: logger.debug("Band %s -> %s (pattern: %s)", band_name, matched[0], pat) return f"/vsizip/{zip_path.as_posix()}/{matched[0]}" raise FileNotFoundError( f"Bande {band_name} introuvable dans le zip sous le prefixe {prefix!r}. " f"Fichiers jp2 disponibles : {sorted(e.split('/')[-1] for e in candidates)}" ) def find_band_files(safe_path: Path | str) -> Dict[str, Path]: """ Locate all required band files inside a SAFE folder. Parameters ---------- safe_path : Path to the *.SAFE directory. Returns ------- dict : {band_name: absolute Path to .jp2 file} """ safe_path = Path(safe_path).resolve() if not safe_path.exists(): raise FileNotFoundError(f"SAFE path does not exist: {safe_path}") granule_dir = safe_path / "GRANULE" if not granule_dir.exists(): raise FileNotFoundError(f"No GRANULE/ directory in {safe_path}") granules = [d for d in granule_dir.iterdir() if d.is_dir()] if not granules: raise FileNotFoundError(f"GRANULE/ is empty in {safe_path}") if len(granules) > 1: logger.warning( "%d granules trouves dans %s - seul le premier (%s) sera traite. " "Relancer le pipeline separement pour chaque granule si necessaire.", len(granules), granule_dir, sorted(granules)[0].name, ) # Most products have a single granule; take the first one alphabetically. granule = sorted(granules)[0] logger.info("Using granule: %s", granule.name) img_data = granule / "IMG_DATA" if not img_data.exists(): raise FileNotFoundError(f"IMG_DATA/ missing in {granule}") # Some products put all bands flat in IMG_DATA (no R10m / R20m sub-dirs). r10m_dir = img_data / "R10m" if (img_data / "R10m").exists() else img_data r20m_dir = img_data / "R20m" if (img_data / "R20m").exists() else img_data band_files: Dict[str, Path] = {} for band in BANDS_10M: band_files[band] = _search_band(r10m_dir, band) for band in BANDS_20M: band_files[band] = _search_band(r20m_dir, band) logger.info("All %d band files located.", len(band_files)) return band_files def find_band_files_from_zip(zip_path: Path | str) -> Dict[str, str]: """ Localise les 10 fichiers de bandes dans un zip Sentinel-2 standard. Utilise le systeme de fichiers virtuel GDAL /vsizip/ — aucune extraction sur disque. Parameters ---------- zip_path : Chemin vers le .zip (produit telecharge Sentinel-2). Returns ------- dict : {nom_bande: chemin /vsizip/...} — compatible avec read_and_align_bands(). """ zip_path = Path(zip_path).resolve() if not zip_path.exists(): raise FileNotFoundError(f"Fichier zip introuvable : {zip_path}") safe_name = _find_safe_in_zip(zip_path) logger.info("SAFE dans le zip : %s", safe_name) with zipfile.ZipFile(zip_path, "r") as zf: all_entries = zf.namelist() granule_prefix = f"{safe_name}/GRANULE/" granule_names = sorted({ n[len(granule_prefix):].split("/")[0] for n in all_entries if n.startswith(granule_prefix) and len(n) > len(granule_prefix) and n[len(granule_prefix):].split("/")[0] }) if not granule_names: raise FileNotFoundError( f"Aucun granule trouve dans GRANULE/ a l'interieur de {safe_name}." ) if len(granule_names) > 1: logger.warning( "%d granules trouves dans %s - seul le premier (%s) sera traite. " "Relancer le pipeline separement pour chaque granule si necessaire.", len(granule_names), safe_name, granule_names[0], ) granule_name = granule_names[0] logger.info("Using granule: %s", granule_name) img_prefix = f"{safe_name}/GRANULE/{granule_name}/IMG_DATA/" has_r10m = any(e.startswith(img_prefix + "R10m/") for e in all_entries) has_r20m = any(e.startswith(img_prefix + "R20m/") for e in all_entries) r10m_prefix = img_prefix + ("R10m/" if has_r10m else "") r20m_prefix = img_prefix + ("R20m/" if has_r20m else "") band_files: Dict[str, str] = {} for band in BANDS_10M: band_files[band] = _search_band_in_zip(all_entries, r10m_prefix, band, zip_path) for band in BANDS_20M: band_files[band] = _search_band_in_zip(all_entries, r20m_prefix, band, zip_path) logger.info("All %d band files located in zip.", len(band_files)) return band_files def get_ref_profile( band_files: Dict[str, Union[Path, str]], reference_band: str = "B02", ) -> dict: """ Return the rasterio profile of the reference band without reading pixel data. Useful for computing a ROI window before calling read_and_align_bands, avoiding loading the full tile into memory. """ with rasterio.open(band_files[reference_band]) as ref: return ref.profile.copy() def read_and_align_bands( band_files: Dict[str, Union[Path, str]], reference_band: str = "B02", resampling: Resampling = Resampling.bilinear, window: Optional[rasterio.windows.Window] = None, ) -> Tuple[np.ndarray, dict]: """ Read all 10 bands and upsample 20 m bands to the 10 m reference grid. The resampling happens **once** here (not in the inference loop) using rasterio's ``out_shape`` parameter — this avoids any re-projection and keeps the spatial footprint identical across bands. Parameters ---------- band_files : Output of :func:`find_band_files`. reference_band : 10 m band used to define the target grid (B02 by default). resampling : Resampling algorithm for 20 m → 10 m upscaling. Returns ------- stacked : np.ndarray, shape (10, H, W), dtype uint16. profile : rasterio profile of the reference band (10 m grid, 10 bands). Notes ----- * The returned profile contains ``count=10`` and ``dtype='uint16'``. * CRS and transform are taken unchanged from the 10 m reference band — no reprojection is performed. """ # ── Reference metadata ──────────────────────────────────────────────────── with rasterio.open(band_files[reference_band]) as ref: ref_profile = ref.profile.copy() ref_height = ref.height ref_width = ref.width ref_crs = ref.crs ref_transform = ref.transform logger.info( "Reference band %s | grid %dx%d px | CRS: %s | pixel size: %.1f m", reference_band, ref_width, ref_height, ref_crs, abs(ref_transform.a), ) if window is not None: out_height = int(window.height) out_width = int(window.width) out_transform = rasterio.windows.transform(window, ref_transform) # Geographic bounds of the 10 m window (used to derive 20 m windows) win_bounds = rasterio.windows.bounds(window, ref_transform) logger.info( "Windowed read: col_off=%d row_off=%d width=%d height=%d", int(window.col_off), int(window.row_off), out_width, out_height, ) else: out_height = ref_height out_width = ref_width out_transform = ref_transform win_bounds = None # ── Read and (optionally) resample each band ────────────────────────────── stacked: List[np.ndarray] = [] for band_name in ALL_BANDS: fpath = band_files[band_name] is_20m = band_name in BANDS_20M with rasterio.open(fpath) as src: if window is not None and is_20m: # Derive the equivalent 20 m window from the geographic bounds win_20m = rasterio.windows.from_bounds( *win_bounds, transform=src.transform ) data = src.read( 1, window=win_20m, out_shape=(out_height, out_width), resampling=resampling, ) elif window is not None: data = src.read(1, window=window) elif is_20m: data = src.read( 1, out_shape=(ref_height, ref_width), resampling=resampling, ) else: data = src.read(1) logger.debug( "Band %-4s | shape %s | min=%d max=%d | %s", band_name, data.shape, int(data.min()), int(data.max()), "resampled" if is_20m else "native", ) stacked.append(data) # ── Stack → (C, H, W) ──────────────────────────────────────────────────── stacked_array = np.stack(stacked, axis=0) # shape (10, H, W), uint16 logger.info( "Stacked array: shape=%s dtype=%s memory=%.1f MB", stacked_array.shape, stacked_array.dtype, stacked_array.nbytes / 1e6, ) out_profile = ref_profile.copy() out_profile.update({ "count": len(ALL_BANDS), "dtype": "uint16", "height": out_height, "width": out_width, "transform": out_transform, }) return stacked_array, out_profile