| """ |
| HuggingFace Space β SEN2SR Super-Resolution Sentinel-2. |
| Two modes: pre-loaded example or custom ZIP upload. |
| """ |
| from __future__ import annotations |
|
|
| import json |
| import os |
| import tempfile |
| from pathlib import Path |
|
|
| import gradio as gr |
| import numpy as np |
| import rasterio |
| import rasterio.windows |
| import torch |
| from PIL import Image |
| from rasterio.crs import CRS |
| from rasterio.transform import Affine |
| from rasterio.warp import transform_bounds |
|
|
| from s2sr_pipe.data_processing.postprocessing import TCI_GAIN, TCI_GAMMA |
| from s2sr_pipe.data_processing.preprocessing import normalize |
| from s2sr_pipe.data_processing.safe_reader import ( |
| ALL_BANDS, find_band_files_from_zip, get_ref_profile, read_and_align_bands, |
| ) |
| from s2sr_pipe.model.architecture import build_model, get_scale |
| from s2sr_pipe.model.inference import infer_large |
| from s2sr_pipe.utils.geo_utils import roi_wgs84_to_pixel_window |
|
|
| |
| MODEL_DIR = Path(os.environ.get("MODEL_DIR", "sen2sr_model")) |
| DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu") |
| TCI_BANDS = (2, 1, 0) |
| MAX_DISP_PX = 800 |
|
|
| PRESETS = { |
| "Zone industrielle (Tatarstan)": "industrial", |
| "Terres agricoles (Tatarstan)": "farmland", |
| } |
|
|
| VARIANT_BANDS = { |
| "main": list(range(10)), |
| "rgbn_x4": [ALL_BANDS.index(b) for b in ["B02", "B03", "B04", "B08"]], |
| "rswir_x2": list(range(10)), |
| } |
|
|
| |
| _model_cache: dict = {} |
|
|
| def _get_model(variant: str): |
| if variant not in _model_cache: |
| _model_cache[variant] = build_model( |
| variant=variant, model_dir=MODEL_DIR, device=DEVICE |
| ) |
| return _model_cache[variant], get_scale(variant) |
|
|
|
|
| |
| def _render_tci(rgb_float: np.ndarray) -> np.ndarray: |
| inv_gamma = 1.0 / TCI_GAMMA |
| rgb = np.clip(rgb_float * TCI_GAIN, 0.0, 1.0) ** inv_gamma |
| return (rgb * 255.0).astype(np.uint8).transpose(1, 2, 0) |
|
|
|
|
| def _write_tci_tif(rgb_float: np.ndarray, crs, transform: Affine, path: str) -> None: |
| """Write a 3-band uint8 TCI GeoTIFF.""" |
| rgb_u8 = _render_tci(rgb_float) |
| H, W = rgb_u8.shape[:2] |
| with rasterio.open( |
| path, "w", driver="GTiff", |
| width=W, height=H, count=3, |
| dtype="uint8", crs=crs, transform=transform, |
| compress="deflate", |
| ) as dst: |
| dst.write(rgb_u8[:, :, 0], 1) |
| dst.write(rgb_u8[:, :, 1], 2) |
| dst.write(rgb_u8[:, :, 2], 3) |
|
|
|
|
| def _make_comparison(lr_array: np.ndarray, sr_array: np.ndarray): |
| """Return (img_before, img_after) as PIL Images, both at MAX_DISP_PX.""" |
| r, g, b = TCI_BANDS |
| lr_u8 = _render_tci(lr_array[[r, g, b]]) |
| sr_u8 = _render_tci(sr_array[[r, g, b]]) |
|
|
| H_sr, W_sr = sr_u8.shape[:2] |
| if max(H_sr, W_sr) > MAX_DISP_PX: |
| s = MAX_DISP_PX / max(H_sr, W_sr) |
| W_d, H_d = int(W_sr * s), int(H_sr * s) |
| else: |
| W_d, H_d = W_sr, H_sr |
|
|
| img_after = Image.fromarray(sr_u8).resize((W_d, H_d), Image.LANCZOS) |
| img_before = Image.fromarray(lr_u8).resize((W_d, H_d), Image.NEAREST) |
| return img_before, img_after |
|
|
|
|
| |
| def _infer(lr_array: np.ndarray, variant: str, progress=None) -> np.ndarray: |
| _p = progress or (lambda *a, **k: None) |
| model, _ = _get_model(variant) |
| band_idx = VARIANT_BANDS[variant] |
| lr_input = lr_array[band_idx] |
| _p(0.4, desc="Inference SR en cours (CPU, patience)...") |
| return infer_large( |
| model=model, array=lr_input, device=DEVICE, |
| overlap=32, use_amp=False, batch_size=1, |
| ) |
|
|
|
|
| |
| def _profile_to_wgs84_bbox(profile: dict) -> tuple: |
| """Return tile bounding box as (min_lon, min_lat, max_lon, max_lat) WGS84.""" |
| T = profile["transform"] |
| w, h = profile["width"], profile["height"] |
| left, top = T.c, T.f |
| right = left + w * T.a |
| bottom = top + h * T.e |
| west, south, east, north = transform_bounds( |
| profile["crs"], CRS.from_epsg(4326), left, bottom, right, top |
| ) |
| return (west, south, east, north) |
|
|
|
|
| def _make_map_data( |
| tile_bbox=None, tile_px=None, |
| roi=None, roi_too_large: bool = False, roi_warning_text: str = "", |
| interactive: bool = False, |
| ) -> str: |
| """Serialise map state as JSON for the PythonβJS bridge textbox.""" |
| return json.dumps({ |
| "tile_bbox": list(tile_bbox) if tile_bbox else None, |
| "tile_px": list(tile_px) if tile_px else None, |
| "roi": list(roi) if roi else None, |
| "roi_too_large": roi_too_large, |
| "roi_warning_text": roi_warning_text, |
| "interactive": interactive, |
| }) |
|
|
|
|
| def _clamp_roi(roi: list, tile_bbox: list) -> list: |
| """Clip ROI to tile bounding box.""" |
| return [ |
| max(roi[0], tile_bbox[0]), |
| max(roi[1], tile_bbox[1]), |
| min(roi[2], tile_bbox[2]), |
| min(roi[3], tile_bbox[3]), |
| ] |
|
|
|
|
| def _estimate_roi_px(roi: list, tile_bbox: list, tile_px: list): |
| """ |
| Estimate ROI pixel dimensions by proportional mapping over the tile extent. |
| Returns (w_px, h_px) or (None, None) if data is missing. |
| """ |
| if not tile_bbox or not tile_px: |
| return None, None |
| tile_w_deg = tile_bbox[2] - tile_bbox[0] |
| tile_h_deg = tile_bbox[3] - tile_bbox[1] |
| if tile_w_deg <= 0 or tile_h_deg <= 0: |
| return None, None |
| w_px = int((roi[2] - roi[0]) / tile_w_deg * tile_px[0]) |
| h_px = int((roi[3] - roi[1]) / tile_h_deg * tile_px[1]) |
| return w_px, h_px |
|
|
|
|
|
|
| |
| _MAP_HTML = """ |
| <style> |
| #s2sr-map-wrap .leaflet-container { border-radius: 8px; } |
| .roi-warn-tip { |
| background: #fee2e2; |
| border: 1px solid #ef4444; border-radius: 5px; |
| color: #991b1b; font-weight: 700; font-size: 0.88em; |
| padding: 4px 10px; white-space: nowrap; |
| box-shadow: 0 1px 4px rgba(0,0,0,0.15); |
| } |
| .roi-warn-tip::before { display: none; } |
| .roi-map-warn-bar { |
| position: absolute; top: 8px; left: 50%; transform: translateX(-50%); |
| z-index: 1001; pointer-events: none; |
| background: #fee2e2; border: 1px solid #ef4444; border-radius: 6px; |
| color: #991b1b; font-weight: 700; font-size: 0.85em; |
| padding: 5px 14px; white-space: nowrap; |
| box-shadow: 0 1px 6px rgba(239,68,68,0.25); |
| } |
| </style> |
| <div id="s2sr-map-wrap" style="position:relative;width:100%;height:400px;"> |
| <div id="s2sr-map" style="width:100%;height:100%;"></div> |
| <div id="s2sr-map-overlay" |
| style="position:absolute;top:0;left:0;width:100%;height:100%; |
| background:rgba(0,0,0,0.42);display:flex;align-items:center; |
| justify-content:center;z-index:1000;border-radius:8px;pointer-events:none;"> |
| <div style="color:#fff;font-size:1.05em;font-weight:600; |
| background:rgba(0,0,0,0.62);padding:12px 28px; |
| border-radius:8px;text-align:center;"> |
| 📂 Uploadez un fichier ZIP pour activer la carte |
| </div> |
| </div> |
| </div> |
| <link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css"/> |
| <link rel="stylesheet" |
| href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css"/> |
| <script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script> |
| <script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js"></script> |
| <script> |
| (function () { |
| 'use strict'; |
| var _map = null, _tileLyr = null, _roiLyr = null, _drawCtrl = null, _warnBar = null; |
| var _cur = { tile_bbox: null, tile_px: null, roi: null, roi_too_large: false, interactive: false }; |
| var _lastMapDataVal = null; |
| |
| function _initMap() { |
| if (_map) return; |
| _map = L.map('s2sr-map').setView([20, 0], 2); |
| L.tileLayer('https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', { |
| attribution: '© <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a>', |
| maxZoom: 18 |
| }).addTo(_map); |
| } |
| |
| function _setTile(bbox) { |
| if (_tileLyr) { _map.removeLayer(_tileLyr); _tileLyr = null; } |
| if (!bbox) return; |
| _tileLyr = L.rectangle([[bbox[1], bbox[0]], [bbox[3], bbox[2]]], |
| { color: '#4488ff', weight: 2, fillOpacity: 0.08 }).addTo(_map); |
| _map.fitBounds(_tileLyr.getBounds(), { padding: [30, 30] }); |
| } |
| |
| function _setRoi(roi, tooLarge) { |
| if (_roiLyr) { _map.removeLayer(_roiLyr); _roiLyr = null; } |
| if (!roi) return; |
| var color = tooLarge ? '#ef4444' : '#22c55e'; |
| _roiLyr = L.rectangle([[roi[1], roi[0]], [roi[3], roi[2]]], |
| { color: color, weight: 2.5, fillOpacity: 0.15 }).addTo(_map); |
| if (tooLarge) { |
| _roiLyr.bindTooltip('β οΈ ROI trop grand', { |
| permanent: true, direction: 'center', className: 'roi-warn-tip' |
| }).openTooltip(); |
| } |
| } |
| |
| function _clampRoiToTile(roi) { |
| if (!_cur.tile_bbox) return roi; |
| var tb = _cur.tile_bbox; |
| return [ |
| Math.max(roi[0], tb[0]), Math.max(roi[1], tb[1]), |
| Math.min(roi[2], tb[2]), Math.min(roi[3], tb[3]), |
| ]; |
| } |
| |
| function _enableDraw() { |
| if (_drawCtrl) return; |
| var drawn = new L.FeatureGroup(); |
| _map.addLayer(drawn); |
| _drawCtrl = new L.Control.Draw({ |
| draw: { |
| rectangle: { shapeOptions: { color: '#ff4444' } }, |
| polyline: false, |
| polygon: false, |
| circle: false, |
| circlemarker: false, |
| marker: false |
| }, |
| edit: false |
| }); |
| _map.addControl(_drawCtrl); |
| |
| _map.on(L.Draw.Event.CREATED, function (e) { |
| var b = e.layer.getBounds(); |
| var sw = b.getSouthWest(), ne = b.getNorthEast(); |
| var roi = _clampRoiToTile([sw.lng, sw.lat, ne.lng, ne.lat]); |
| _setRoi(roi, false); // green placeholder; Python will confirm final color |
| |
| var val = roi[0].toFixed(6) + ',' + roi[1].toFixed(6) + ',' |
| + roi[2].toFixed(6) + ',' + roi[3].toFixed(6); |
| |
| // Write into the hidden Gradio textbox #roi_drawn to notify Python |
| function _push() { |
| var el = document.querySelector('#roi_drawn textarea') |
| || document.querySelector('#roi_drawn input[type=text]'); |
| if (!el) { setTimeout(_push, 200); return; } |
| // Use native setter so React/Svelte picks up the change |
| var proto = el.tagName === 'TEXTAREA' |
| ? window.HTMLTextAreaElement.prototype |
| : window.HTMLInputElement.prototype; |
| var setter = Object.getOwnPropertyDescriptor(proto, 'value'); |
| if (setter && setter.set) setter.set.call(el, val); else el.value = val; |
| el.dispatchEvent(new Event('input', { bubbles: true })); |
| el.dispatchEvent(new Event('change', { bubbles: true })); |
| } |
| _push(); |
| }); |
| } |
| |
| function _setWarningBar(text) { |
| if (!_warnBar) { |
| _warnBar = document.createElement('div'); |
| _warnBar.className = 'roi-map-warn-bar'; |
| document.getElementById('s2sr-map-wrap').appendChild(_warnBar); |
| } |
| _warnBar.textContent = text || ''; |
| _warnBar.style.display = text ? 'block' : 'none'; |
| } |
| |
| function _setInteractive(on) { |
| var overlay = document.getElementById('s2sr-map-overlay'); |
| if (overlay) overlay.style.display = on ? 'none' : 'flex'; |
| if (on) _enableDraw(); |
| } |
| |
| function _apply(data) { |
| _initMap(); |
| if (JSON.stringify(data.tile_bbox) !== JSON.stringify(_cur.tile_bbox)) |
| _setTile(data.tile_bbox); |
| if (JSON.stringify(data.roi) !== JSON.stringify(_cur.roi) || |
| data.roi_too_large !== _cur.roi_too_large) |
| _setRoi(data.roi, data.roi_too_large); |
| if (data.interactive !== _cur.interactive) |
| _setInteractive(data.interactive); |
| if ((data.roi_warning_text || '') !== (_cur.roi_warning_text || '')) |
| _setWarningBar(data.roi_warning_text); |
| _cur = data; |
| } |
| |
| function _poll() { |
| var el = document.querySelector('#map_data textarea') |
| || document.querySelector('#map_data input[type=text]'); |
| if (!el || el.value === _lastMapDataVal) return; |
| _lastMapDataVal = el.value; |
| if (el.value) { try { _apply(JSON.parse(el.value)); } catch (e) {} } |
| } |
| |
| function _boot() { |
| if (!document.getElementById('s2sr-map')) { setTimeout(_boot, 150); return; } |
| _initMap(); |
| setInterval(_poll, 400); |
| } |
| |
| if (document.readyState === 'loading') |
| document.addEventListener('DOMContentLoaded', _boot); |
| else |
| _boot(); |
| })(); |
| </script> |
| """ |
|
|
|
|
| |
| def _on_zip_upload(zip_file) -> str: |
| """Called when ZIP is uploaded; returns map_data JSON with tile bbox + pixel dims.""" |
| if zip_file is None: |
| return _make_map_data() |
| try: |
| zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name) |
| band_files = find_band_files_from_zip(zip_path) |
| profile = get_ref_profile(band_files) |
| bbox = _profile_to_wgs84_bbox(profile) |
| tile_px = (profile["width"], profile["height"]) |
| return _make_map_data(tile_bbox=bbox, tile_px=tile_px, interactive=True) |
| except Exception: |
| return _make_map_data() |
|
|
|
|
| def _on_roi_drawn(roi_str: str, map_data_str: str): |
| """ |
| Parse drawn ROI from JS (already clamped client-side), clamp again server-side, |
| compute pixel estimate, and return updated fields + map_data + warning. |
| """ |
| try: |
| cur = json.loads(map_data_str) if map_data_str else {} |
| except Exception: |
| cur = {} |
|
|
| tile_bbox = cur.get("tile_bbox") |
| tile_px = cur.get("tile_px") |
|
|
| mn_lon = mn_lat = mx_lon = mx_lat = None |
| if roi_str and roi_str.strip(): |
| try: |
| parts = [float(x) for x in roi_str.strip().split(',')] |
| if len(parts) == 4: |
| mn_lon, mn_lat, mx_lon, mx_lat = parts |
| except ValueError: |
| pass |
|
|
| roi_list = None |
| w_px = h_px = None |
| too_large = False |
| if mn_lon is not None: |
| roi_list = [mn_lon, mn_lat, mx_lon, mx_lat] |
| if tile_bbox: |
| roi_list = _clamp_roi(roi_list, tile_bbox) |
| mn_lon, mn_lat, mx_lon, mx_lat = roi_list |
| if tile_px: |
| w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px) |
| too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500)) |
|
|
| warn_text = ( |
| f"β οΈ ROI trop grand : ~{w_px}β―Γβ―{h_px}β―px (limiteβ―:β―1β―500β―Γβ―1β―500β―px)" |
| if too_large else "" |
| ) |
| new_map_data = _make_map_data( |
| tile_bbox=tile_bbox, tile_px=tile_px, |
| roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text, |
| interactive=cur.get("interactive", False), |
| ) |
| return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data |
|
|
|
|
| def _on_coords_change(mn_lon, mn_lat, mx_lon, mx_lat, map_data_str: str): |
| """Clamp coordinates to tile, estimate pixel size, update map + warning.""" |
| try: |
| cur = json.loads(map_data_str) if map_data_str else {} |
| except Exception: |
| cur = {} |
|
|
| tile_bbox = cur.get("tile_bbox") |
| tile_px = cur.get("tile_px") |
|
|
| roi_list = None |
| w_px = h_px = None |
| too_large = False |
| try: |
| if all(v is not None for v in [mn_lon, mn_lat, mx_lon, mx_lat]): |
| roi_list = [float(mn_lon), float(mn_lat), float(mx_lon), float(mx_lat)] |
| if tile_bbox: |
| roi_list = _clamp_roi(roi_list, tile_bbox) |
| mn_lon, mn_lat, mx_lon, mx_lat = roi_list |
| if tile_px: |
| w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px) |
| too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500)) |
| except (TypeError, ValueError): |
| pass |
|
|
| warn_text = ( |
| f"β οΈ ROI trop grand : ~{w_px}β―Γβ―{h_px}β―px (limiteβ―:β―1β―500β―Γβ―1β―500β―px)" |
| if too_large else "" |
| ) |
| new_map_data = _make_map_data( |
| tile_bbox=tile_bbox, tile_px=tile_px, |
| roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text, |
| interactive=cur.get("interactive", False), |
| ) |
| return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data |
|
|
|
|
| |
| def run_example(preset_label: str, variant: str, progress=None): |
| _p = progress or (lambda *a, **k: None) |
| preset_name = PRESETS[preset_label] |
| _p(0.1, desc="Chargement de l'exemple...") |
|
|
| lr_array = np.load(f"examples/{preset_name}.npy") |
|
|
| with open(f"examples/{preset_name}.json") as f: |
| meta = json.load(f) |
| a, b, c, d, e, ff = meta["transform"] |
| crs = CRS.from_string(meta["crs"]) |
| T_lr = Affine(a, b, c, d, e, ff) |
|
|
| sr_array = _infer(lr_array, variant, progress) |
| _p(0.9, desc="Rendu TCI...") |
|
|
| r, g, b_idx = TCI_BANDS |
| img_before, img_after = _make_comparison(lr_array, sr_array) |
|
|
| scale = get_scale(variant) |
| T_sr = Affine(a / scale, b, c, d, e / scale, ff) |
|
|
| tmpdir = Path(tempfile.mkdtemp()) |
| p_before = str(tmpdir / "avant_10m.jpg") |
| p_after = str(tmpdir / "apres_2p5m.jpg") |
| p_tif_before = str(tmpdir / "TCI_avant_10m.tif") |
| p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif") |
|
|
| img_before.save(p_before, quality=90) |
| img_after.save(p_after, quality=90) |
| _write_tci_tif(lr_array[[r, g, b_idx]], crs, T_lr, p_tif_before) |
| _write_tci_tif(sr_array[[r, g, b_idx]], crs, T_sr, p_tif_after) |
|
|
| _p(1.0, desc="Termine !") |
| return ( |
| [(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")], |
| p_before, p_after, p_tif_before, p_tif_after, |
| ) |
|
|
|
|
| |
| def run_zip( |
| zip_file, min_lon, min_lat, max_lon, max_lat, |
| variant: str, progress=None, |
| ): |
| if zip_file is None: |
| raise gr.Error("Uploadez d'abord un fichier ZIP Sentinel-2 L2A.") |
| if any(v is None for v in [min_lon, min_lat, max_lon, max_lat]): |
| raise gr.Error("Renseignez les 4 coordonnees ROI.") |
|
|
| _p = progress or (lambda *a, **k: None) |
| zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name) |
| roi = (float(min_lon), float(min_lat), float(max_lon), float(max_lat)) |
|
|
| try: |
| _p(0.05, desc="Lecture structure S2...") |
| band_files = find_band_files_from_zip(zip_path) |
| ref_profile = get_ref_profile(band_files) |
|
|
| _p(0.1, desc="Calcul fenetre ROI...") |
| roi_win = roi_wgs84_to_pixel_window(roi, ref_profile, buffer_px=128) |
|
|
| roi_w, roi_h = roi_win["roi_width"], roi_win["roi_height"] |
| if roi_w > 1500 or roi_h > 1500: |
| raise gr.Error( |
| f"ROI trop grand pour cette demo : {roi_w} x {roi_h} px " |
| f"(limite demo HuggingFace CPU : 1500 x 1500 px). " |
| f"Reduisez la zone ou utilisez le pipeline complet en local " |
| f"(https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution)." |
| ) |
|
|
| rr_min, rr_max, rc_min, rc_max = roi_win["read"] |
| read_window = rasterio.windows.Window( |
| col_off=rc_min, row_off=rr_min, |
| width=rc_max - rc_min, height=rr_max - rr_min, |
| ) |
|
|
| _p(0.2, desc="Lecture fenetre ROI (quelques secondes)...") |
| raw_array, _ = read_and_align_bands(band_files, window=read_window) |
| norm_array = normalize(raw_array) |
| del raw_array |
|
|
| ir_start, ir_end, ic_start, ic_end = roi_win["inner"] |
| lr_exact = norm_array[:, ir_start:ir_end, ic_start:ic_end].copy() |
|
|
| sr_array = _infer(norm_array, variant, progress) |
| del norm_array |
|
|
| scale = get_scale(variant) |
| sr_exact = sr_array[ |
| :, |
| ir_start * scale : ir_end * scale, |
| ic_start * scale : ic_end * scale, |
| ] |
|
|
| _p(0.9, desc="Rendu TCI...") |
| r, g, b_idx = TCI_BANDS |
| img_before, img_after = _make_comparison(lr_exact, sr_exact) |
|
|
| crs = ref_profile["crs"] |
| T_lr = roi_win["roi_transform"] |
| T_sr = Affine( |
| T_lr.a / scale, T_lr.b, T_lr.c, |
| T_lr.d, T_lr.e / scale, T_lr.f, |
| ) |
|
|
| except ValueError as e: |
| raise gr.Error(str(e)) |
|
|
| tmpdir = Path(tempfile.mkdtemp()) |
| p_before = str(tmpdir / "avant_10m.jpg") |
| p_after = str(tmpdir / "apres_2p5m.jpg") |
| p_tif_before = str(tmpdir / "TCI_avant_10m.tif") |
| p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif") |
|
|
| img_before.save(p_before, quality=90) |
| img_after.save(p_after, quality=90) |
| _write_tci_tif(lr_exact[[r, g, b_idx]], crs, T_lr, p_tif_before) |
| _write_tci_tif(sr_exact[[r, g, b_idx]], crs, T_sr, p_tif_after) |
|
|
| _p(1.0, desc="Termine !") |
| return ( |
| [(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")], |
| p_before, p_after, p_tif_before, p_tif_after, |
| ) |
|
|
|
|
| |
| with gr.Blocks(title="SEN2SR β Super-Resolution Sentinel-2") as demo: |
|
|
| gr.Markdown(""" |
| # SEN2SR β Super-Resolution Sentinel-2 |
| **10 m/pixel β 2.5 m/pixel (Γ4)** sur des images Sentinel-2 L2A. |
| Modele officiel ESA : [tacofoundation/sen2sr](https://huggingface.co/tacofoundation/sen2sr) |
| β Code pipeline : [GitHub](https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution) |
| |
| > L'inference tourne sur CPU β comptez **1 a 5 minutes** selon la taille du ROI. |
| """) |
|
|
| with gr.Row(): |
| source = gr.Radio( |
| choices=["Exemple Tatarstan", "Mon ZIP Sentinel-2"], |
| value="Exemple Tatarstan", |
| label="Source des donnees", |
| ) |
| variant = gr.Dropdown( |
| choices=["main", "rgbn_x4", "rswir_x2"], |
| value="main", |
| label="Variante du modele", |
| info="main: 10 bandes x4 | rgbn_x4: RGB+NIR x4 | rswir_x2: 10 bandes x2", |
| ) |
|
|
| with gr.Group(visible=True) as panel_example: |
| preset = gr.Radio( |
| choices=list(PRESETS.keys()), |
| value="Zone industrielle (Tatarstan)", |
| label="Zone de demonstration", |
| ) |
|
|
| with gr.Group(visible=False) as panel_zip: |
| with gr.Row(equal_height=True): |
| with gr.Column(scale=1, min_width=180): |
| zip_input = gr.File( |
| label="Fichier ZIP Sentinel-2 L2A", |
| file_types=[".zip"], |
| height=440, |
| ) |
| with gr.Column(scale=3): |
| |
| gr.Markdown("**Carte β emprise de la dalle et ROI**") |
| map_html = gr.HTML(value=_MAP_HTML) |
|
|
| |
| map_data = gr.Textbox(visible=False, elem_id="map_data", value=_make_map_data()) |
| roi_drawn = gr.Textbox(visible=False, elem_id="roi_drawn", value="") |
|
|
| |
| gr.Markdown("**ROI β coordonnees WGS84** (dessinez sur la carte ou saisissez manuellement)") |
| with gr.Row(): |
| min_lon = gr.Number(label="Min longitude", value=52.030) |
| min_lat = gr.Number(label="Min latitude", value=55.828) |
| max_lon = gr.Number(label="Max longitude", value=52.048) |
| max_lat = gr.Number(label="Max latitude", value=55.843) |
|
|
| def _toggle(src): |
| return ( |
| gr.update(visible=(src == "Exemple Tatarstan")), |
| gr.update(visible=(src == "Mon ZIP Sentinel-2")), |
| ) |
| source.change(_toggle, inputs=source, outputs=[panel_example, panel_zip]) |
|
|
| run_btn = gr.Button("Lancer la super-resolution", variant="primary", size="lg") |
|
|
| gallery = gr.Gallery( |
| label="Avant (10 m) / Apres (2.5 m SR)", |
| columns=2, |
| height=520, |
| ) |
| with gr.Row(): |
| dl_before = gr.File(label="Telecharger β Avant 10m (JPEG)") |
| dl_after = gr.File(label="Telecharger β Apres 2.5m SR (JPEG)") |
| with gr.Row(): |
| dl_tif_before = gr.File(label="Telecharger β TCI Avant GeoTIFF") |
| dl_tif_after = gr.File(label="Telecharger β TCI Apres SR GeoTIFF") |
|
|
| |
| zip_input.upload(fn=_on_zip_upload, inputs=[zip_input], outputs=[map_data]) |
| zip_input.clear(fn=lambda: _make_map_data(), inputs=[], outputs=[map_data]) |
|
|
| roi_drawn.change( |
| fn=_on_roi_drawn, |
| inputs=[roi_drawn, map_data], |
| outputs=[min_lon, min_lat, max_lon, max_lat, map_data], |
| ) |
|
|
| _coord_inputs = [min_lon, min_lat, max_lon, max_lat] |
| for _c in _coord_inputs: |
| _c.change( |
| fn=_on_coords_change, |
| inputs=_coord_inputs + [map_data], |
| outputs=[min_lon, min_lat, max_lon, max_lat, map_data], |
| ) |
|
|
| |
| def _on_run(src, p, zf, mn_lon, mn_lat, mx_lon, mx_lat, var): |
| if src == "Exemple Tatarstan": |
| return run_example(p, var) |
| return run_zip(zf, mn_lon, mn_lat, mx_lon, mx_lat, var) |
|
|
| run_btn.click( |
| _on_run, |
| inputs=[source, preset, zip_input, min_lon, min_lat, max_lon, max_lat, variant], |
| outputs=[gallery, dl_before, dl_after, dl_tif_before, dl_tif_after], |
| api_name="run_sr", |
| ) |
|
|
| if __name__ == "__main__": |
| demo.launch() |
|
|