gdubrasquetd's picture
deploy: bundle s2sr_pipe, fix requirements
e2348ed verified
Raw
History Blame Contribute Delete
13.3 kB
"""
Sentinel-2 L2A SAFE folder parser.
Handles both the old (< 2022) and new (>= 2022) ESA naming conventions:
Old: IMG_DATA/R10m/T31UDQ_20210601T103021_B02_10m.jp2
New: IMG_DATA/R10m/T31UDQ_20210601T103021_B02.jp2 (no resolution suffix)
Compact: IMG_DATA/B02.jp2 (some third-party products)
Band ordering returned by this module (index β†’ band name):
0 B02 | 1 B03 | 2 B04 | 3 B08 ← 10 m native
4 B05 | 5 B06 | 6 B07 | 7 B8A ← 20 m β†’ resampled to 10 m
8 B11 | 9 B12 ← 20 m β†’ resampled to 10 m
"""
from __future__ import annotations
import fnmatch
import re
import zipfile
from pathlib import Path
from typing import Dict, List, Optional, Tuple, Union
import rasterio.windows
import numpy as np
import rasterio
from rasterio.enums import Resampling
from s2sr_pipe.utils.logging_utils import get_logger
logger = get_logger("safe_reader")
# ── Band lists ────────────────────────────────────────────────────────────────
BANDS_10M: List[str] = ["B02", "B03", "B04", "B08"]
BANDS_20M: List[str] = ["B05", "B06", "B07", "B8A", "B11", "B12"]
# Canonical SEN2SR input order (MUST match tacofoundation/sen2sr model expectations):
# B02, B03, B04, B05, B06, B07, B08, B8A, B11, B12
# NOTE: B08 is at index 6, NOT index 3 β€” differs from a naive 10m-first stacking.
ALL_BANDS: List[str] = ["B02", "B03", "B04", "B05", "B06", "B07", "B08", "B8A", "B11", "B12"]
def _search_band(directory: Path, band_name: str) -> Path:
"""
Search *directory* for a jp2 file matching *band_name*.
Tries several glob patterns to handle naming variations.
Raises FileNotFoundError if nothing is found.
"""
patterns = [
f"*_{band_name}_*m.jp2", # T31UDQ_..._B02_10m.jp2
f"*_{band_name}.jp2", # T31UDQ_..._B02.jp2
f"{band_name}.jp2", # B02.jp2 (compact products)
f"*{band_name}*.jp2", # catch-all
]
for pat in patterns:
matches = sorted(directory.glob(pat))
if matches:
logger.debug("Band %s -> %s (pattern: %s)", band_name, matches[0], pat)
return matches[0]
raise FileNotFoundError(
f"Band {band_name} not found in {directory}. "
f"Available jp2 files: {sorted(directory.glob('*.jp2'))}"
)
def _find_safe_in_zip(zip_path: Path) -> str:
"""Trouve le nom du dossier .SAFE Γ  la racine d'une archive zip."""
with zipfile.ZipFile(zip_path, "r") as zf:
names = zf.namelist()
safe_dirs = sorted({
n.split("/")[0]
for n in names
if "/" in n and n.split("/")[0].endswith(".SAFE")
})
if not safe_dirs:
raise FileNotFoundError(
f"Aucun dossier .SAFE trouve dans {zip_path.name}. "
"Format attendu : produit Sentinel-2 standard (.SAFE zippe)."
)
if len(safe_dirs) > 1:
logger.warning(
"%d dossiers .SAFE trouves dans %s - seul le premier (%s) sera traite. "
"Format non standard Sentinel-2.",
len(safe_dirs), zip_path.name, safe_dirs[0],
)
return safe_dirs[0]
def _search_band_in_zip(
zip_entries: List[str],
prefix: str,
band_name: str,
zip_path: Path,
) -> str:
"""Trouve un .jp2 de bande dans les entrees zip et retourne son chemin /vsizip/."""
candidates = [e for e in zip_entries if e.startswith(prefix) and e.endswith(".jp2")]
patterns = [
f"*_{band_name}_*m.jp2",
f"*_{band_name}.jp2",
f"{band_name}.jp2",
f"*{band_name}*.jp2",
]
for pat in patterns:
matched = sorted(e for e in candidates if fnmatch.fnmatch(e.split("/")[-1], pat))
if matched:
logger.debug("Band %s -> %s (pattern: %s)", band_name, matched[0], pat)
return f"/vsizip/{zip_path.as_posix()}/{matched[0]}"
raise FileNotFoundError(
f"Bande {band_name} introuvable dans le zip sous le prefixe {prefix!r}. "
f"Fichiers jp2 disponibles : {sorted(e.split('/')[-1] for e in candidates)}"
)
def find_band_files(safe_path: Path | str) -> Dict[str, Path]:
"""
Locate all required band files inside a SAFE folder.
Parameters
----------
safe_path : Path to the *.SAFE directory.
Returns
-------
dict : {band_name: absolute Path to .jp2 file}
"""
safe_path = Path(safe_path).resolve()
if not safe_path.exists():
raise FileNotFoundError(f"SAFE path does not exist: {safe_path}")
granule_dir = safe_path / "GRANULE"
if not granule_dir.exists():
raise FileNotFoundError(f"No GRANULE/ directory in {safe_path}")
granules = [d for d in granule_dir.iterdir() if d.is_dir()]
if not granules:
raise FileNotFoundError(f"GRANULE/ is empty in {safe_path}")
if len(granules) > 1:
logger.warning(
"%d granules trouves dans %s - seul le premier (%s) sera traite. "
"Relancer le pipeline separement pour chaque granule si necessaire.",
len(granules), granule_dir, sorted(granules)[0].name,
)
# Most products have a single granule; take the first one alphabetically.
granule = sorted(granules)[0]
logger.info("Using granule: %s", granule.name)
img_data = granule / "IMG_DATA"
if not img_data.exists():
raise FileNotFoundError(f"IMG_DATA/ missing in {granule}")
# Some products put all bands flat in IMG_DATA (no R10m / R20m sub-dirs).
r10m_dir = img_data / "R10m" if (img_data / "R10m").exists() else img_data
r20m_dir = img_data / "R20m" if (img_data / "R20m").exists() else img_data
band_files: Dict[str, Path] = {}
for band in BANDS_10M:
band_files[band] = _search_band(r10m_dir, band)
for band in BANDS_20M:
band_files[band] = _search_band(r20m_dir, band)
logger.info("All %d band files located.", len(band_files))
return band_files
def find_band_files_from_zip(zip_path: Path | str) -> Dict[str, str]:
"""
Localise les 10 fichiers de bandes dans un zip Sentinel-2 standard.
Utilise le systeme de fichiers virtuel GDAL /vsizip/ β€” aucune extraction sur disque.
Parameters
----------
zip_path : Chemin vers le .zip (produit telecharge Sentinel-2).
Returns
-------
dict : {nom_bande: chemin /vsizip/...} β€” compatible avec read_and_align_bands().
"""
zip_path = Path(zip_path).resolve()
if not zip_path.exists():
raise FileNotFoundError(f"Fichier zip introuvable : {zip_path}")
safe_name = _find_safe_in_zip(zip_path)
logger.info("SAFE dans le zip : %s", safe_name)
with zipfile.ZipFile(zip_path, "r") as zf:
all_entries = zf.namelist()
granule_prefix = f"{safe_name}/GRANULE/"
granule_names = sorted({
n[len(granule_prefix):].split("/")[0]
for n in all_entries
if n.startswith(granule_prefix)
and len(n) > len(granule_prefix)
and n[len(granule_prefix):].split("/")[0]
})
if not granule_names:
raise FileNotFoundError(
f"Aucun granule trouve dans GRANULE/ a l'interieur de {safe_name}."
)
if len(granule_names) > 1:
logger.warning(
"%d granules trouves dans %s - seul le premier (%s) sera traite. "
"Relancer le pipeline separement pour chaque granule si necessaire.",
len(granule_names), safe_name, granule_names[0],
)
granule_name = granule_names[0]
logger.info("Using granule: %s", granule_name)
img_prefix = f"{safe_name}/GRANULE/{granule_name}/IMG_DATA/"
has_r10m = any(e.startswith(img_prefix + "R10m/") for e in all_entries)
has_r20m = any(e.startswith(img_prefix + "R20m/") for e in all_entries)
r10m_prefix = img_prefix + ("R10m/" if has_r10m else "")
r20m_prefix = img_prefix + ("R20m/" if has_r20m else "")
band_files: Dict[str, str] = {}
for band in BANDS_10M:
band_files[band] = _search_band_in_zip(all_entries, r10m_prefix, band, zip_path)
for band in BANDS_20M:
band_files[band] = _search_band_in_zip(all_entries, r20m_prefix, band, zip_path)
logger.info("All %d band files located in zip.", len(band_files))
return band_files
def get_ref_profile(
band_files: Dict[str, Union[Path, str]],
reference_band: str = "B02",
) -> dict:
"""
Return the rasterio profile of the reference band without reading pixel data.
Useful for computing a ROI window before calling read_and_align_bands,
avoiding loading the full tile into memory.
"""
with rasterio.open(band_files[reference_band]) as ref:
return ref.profile.copy()
def read_and_align_bands(
band_files: Dict[str, Union[Path, str]],
reference_band: str = "B02",
resampling: Resampling = Resampling.bilinear,
window: Optional[rasterio.windows.Window] = None,
) -> Tuple[np.ndarray, dict]:
"""
Read all 10 bands and upsample 20 m bands to the 10 m reference grid.
The resampling happens **once** here (not in the inference loop) using
rasterio's ``out_shape`` parameter β€” this avoids any re-projection and
keeps the spatial footprint identical across bands.
Parameters
----------
band_files : Output of :func:`find_band_files`.
reference_band : 10 m band used to define the target grid (B02 by default).
resampling : Resampling algorithm for 20 m β†’ 10 m upscaling.
Returns
-------
stacked : np.ndarray, shape (10, H, W), dtype uint16.
profile : rasterio profile of the reference band (10 m grid, 10 bands).
Notes
-----
* The returned profile contains ``count=10`` and ``dtype='uint16'``.
* CRS and transform are taken unchanged from the 10 m reference band β€”
no reprojection is performed.
"""
# ── Reference metadata ────────────────────────────────────────────────────
with rasterio.open(band_files[reference_band]) as ref:
ref_profile = ref.profile.copy()
ref_height = ref.height
ref_width = ref.width
ref_crs = ref.crs
ref_transform = ref.transform
logger.info(
"Reference band %s | grid %dx%d px | CRS: %s | pixel size: %.1f m",
reference_band,
ref_width,
ref_height,
ref_crs,
abs(ref_transform.a),
)
if window is not None:
out_height = int(window.height)
out_width = int(window.width)
out_transform = rasterio.windows.transform(window, ref_transform)
# Geographic bounds of the 10 m window (used to derive 20 m windows)
win_bounds = rasterio.windows.bounds(window, ref_transform)
logger.info(
"Windowed read: col_off=%d row_off=%d width=%d height=%d",
int(window.col_off), int(window.row_off), out_width, out_height,
)
else:
out_height = ref_height
out_width = ref_width
out_transform = ref_transform
win_bounds = None
# ── Read and (optionally) resample each band ──────────────────────────────
stacked: List[np.ndarray] = []
for band_name in ALL_BANDS:
fpath = band_files[band_name]
is_20m = band_name in BANDS_20M
with rasterio.open(fpath) as src:
if window is not None and is_20m:
# Derive the equivalent 20 m window from the geographic bounds
win_20m = rasterio.windows.from_bounds(
*win_bounds, transform=src.transform
)
data = src.read(
1,
window=win_20m,
out_shape=(out_height, out_width),
resampling=resampling,
)
elif window is not None:
data = src.read(1, window=window)
elif is_20m:
data = src.read(
1,
out_shape=(ref_height, ref_width),
resampling=resampling,
)
else:
data = src.read(1)
logger.debug(
"Band %-4s | shape %s | min=%d max=%d | %s",
band_name,
data.shape,
int(data.min()),
int(data.max()),
"resampled" if is_20m else "native",
)
stacked.append(data)
# ── Stack β†’ (C, H, W) ────────────────────────────────────────────────────
stacked_array = np.stack(stacked, axis=0) # shape (10, H, W), uint16
logger.info(
"Stacked array: shape=%s dtype=%s memory=%.1f MB",
stacked_array.shape,
stacked_array.dtype,
stacked_array.nbytes / 1e6,
)
out_profile = ref_profile.copy()
out_profile.update({
"count": len(ALL_BANDS),
"dtype": "uint16",
"height": out_height,
"width": out_width,
"transform": out_transform,
})
return stacked_array, out_profile