""" HuggingFace Space — SEN2SR Super-Resolution Sentinel-2. Two modes: pre-loaded example or custom ZIP upload. """ from __future__ import annotations import json import os import tempfile from pathlib import Path import gradio as gr import numpy as np import rasterio import rasterio.windows import torch from PIL import Image from rasterio.crs import CRS from rasterio.transform import Affine from rasterio.warp import transform_bounds from s2sr_pipe.data_processing.postprocessing import TCI_GAIN, TCI_GAMMA from s2sr_pipe.data_processing.preprocessing import normalize from s2sr_pipe.data_processing.safe_reader import ( ALL_BANDS, find_band_files_from_zip, get_ref_profile, read_and_align_bands, ) from s2sr_pipe.model.architecture import build_model, get_scale from s2sr_pipe.model.inference import infer_large from s2sr_pipe.utils.geo_utils import roi_wgs84_to_pixel_window # ── Constants ───────────────────────────────────────────────────────────────── MODEL_DIR = Path(os.environ.get("MODEL_DIR", "sen2sr_model")) DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") TCI_BANDS = (2, 1, 0) # R=B04(idx2), G=B03(idx1), B=B02(idx0) MAX_DISP_PX = 800 PRESETS = { "Zone industrielle (Tatarstan)": "industrial", "Terres agricoles (Tatarstan)": "farmland", } VARIANT_BANDS = { "main": list(range(10)), "rgbn_x4": [ALL_BANDS.index(b) for b in ["B02", "B03", "B04", "B08"]], "rswir_x2": list(range(10)), } # ── Model cache ─────────────────────────────────────────────────────────────── _model_cache: dict = {} def _get_model(variant: str): if variant not in _model_cache: _model_cache[variant] = build_model( variant=variant, model_dir=MODEL_DIR, device=DEVICE ) return _model_cache[variant], get_scale(variant) # ── TCI rendering ───────────────────────────────────────────────────────────── def _render_tci(rgb_float: np.ndarray) -> np.ndarray: inv_gamma = 1.0 / TCI_GAMMA rgb = np.clip(rgb_float * TCI_GAIN, 0.0, 1.0) ** inv_gamma return (rgb * 255.0).astype(np.uint8).transpose(1, 2, 0) # (H, W, 3) def _write_tci_tif(rgb_float: np.ndarray, crs, transform: Affine, path: str) -> None: """Write a 3-band uint8 TCI GeoTIFF.""" rgb_u8 = _render_tci(rgb_float) # (H, W, 3) H, W = rgb_u8.shape[:2] with rasterio.open( path, "w", driver="GTiff", width=W, height=H, count=3, dtype="uint8", crs=crs, transform=transform, compress="deflate", ) as dst: dst.write(rgb_u8[:, :, 0], 1) dst.write(rgb_u8[:, :, 1], 2) dst.write(rgb_u8[:, :, 2], 3) def _make_comparison(lr_array: np.ndarray, sr_array: np.ndarray): """Return (img_before, img_after) as PIL Images, both at MAX_DISP_PX.""" r, g, b = TCI_BANDS lr_u8 = _render_tci(lr_array[[r, g, b]]) sr_u8 = _render_tci(sr_array[[r, g, b]]) H_sr, W_sr = sr_u8.shape[:2] if max(H_sr, W_sr) > MAX_DISP_PX: s = MAX_DISP_PX / max(H_sr, W_sr) W_d, H_d = int(W_sr * s), int(H_sr * s) else: W_d, H_d = W_sr, H_sr img_after = Image.fromarray(sr_u8).resize((W_d, H_d), Image.LANCZOS) img_before = Image.fromarray(lr_u8).resize((W_d, H_d), Image.NEAREST) return img_before, img_after # ── Inference helper ────────────────────────────────────────────────────────── def _infer(lr_array: np.ndarray, variant: str, progress=None) -> np.ndarray: _p = progress or (lambda *a, **k: None) model, _ = _get_model(variant) band_idx = VARIANT_BANDS[variant] lr_input = lr_array[band_idx] _p(0.4, desc="Inference SR en cours (CPU, patience)...") return infer_large( model=model, array=lr_input, device=DEVICE, overlap=32, use_amp=False, batch_size=1, ) # ── Map helpers ─────────────────────────────────────────────────────────────── def _profile_to_wgs84_bbox(profile: dict) -> tuple: """Return tile bounding box as (min_lon, min_lat, max_lon, max_lat) WGS84.""" T = profile["transform"] w, h = profile["width"], profile["height"] left, top = T.c, T.f right = left + w * T.a bottom = top + h * T.e # T.e is negative west, south, east, north = transform_bounds( profile["crs"], CRS.from_epsg(4326), left, bottom, right, top ) return (west, south, east, north) def _make_map_data( tile_bbox=None, tile_px=None, roi=None, roi_too_large: bool = False, roi_warning_text: str = "", interactive: bool = False, ) -> str: """Serialise map state as JSON for the Python→JS bridge textbox.""" return json.dumps({ "tile_bbox": list(tile_bbox) if tile_bbox else None, "tile_px": list(tile_px) if tile_px else None, "roi": list(roi) if roi else None, "roi_too_large": roi_too_large, "roi_warning_text": roi_warning_text, "interactive": interactive, }) def _clamp_roi(roi: list, tile_bbox: list) -> list: """Clip ROI to tile bounding box.""" return [ max(roi[0], tile_bbox[0]), # min_lon max(roi[1], tile_bbox[1]), # min_lat min(roi[2], tile_bbox[2]), # max_lon min(roi[3], tile_bbox[3]), # max_lat ] def _estimate_roi_px(roi: list, tile_bbox: list, tile_px: list): """ Estimate ROI pixel dimensions by proportional mapping over the tile extent. Returns (w_px, h_px) or (None, None) if data is missing. """ if not tile_bbox or not tile_px: return None, None tile_w_deg = tile_bbox[2] - tile_bbox[0] tile_h_deg = tile_bbox[3] - tile_bbox[1] if tile_w_deg <= 0 or tile_h_deg <= 0: return None, None w_px = int((roi[2] - roi[0]) / tile_w_deg * tile_px[0]) h_px = int((roi[3] - roi[1]) / tile_h_deg * tile_px[1]) return w_px, h_px # Static HTML — loaded once; JS polls the hidden #map_data textbox every 400 ms. _MAP_HTML = """
📂 Uploadez un fichier ZIP pour activer la carte
""" # ── Map event handlers ──────────────────────────────────────────────────────── def _on_zip_upload(zip_file) -> str: """Called when ZIP is uploaded; returns map_data JSON with tile bbox + pixel dims.""" if zip_file is None: return _make_map_data() try: zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name) band_files = find_band_files_from_zip(zip_path) profile = get_ref_profile(band_files) bbox = _profile_to_wgs84_bbox(profile) tile_px = (profile["width"], profile["height"]) return _make_map_data(tile_bbox=bbox, tile_px=tile_px, interactive=True) except Exception: return _make_map_data() def _on_roi_drawn(roi_str: str, map_data_str: str): """ Parse drawn ROI from JS (already clamped client-side), clamp again server-side, compute pixel estimate, and return updated fields + map_data + warning. """ try: cur = json.loads(map_data_str) if map_data_str else {} except Exception: cur = {} tile_bbox = cur.get("tile_bbox") tile_px = cur.get("tile_px") mn_lon = mn_lat = mx_lon = mx_lat = None if roi_str and roi_str.strip(): try: parts = [float(x) for x in roi_str.strip().split(',')] if len(parts) == 4: mn_lon, mn_lat, mx_lon, mx_lat = parts except ValueError: pass roi_list = None w_px = h_px = None too_large = False if mn_lon is not None: roi_list = [mn_lon, mn_lat, mx_lon, mx_lat] if tile_bbox: roi_list = _clamp_roi(roi_list, tile_bbox) mn_lon, mn_lat, mx_lon, mx_lat = roi_list if tile_px: w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px) too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500)) warn_text = ( f"⚠️ ROI trop grand : ~{w_px} × {h_px} px (limite : 1 500 × 1 500 px)" if too_large else "" ) new_map_data = _make_map_data( tile_bbox=tile_bbox, tile_px=tile_px, roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text, interactive=cur.get("interactive", False), ) return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data def _on_coords_change(mn_lon, mn_lat, mx_lon, mx_lat, map_data_str: str): """Clamp coordinates to tile, estimate pixel size, update map + warning.""" try: cur = json.loads(map_data_str) if map_data_str else {} except Exception: cur = {} tile_bbox = cur.get("tile_bbox") tile_px = cur.get("tile_px") roi_list = None w_px = h_px = None too_large = False try: if all(v is not None for v in [mn_lon, mn_lat, mx_lon, mx_lat]): roi_list = [float(mn_lon), float(mn_lat), float(mx_lon), float(mx_lat)] if tile_bbox: roi_list = _clamp_roi(roi_list, tile_bbox) mn_lon, mn_lat, mx_lon, mx_lat = roi_list if tile_px: w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px) too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500)) except (TypeError, ValueError): pass warn_text = ( f"⚠️ ROI trop grand : ~{w_px} × {h_px} px (limite : 1 500 × 1 500 px)" if too_large else "" ) new_map_data = _make_map_data( tile_bbox=tile_bbox, tile_px=tile_px, roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text, interactive=cur.get("interactive", False), ) return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data # ── Mode Exemple ────────────────────────────────────────────────────────────── def run_example(preset_label: str, variant: str, progress=None): _p = progress or (lambda *a, **k: None) preset_name = PRESETS[preset_label] _p(0.1, desc="Chargement de l'exemple...") lr_array = np.load(f"examples/{preset_name}.npy") with open(f"examples/{preset_name}.json") as f: meta = json.load(f) a, b, c, d, e, ff = meta["transform"] crs = CRS.from_string(meta["crs"]) T_lr = Affine(a, b, c, d, e, ff) sr_array = _infer(lr_array, variant, progress) _p(0.9, desc="Rendu TCI...") r, g, b_idx = TCI_BANDS img_before, img_after = _make_comparison(lr_array, sr_array) scale = get_scale(variant) T_sr = Affine(a / scale, b, c, d, e / scale, ff) tmpdir = Path(tempfile.mkdtemp()) p_before = str(tmpdir / "avant_10m.jpg") p_after = str(tmpdir / "apres_2p5m.jpg") p_tif_before = str(tmpdir / "TCI_avant_10m.tif") p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif") img_before.save(p_before, quality=90) img_after.save(p_after, quality=90) _write_tci_tif(lr_array[[r, g, b_idx]], crs, T_lr, p_tif_before) _write_tci_tif(sr_array[[r, g, b_idx]], crs, T_sr, p_tif_after) _p(1.0, desc="Termine !") return ( [(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")], p_before, p_after, p_tif_before, p_tif_after, ) # ── Mode ZIP ────────────────────────────────────────────────────────────────── def run_zip( zip_file, min_lon, min_lat, max_lon, max_lat, variant: str, progress=None, ): if zip_file is None: raise gr.Error("Uploadez d'abord un fichier ZIP Sentinel-2 L2A.") if any(v is None for v in [min_lon, min_lat, max_lon, max_lat]): raise gr.Error("Renseignez les 4 coordonnees ROI.") _p = progress or (lambda *a, **k: None) zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name) roi = (float(min_lon), float(min_lat), float(max_lon), float(max_lat)) try: _p(0.05, desc="Lecture structure S2...") band_files = find_band_files_from_zip(zip_path) ref_profile = get_ref_profile(band_files) _p(0.1, desc="Calcul fenetre ROI...") roi_win = roi_wgs84_to_pixel_window(roi, ref_profile, buffer_px=128) roi_w, roi_h = roi_win["roi_width"], roi_win["roi_height"] if roi_w > 1500 or roi_h > 1500: raise gr.Error( f"ROI trop grand pour cette demo : {roi_w} x {roi_h} px " f"(limite demo HuggingFace CPU : 1500 x 1500 px). " f"Reduisez la zone ou utilisez le pipeline complet en local " f"(https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution)." ) rr_min, rr_max, rc_min, rc_max = roi_win["read"] read_window = rasterio.windows.Window( col_off=rc_min, row_off=rr_min, width=rc_max - rc_min, height=rr_max - rr_min, ) _p(0.2, desc="Lecture fenetre ROI (quelques secondes)...") raw_array, _ = read_and_align_bands(band_files, window=read_window) norm_array = normalize(raw_array) del raw_array ir_start, ir_end, ic_start, ic_end = roi_win["inner"] lr_exact = norm_array[:, ir_start:ir_end, ic_start:ic_end].copy() sr_array = _infer(norm_array, variant, progress) del norm_array scale = get_scale(variant) sr_exact = sr_array[ :, ir_start * scale : ir_end * scale, ic_start * scale : ic_end * scale, ] _p(0.9, desc="Rendu TCI...") r, g, b_idx = TCI_BANDS img_before, img_after = _make_comparison(lr_exact, sr_exact) crs = ref_profile["crs"] T_lr = roi_win["roi_transform"] T_sr = Affine( T_lr.a / scale, T_lr.b, T_lr.c, T_lr.d, T_lr.e / scale, T_lr.f, ) except ValueError as e: raise gr.Error(str(e)) tmpdir = Path(tempfile.mkdtemp()) p_before = str(tmpdir / "avant_10m.jpg") p_after = str(tmpdir / "apres_2p5m.jpg") p_tif_before = str(tmpdir / "TCI_avant_10m.tif") p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif") img_before.save(p_before, quality=90) img_after.save(p_after, quality=90) _write_tci_tif(lr_exact[[r, g, b_idx]], crs, T_lr, p_tif_before) _write_tci_tif(sr_exact[[r, g, b_idx]], crs, T_sr, p_tif_after) _p(1.0, desc="Termine !") return ( [(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")], p_before, p_after, p_tif_before, p_tif_after, ) # ── Gradio UI ───────────────────────────────────────────────────────────────── with gr.Blocks(title="SEN2SR — Super-Resolution Sentinel-2") as demo: gr.Markdown(""" # SEN2SR — Super-Resolution Sentinel-2 **10 m/pixel → 2.5 m/pixel (×4)** sur des images Sentinel-2 L2A. Modele officiel ESA : [tacofoundation/sen2sr](https://huggingface.co/tacofoundation/sen2sr) — Code pipeline : [GitHub](https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution) > L'inference tourne sur CPU — comptez **1 a 5 minutes** selon la taille du ROI. """) with gr.Row(): source = gr.Radio( choices=["Exemple Tatarstan", "Mon ZIP Sentinel-2"], value="Exemple Tatarstan", label="Source des donnees", ) variant = gr.Dropdown( choices=["main", "rgbn_x4", "rswir_x2"], value="main", label="Variante du modele", info="main: 10 bandes x4 | rgbn_x4: RGB+NIR x4 | rswir_x2: 10 bandes x2", ) with gr.Group(visible=True) as panel_example: preset = gr.Radio( choices=list(PRESETS.keys()), value="Zone industrielle (Tatarstan)", label="Zone de demonstration", ) with gr.Group(visible=False) as panel_zip: with gr.Row(equal_height=True): with gr.Column(scale=1, min_width=180): zip_input = gr.File( label="Fichier ZIP Sentinel-2 L2A", file_types=[".zip"], height=440, ) with gr.Column(scale=3): # ── OSM map ────────────────────────────────────────────────── gr.Markdown("**Carte — emprise de la dalle et ROI**") map_html = gr.HTML(value=_MAP_HTML) # Hidden bridge textboxes (not shown to the user) map_data = gr.Textbox(visible=False, elem_id="map_data", value=_make_map_data()) roi_drawn = gr.Textbox(visible=False, elem_id="roi_drawn", value="") # ── ROI coordinates ────────────────────────────────────────── gr.Markdown("**ROI — coordonnees WGS84** (dessinez sur la carte ou saisissez manuellement)") with gr.Row(): min_lon = gr.Number(label="Min longitude", value=52.030) min_lat = gr.Number(label="Min latitude", value=55.828) max_lon = gr.Number(label="Max longitude", value=52.048) max_lat = gr.Number(label="Max latitude", value=55.843) def _toggle(src): return ( gr.update(visible=(src == "Exemple Tatarstan")), gr.update(visible=(src == "Mon ZIP Sentinel-2")), ) source.change(_toggle, inputs=source, outputs=[panel_example, panel_zip]) run_btn = gr.Button("Lancer la super-resolution", variant="primary", size="lg") gallery = gr.Gallery( label="Avant (10 m) / Apres (2.5 m SR)", columns=2, height=520, ) with gr.Row(): dl_before = gr.File(label="Telecharger — Avant 10m (JPEG)") dl_after = gr.File(label="Telecharger — Apres 2.5m SR (JPEG)") with gr.Row(): dl_tif_before = gr.File(label="Telecharger — TCI Avant GeoTIFF") dl_tif_after = gr.File(label="Telecharger — TCI Apres SR GeoTIFF") # ── Map events ──────────────────────────────────────────────────────────── zip_input.upload(fn=_on_zip_upload, inputs=[zip_input], outputs=[map_data]) zip_input.clear(fn=lambda: _make_map_data(), inputs=[], outputs=[map_data]) roi_drawn.change( fn=_on_roi_drawn, inputs=[roi_drawn, map_data], outputs=[min_lon, min_lat, max_lon, max_lat, map_data], ) _coord_inputs = [min_lon, min_lat, max_lon, max_lat] for _c in _coord_inputs: _c.change( fn=_on_coords_change, inputs=_coord_inputs + [map_data], outputs=[min_lon, min_lat, max_lon, max_lat, map_data], ) # ── Run button ──────────────────────────────────────────────────────────── def _on_run(src, p, zf, mn_lon, mn_lat, mx_lon, mx_lat, var): if src == "Exemple Tatarstan": return run_example(p, var) return run_zip(zf, mn_lon, mn_lat, mx_lon, mx_lat, var) run_btn.click( _on_run, inputs=[source, preset, zip_input, min_lon, min_lat, max_lon, max_lat, variant], outputs=[gallery, dl_before, dl_after, dl_tif_before, dl_tif_after], api_name="run_sr", ) if __name__ == "__main__": demo.launch()