gdubrasquetd's picture
fix: add roi_warning_text param to _make_map_data (missing from previous deploy)
2afa4db
Raw
History Blame Contribute Delete
27.2 kB
"""
HuggingFace Space β€” SEN2SR Super-Resolution Sentinel-2.
Two modes: pre-loaded example or custom ZIP upload.
"""
from __future__ import annotations
import json
import os
import tempfile
from pathlib import Path
import gradio as gr
import numpy as np
import rasterio
import rasterio.windows
import torch
from PIL import Image
from rasterio.crs import CRS
from rasterio.transform import Affine
from rasterio.warp import transform_bounds
from s2sr_pipe.data_processing.postprocessing import TCI_GAIN, TCI_GAMMA
from s2sr_pipe.data_processing.preprocessing import normalize
from s2sr_pipe.data_processing.safe_reader import (
ALL_BANDS, find_band_files_from_zip, get_ref_profile, read_and_align_bands,
)
from s2sr_pipe.model.architecture import build_model, get_scale
from s2sr_pipe.model.inference import infer_large
from s2sr_pipe.utils.geo_utils import roi_wgs84_to_pixel_window
# ── Constants ─────────────────────────────────────────────────────────────────
MODEL_DIR = Path(os.environ.get("MODEL_DIR", "sen2sr_model"))
DEVICE = torch.device("cuda" if torch.cuda.is_available() else "cpu")
TCI_BANDS = (2, 1, 0) # R=B04(idx2), G=B03(idx1), B=B02(idx0)
MAX_DISP_PX = 800
PRESETS = {
"Zone industrielle (Tatarstan)": "industrial",
"Terres agricoles (Tatarstan)": "farmland",
}
VARIANT_BANDS = {
"main": list(range(10)),
"rgbn_x4": [ALL_BANDS.index(b) for b in ["B02", "B03", "B04", "B08"]],
"rswir_x2": list(range(10)),
}
# ── Model cache ───────────────────────────────────────────────────────────────
_model_cache: dict = {}
def _get_model(variant: str):
if variant not in _model_cache:
_model_cache[variant] = build_model(
variant=variant, model_dir=MODEL_DIR, device=DEVICE
)
return _model_cache[variant], get_scale(variant)
# ── TCI rendering ─────────────────────────────────────────────────────────────
def _render_tci(rgb_float: np.ndarray) -> np.ndarray:
inv_gamma = 1.0 / TCI_GAMMA
rgb = np.clip(rgb_float * TCI_GAIN, 0.0, 1.0) ** inv_gamma
return (rgb * 255.0).astype(np.uint8).transpose(1, 2, 0) # (H, W, 3)
def _write_tci_tif(rgb_float: np.ndarray, crs, transform: Affine, path: str) -> None:
"""Write a 3-band uint8 TCI GeoTIFF."""
rgb_u8 = _render_tci(rgb_float) # (H, W, 3)
H, W = rgb_u8.shape[:2]
with rasterio.open(
path, "w", driver="GTiff",
width=W, height=H, count=3,
dtype="uint8", crs=crs, transform=transform,
compress="deflate",
) as dst:
dst.write(rgb_u8[:, :, 0], 1)
dst.write(rgb_u8[:, :, 1], 2)
dst.write(rgb_u8[:, :, 2], 3)
def _make_comparison(lr_array: np.ndarray, sr_array: np.ndarray):
"""Return (img_before, img_after) as PIL Images, both at MAX_DISP_PX."""
r, g, b = TCI_BANDS
lr_u8 = _render_tci(lr_array[[r, g, b]])
sr_u8 = _render_tci(sr_array[[r, g, b]])
H_sr, W_sr = sr_u8.shape[:2]
if max(H_sr, W_sr) > MAX_DISP_PX:
s = MAX_DISP_PX / max(H_sr, W_sr)
W_d, H_d = int(W_sr * s), int(H_sr * s)
else:
W_d, H_d = W_sr, H_sr
img_after = Image.fromarray(sr_u8).resize((W_d, H_d), Image.LANCZOS)
img_before = Image.fromarray(lr_u8).resize((W_d, H_d), Image.NEAREST)
return img_before, img_after
# ── Inference helper ──────────────────────────────────────────────────────────
def _infer(lr_array: np.ndarray, variant: str, progress=None) -> np.ndarray:
_p = progress or (lambda *a, **k: None)
model, _ = _get_model(variant)
band_idx = VARIANT_BANDS[variant]
lr_input = lr_array[band_idx]
_p(0.4, desc="Inference SR en cours (CPU, patience)...")
return infer_large(
model=model, array=lr_input, device=DEVICE,
overlap=32, use_amp=False, batch_size=1,
)
# ── Map helpers ───────────────────────────────────────────────────────────────
def _profile_to_wgs84_bbox(profile: dict) -> tuple:
"""Return tile bounding box as (min_lon, min_lat, max_lon, max_lat) WGS84."""
T = profile["transform"]
w, h = profile["width"], profile["height"]
left, top = T.c, T.f
right = left + w * T.a
bottom = top + h * T.e # T.e is negative
west, south, east, north = transform_bounds(
profile["crs"], CRS.from_epsg(4326), left, bottom, right, top
)
return (west, south, east, north)
def _make_map_data(
tile_bbox=None, tile_px=None,
roi=None, roi_too_large: bool = False, roi_warning_text: str = "",
interactive: bool = False,
) -> str:
"""Serialise map state as JSON for the Python→JS bridge textbox."""
return json.dumps({
"tile_bbox": list(tile_bbox) if tile_bbox else None,
"tile_px": list(tile_px) if tile_px else None,
"roi": list(roi) if roi else None,
"roi_too_large": roi_too_large,
"roi_warning_text": roi_warning_text,
"interactive": interactive,
})
def _clamp_roi(roi: list, tile_bbox: list) -> list:
"""Clip ROI to tile bounding box."""
return [
max(roi[0], tile_bbox[0]), # min_lon
max(roi[1], tile_bbox[1]), # min_lat
min(roi[2], tile_bbox[2]), # max_lon
min(roi[3], tile_bbox[3]), # max_lat
]
def _estimate_roi_px(roi: list, tile_bbox: list, tile_px: list):
"""
Estimate ROI pixel dimensions by proportional mapping over the tile extent.
Returns (w_px, h_px) or (None, None) if data is missing.
"""
if not tile_bbox or not tile_px:
return None, None
tile_w_deg = tile_bbox[2] - tile_bbox[0]
tile_h_deg = tile_bbox[3] - tile_bbox[1]
if tile_w_deg <= 0 or tile_h_deg <= 0:
return None, None
w_px = int((roi[2] - roi[0]) / tile_w_deg * tile_px[0])
h_px = int((roi[3] - roi[1]) / tile_h_deg * tile_px[1])
return w_px, h_px
# Static HTML β€” loaded once; JS polls the hidden #map_data textbox every 400 ms.
_MAP_HTML = """
<style>
#s2sr-map-wrap .leaflet-container { border-radius: 8px; }
.roi-warn-tip {
background: #fee2e2;
border: 1px solid #ef4444; border-radius: 5px;
color: #991b1b; font-weight: 700; font-size: 0.88em;
padding: 4px 10px; white-space: nowrap;
box-shadow: 0 1px 4px rgba(0,0,0,0.15);
}
.roi-warn-tip::before { display: none; }
.roi-map-warn-bar {
position: absolute; top: 8px; left: 50%; transform: translateX(-50%);
z-index: 1001; pointer-events: none;
background: #fee2e2; border: 1px solid #ef4444; border-radius: 6px;
color: #991b1b; font-weight: 700; font-size: 0.85em;
padding: 5px 14px; white-space: nowrap;
box-shadow: 0 1px 6px rgba(239,68,68,0.25);
}
</style>
<div id="s2sr-map-wrap" style="position:relative;width:100%;height:400px;">
<div id="s2sr-map" style="width:100%;height:100%;"></div>
<div id="s2sr-map-overlay"
style="position:absolute;top:0;left:0;width:100%;height:100%;
background:rgba(0,0,0,0.42);display:flex;align-items:center;
justify-content:center;z-index:1000;border-radius:8px;pointer-events:none;">
<div style="color:#fff;font-size:1.05em;font-weight:600;
background:rgba(0,0,0,0.62);padding:12px 28px;
border-radius:8px;text-align:center;">
&#128194; Uploadez un fichier ZIP pour activer la carte
</div>
</div>
</div>
<link rel="stylesheet" href="https://unpkg.com/leaflet@1.9.4/dist/leaflet.css"/>
<link rel="stylesheet"
href="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.css"/>
<script src="https://unpkg.com/leaflet@1.9.4/dist/leaflet.js"></script>
<script src="https://cdnjs.cloudflare.com/ajax/libs/leaflet.draw/1.0.4/leaflet.draw.js"></script>
<script>
(function () {
'use strict';
var _map = null, _tileLyr = null, _roiLyr = null, _drawCtrl = null, _warnBar = null;
var _cur = { tile_bbox: null, tile_px: null, roi: null, roi_too_large: false, interactive: false };
var _lastMapDataVal = null;
function _initMap() {
if (_map) return;
_map = L.map('s2sr-map').setView([20, 0], 2);
L.tileLayer('https://{s}.tile.openstreetmap.org/{z}/{x}/{y}.png', {
attribution: '&copy; <a href="https://www.openstreetmap.org/copyright">OpenStreetMap</a>',
maxZoom: 18
}).addTo(_map);
}
function _setTile(bbox) {
if (_tileLyr) { _map.removeLayer(_tileLyr); _tileLyr = null; }
if (!bbox) return;
_tileLyr = L.rectangle([[bbox[1], bbox[0]], [bbox[3], bbox[2]]],
{ color: '#4488ff', weight: 2, fillOpacity: 0.08 }).addTo(_map);
_map.fitBounds(_tileLyr.getBounds(), { padding: [30, 30] });
}
function _setRoi(roi, tooLarge) {
if (_roiLyr) { _map.removeLayer(_roiLyr); _roiLyr = null; }
if (!roi) return;
var color = tooLarge ? '#ef4444' : '#22c55e';
_roiLyr = L.rectangle([[roi[1], roi[0]], [roi[3], roi[2]]],
{ color: color, weight: 2.5, fillOpacity: 0.15 }).addTo(_map);
if (tooLarge) {
_roiLyr.bindTooltip('⚠️ ROI trop grand', {
permanent: true, direction: 'center', className: 'roi-warn-tip'
}).openTooltip();
}
}
function _clampRoiToTile(roi) {
if (!_cur.tile_bbox) return roi;
var tb = _cur.tile_bbox;
return [
Math.max(roi[0], tb[0]), Math.max(roi[1], tb[1]),
Math.min(roi[2], tb[2]), Math.min(roi[3], tb[3]),
];
}
function _enableDraw() {
if (_drawCtrl) return;
var drawn = new L.FeatureGroup();
_map.addLayer(drawn);
_drawCtrl = new L.Control.Draw({
draw: {
rectangle: { shapeOptions: { color: '#ff4444' } },
polyline: false,
polygon: false,
circle: false,
circlemarker: false,
marker: false
},
edit: false
});
_map.addControl(_drawCtrl);
_map.on(L.Draw.Event.CREATED, function (e) {
var b = e.layer.getBounds();
var sw = b.getSouthWest(), ne = b.getNorthEast();
var roi = _clampRoiToTile([sw.lng, sw.lat, ne.lng, ne.lat]);
_setRoi(roi, false); // green placeholder; Python will confirm final color
var val = roi[0].toFixed(6) + ',' + roi[1].toFixed(6) + ','
+ roi[2].toFixed(6) + ',' + roi[3].toFixed(6);
// Write into the hidden Gradio textbox #roi_drawn to notify Python
function _push() {
var el = document.querySelector('#roi_drawn textarea')
|| document.querySelector('#roi_drawn input[type=text]');
if (!el) { setTimeout(_push, 200); return; }
// Use native setter so React/Svelte picks up the change
var proto = el.tagName === 'TEXTAREA'
? window.HTMLTextAreaElement.prototype
: window.HTMLInputElement.prototype;
var setter = Object.getOwnPropertyDescriptor(proto, 'value');
if (setter && setter.set) setter.set.call(el, val); else el.value = val;
el.dispatchEvent(new Event('input', { bubbles: true }));
el.dispatchEvent(new Event('change', { bubbles: true }));
}
_push();
});
}
function _setWarningBar(text) {
if (!_warnBar) {
_warnBar = document.createElement('div');
_warnBar.className = 'roi-map-warn-bar';
document.getElementById('s2sr-map-wrap').appendChild(_warnBar);
}
_warnBar.textContent = text || '';
_warnBar.style.display = text ? 'block' : 'none';
}
function _setInteractive(on) {
var overlay = document.getElementById('s2sr-map-overlay');
if (overlay) overlay.style.display = on ? 'none' : 'flex';
if (on) _enableDraw();
}
function _apply(data) {
_initMap();
if (JSON.stringify(data.tile_bbox) !== JSON.stringify(_cur.tile_bbox))
_setTile(data.tile_bbox);
if (JSON.stringify(data.roi) !== JSON.stringify(_cur.roi) ||
data.roi_too_large !== _cur.roi_too_large)
_setRoi(data.roi, data.roi_too_large);
if (data.interactive !== _cur.interactive)
_setInteractive(data.interactive);
if ((data.roi_warning_text || '') !== (_cur.roi_warning_text || ''))
_setWarningBar(data.roi_warning_text);
_cur = data;
}
function _poll() {
var el = document.querySelector('#map_data textarea')
|| document.querySelector('#map_data input[type=text]');
if (!el || el.value === _lastMapDataVal) return;
_lastMapDataVal = el.value;
if (el.value) { try { _apply(JSON.parse(el.value)); } catch (e) {} }
}
function _boot() {
if (!document.getElementById('s2sr-map')) { setTimeout(_boot, 150); return; }
_initMap();
setInterval(_poll, 400);
}
if (document.readyState === 'loading')
document.addEventListener('DOMContentLoaded', _boot);
else
_boot();
})();
</script>
"""
# ── Map event handlers ────────────────────────────────────────────────────────
def _on_zip_upload(zip_file) -> str:
"""Called when ZIP is uploaded; returns map_data JSON with tile bbox + pixel dims."""
if zip_file is None:
return _make_map_data()
try:
zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name)
band_files = find_band_files_from_zip(zip_path)
profile = get_ref_profile(band_files)
bbox = _profile_to_wgs84_bbox(profile)
tile_px = (profile["width"], profile["height"])
return _make_map_data(tile_bbox=bbox, tile_px=tile_px, interactive=True)
except Exception:
return _make_map_data()
def _on_roi_drawn(roi_str: str, map_data_str: str):
"""
Parse drawn ROI from JS (already clamped client-side), clamp again server-side,
compute pixel estimate, and return updated fields + map_data + warning.
"""
try:
cur = json.loads(map_data_str) if map_data_str else {}
except Exception:
cur = {}
tile_bbox = cur.get("tile_bbox")
tile_px = cur.get("tile_px")
mn_lon = mn_lat = mx_lon = mx_lat = None
if roi_str and roi_str.strip():
try:
parts = [float(x) for x in roi_str.strip().split(',')]
if len(parts) == 4:
mn_lon, mn_lat, mx_lon, mx_lat = parts
except ValueError:
pass
roi_list = None
w_px = h_px = None
too_large = False
if mn_lon is not None:
roi_list = [mn_lon, mn_lat, mx_lon, mx_lat]
if tile_bbox:
roi_list = _clamp_roi(roi_list, tile_bbox)
mn_lon, mn_lat, mx_lon, mx_lat = roi_list
if tile_px:
w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px)
too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500))
warn_text = (
f"⚠️ ROI trop grand : ~{w_px}β€―Γ—β€―{h_px}β€―px (limiteβ€―:β€―1β€―500β€―Γ—β€―1β€―500β€―px)"
if too_large else ""
)
new_map_data = _make_map_data(
tile_bbox=tile_bbox, tile_px=tile_px,
roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text,
interactive=cur.get("interactive", False),
)
return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data
def _on_coords_change(mn_lon, mn_lat, mx_lon, mx_lat, map_data_str: str):
"""Clamp coordinates to tile, estimate pixel size, update map + warning."""
try:
cur = json.loads(map_data_str) if map_data_str else {}
except Exception:
cur = {}
tile_bbox = cur.get("tile_bbox")
tile_px = cur.get("tile_px")
roi_list = None
w_px = h_px = None
too_large = False
try:
if all(v is not None for v in [mn_lon, mn_lat, mx_lon, mx_lat]):
roi_list = [float(mn_lon), float(mn_lat), float(mx_lon), float(mx_lat)]
if tile_bbox:
roi_list = _clamp_roi(roi_list, tile_bbox)
mn_lon, mn_lat, mx_lon, mx_lat = roi_list
if tile_px:
w_px, h_px = _estimate_roi_px(roi_list, tile_bbox, tile_px)
too_large = bool(w_px and h_px and (w_px > 1500 or h_px > 1500))
except (TypeError, ValueError):
pass
warn_text = (
f"⚠️ ROI trop grand : ~{w_px}β€―Γ—β€―{h_px}β€―px (limiteβ€―:β€―1β€―500β€―Γ—β€―1β€―500β€―px)"
if too_large else ""
)
new_map_data = _make_map_data(
tile_bbox=tile_bbox, tile_px=tile_px,
roi=roi_list, roi_too_large=too_large, roi_warning_text=warn_text,
interactive=cur.get("interactive", False),
)
return mn_lon, mn_lat, mx_lon, mx_lat, new_map_data
# ── Mode Exemple ──────────────────────────────────────────────────────────────
def run_example(preset_label: str, variant: str, progress=None):
_p = progress or (lambda *a, **k: None)
preset_name = PRESETS[preset_label]
_p(0.1, desc="Chargement de l'exemple...")
lr_array = np.load(f"examples/{preset_name}.npy")
with open(f"examples/{preset_name}.json") as f:
meta = json.load(f)
a, b, c, d, e, ff = meta["transform"]
crs = CRS.from_string(meta["crs"])
T_lr = Affine(a, b, c, d, e, ff)
sr_array = _infer(lr_array, variant, progress)
_p(0.9, desc="Rendu TCI...")
r, g, b_idx = TCI_BANDS
img_before, img_after = _make_comparison(lr_array, sr_array)
scale = get_scale(variant)
T_sr = Affine(a / scale, b, c, d, e / scale, ff)
tmpdir = Path(tempfile.mkdtemp())
p_before = str(tmpdir / "avant_10m.jpg")
p_after = str(tmpdir / "apres_2p5m.jpg")
p_tif_before = str(tmpdir / "TCI_avant_10m.tif")
p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif")
img_before.save(p_before, quality=90)
img_after.save(p_after, quality=90)
_write_tci_tif(lr_array[[r, g, b_idx]], crs, T_lr, p_tif_before)
_write_tci_tif(sr_array[[r, g, b_idx]], crs, T_sr, p_tif_after)
_p(1.0, desc="Termine !")
return (
[(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")],
p_before, p_after, p_tif_before, p_tif_after,
)
# ── Mode ZIP ──────────────────────────────────────────────────────────────────
def run_zip(
zip_file, min_lon, min_lat, max_lon, max_lat,
variant: str, progress=None,
):
if zip_file is None:
raise gr.Error("Uploadez d'abord un fichier ZIP Sentinel-2 L2A.")
if any(v is None for v in [min_lon, min_lat, max_lon, max_lat]):
raise gr.Error("Renseignez les 4 coordonnees ROI.")
_p = progress or (lambda *a, **k: None)
zip_path = Path(zip_file if isinstance(zip_file, str) else zip_file.name)
roi = (float(min_lon), float(min_lat), float(max_lon), float(max_lat))
try:
_p(0.05, desc="Lecture structure S2...")
band_files = find_band_files_from_zip(zip_path)
ref_profile = get_ref_profile(band_files)
_p(0.1, desc="Calcul fenetre ROI...")
roi_win = roi_wgs84_to_pixel_window(roi, ref_profile, buffer_px=128)
roi_w, roi_h = roi_win["roi_width"], roi_win["roi_height"]
if roi_w > 1500 or roi_h > 1500:
raise gr.Error(
f"ROI trop grand pour cette demo : {roi_w} x {roi_h} px "
f"(limite demo HuggingFace CPU : 1500 x 1500 px). "
f"Reduisez la zone ou utilisez le pipeline complet en local "
f"(https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution)."
)
rr_min, rr_max, rc_min, rc_max = roi_win["read"]
read_window = rasterio.windows.Window(
col_off=rc_min, row_off=rr_min,
width=rc_max - rc_min, height=rr_max - rr_min,
)
_p(0.2, desc="Lecture fenetre ROI (quelques secondes)...")
raw_array, _ = read_and_align_bands(band_files, window=read_window)
norm_array = normalize(raw_array)
del raw_array
ir_start, ir_end, ic_start, ic_end = roi_win["inner"]
lr_exact = norm_array[:, ir_start:ir_end, ic_start:ic_end].copy()
sr_array = _infer(norm_array, variant, progress)
del norm_array
scale = get_scale(variant)
sr_exact = sr_array[
:,
ir_start * scale : ir_end * scale,
ic_start * scale : ic_end * scale,
]
_p(0.9, desc="Rendu TCI...")
r, g, b_idx = TCI_BANDS
img_before, img_after = _make_comparison(lr_exact, sr_exact)
crs = ref_profile["crs"]
T_lr = roi_win["roi_transform"]
T_sr = Affine(
T_lr.a / scale, T_lr.b, T_lr.c,
T_lr.d, T_lr.e / scale, T_lr.f,
)
except ValueError as e:
raise gr.Error(str(e))
tmpdir = Path(tempfile.mkdtemp())
p_before = str(tmpdir / "avant_10m.jpg")
p_after = str(tmpdir / "apres_2p5m.jpg")
p_tif_before = str(tmpdir / "TCI_avant_10m.tif")
p_tif_after = str(tmpdir / "TCI_apres_2p5m.tif")
img_before.save(p_before, quality=90)
img_after.save(p_after, quality=90)
_write_tci_tif(lr_exact[[r, g, b_idx]], crs, T_lr, p_tif_before)
_write_tci_tif(sr_exact[[r, g, b_idx]], crs, T_sr, p_tif_after)
_p(1.0, desc="Termine !")
return (
[(img_before, "Avant 10m"), (img_after, "Apres 2.5m SR")],
p_before, p_after, p_tif_before, p_tif_after,
)
# ── Gradio UI ─────────────────────────────────────────────────────────────────
with gr.Blocks(title="SEN2SR β€” Super-Resolution Sentinel-2") as demo:
gr.Markdown("""
# SEN2SR β€” Super-Resolution Sentinel-2
**10 m/pixel β†’ 2.5 m/pixel (Γ—4)** sur des images Sentinel-2 L2A.
Modele officiel ESA : [tacofoundation/sen2sr](https://huggingface.co/tacofoundation/sen2sr)
β€” Code pipeline : [GitHub](https://github.com/gdubrasquetd/Sentinel-2-Better-Resolution)
> L'inference tourne sur CPU β€” comptez **1 a 5 minutes** selon la taille du ROI.
""")
with gr.Row():
source = gr.Radio(
choices=["Exemple Tatarstan", "Mon ZIP Sentinel-2"],
value="Exemple Tatarstan",
label="Source des donnees",
)
variant = gr.Dropdown(
choices=["main", "rgbn_x4", "rswir_x2"],
value="main",
label="Variante du modele",
info="main: 10 bandes x4 | rgbn_x4: RGB+NIR x4 | rswir_x2: 10 bandes x2",
)
with gr.Group(visible=True) as panel_example:
preset = gr.Radio(
choices=list(PRESETS.keys()),
value="Zone industrielle (Tatarstan)",
label="Zone de demonstration",
)
with gr.Group(visible=False) as panel_zip:
with gr.Row(equal_height=True):
with gr.Column(scale=1, min_width=180):
zip_input = gr.File(
label="Fichier ZIP Sentinel-2 L2A",
file_types=[".zip"],
height=440,
)
with gr.Column(scale=3):
# ── OSM map ──────────────────────────────────────────────────
gr.Markdown("**Carte β€” emprise de la dalle et ROI**")
map_html = gr.HTML(value=_MAP_HTML)
# Hidden bridge textboxes (not shown to the user)
map_data = gr.Textbox(visible=False, elem_id="map_data", value=_make_map_data())
roi_drawn = gr.Textbox(visible=False, elem_id="roi_drawn", value="")
# ── ROI coordinates ──────────────────────────────────────────
gr.Markdown("**ROI β€” coordonnees WGS84** (dessinez sur la carte ou saisissez manuellement)")
with gr.Row():
min_lon = gr.Number(label="Min longitude", value=52.030)
min_lat = gr.Number(label="Min latitude", value=55.828)
max_lon = gr.Number(label="Max longitude", value=52.048)
max_lat = gr.Number(label="Max latitude", value=55.843)
def _toggle(src):
return (
gr.update(visible=(src == "Exemple Tatarstan")),
gr.update(visible=(src == "Mon ZIP Sentinel-2")),
)
source.change(_toggle, inputs=source, outputs=[panel_example, panel_zip])
run_btn = gr.Button("Lancer la super-resolution", variant="primary", size="lg")
gallery = gr.Gallery(
label="Avant (10 m) / Apres (2.5 m SR)",
columns=2,
height=520,
)
with gr.Row():
dl_before = gr.File(label="Telecharger β€” Avant 10m (JPEG)")
dl_after = gr.File(label="Telecharger β€” Apres 2.5m SR (JPEG)")
with gr.Row():
dl_tif_before = gr.File(label="Telecharger β€” TCI Avant GeoTIFF")
dl_tif_after = gr.File(label="Telecharger β€” TCI Apres SR GeoTIFF")
# ── Map events ────────────────────────────────────────────────────────────
zip_input.upload(fn=_on_zip_upload, inputs=[zip_input], outputs=[map_data])
zip_input.clear(fn=lambda: _make_map_data(), inputs=[], outputs=[map_data])
roi_drawn.change(
fn=_on_roi_drawn,
inputs=[roi_drawn, map_data],
outputs=[min_lon, min_lat, max_lon, max_lat, map_data],
)
_coord_inputs = [min_lon, min_lat, max_lon, max_lat]
for _c in _coord_inputs:
_c.change(
fn=_on_coords_change,
inputs=_coord_inputs + [map_data],
outputs=[min_lon, min_lat, max_lon, max_lat, map_data],
)
# ── Run button ────────────────────────────────────────────────────────────
def _on_run(src, p, zf, mn_lon, mn_lat, mx_lon, mx_lat, var):
if src == "Exemple Tatarstan":
return run_example(p, var)
return run_zip(zf, mn_lon, mn_lat, mx_lon, mx_lat, var)
run_btn.click(
_on_run,
inputs=[source, preset, zip_input, min_lon, min_lat, max_lon, max_lat, variant],
outputs=[gallery, dl_before, dl_after, dl_tif_before, dl_tif_after],
api_name="run_sr",
)
if __name__ == "__main__":
demo.launch()