segspace_app / app.py
functionNormally
Remove speculative band labels; revert class names to English
f4b0cd5
import gradio as gr
from config import (
APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE,
COMPOSITE_PRESETS, BAND_DESCRIPTIONS, MAX_EXPERIMENTS,
)
from train import (
load_dataset_action,
update_step1_composite,
handle_click_step1,
update_step2_index,
run_baseline_action,
update_step4_patch,
train_experiment,
update_step5_comparison,
handle_click_step5,
)
set_seed(SEED)
_COMPOSITE_CHOICES = list(COMPOSITE_PRESETS.keys()) + BAND_DESCRIPTIONS
custom_css = """
#step1-img img, #step3-pred img { image-rendering: pixelated; }
.hint { font-size: 0.88rem; color: #666; }
"""
with gr.Blocks(title=APP_TITLE, css=custom_css) as demo:
gr.Markdown(f"# {APP_TITLE}")
gr.Markdown(
"Un parcours en cinq étapes, des pixels bruts du satellite jusqu'à la segmentation "
"par apprentissage profond. Suivez les onglets dans l'ordre — chaque étape s'appuie sur la précédente.",
elem_classes="hint",
)
dataset_state = gr.State(None)
baseline_state = gr.State(None)
experiments_state = gr.State([])
# ────────────────────────────────────────────────────────
# Étape 1 — Découvrir les données
# ────────────────────────────────────────────────────────
with gr.Tab("Étape 1 · Découvrir les données"):
gr.Markdown(
"**Commencez ici.** Chargez le jeu de données, puis explorez les 7 bandes spectrales. "
"Les carrés sur l'image sont les étiquettes d'entraînement ; les cercles sont les étiquettes de validation.",
elem_classes="hint",
)
with gr.Row():
with gr.Column(scale=1):
patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32,
label="Taille du patch (pour l'entraînement)")
load_btn = gr.Button("Charger les données", variant="primary")
composite_dd = gr.Dropdown(
choices=_COMPOSITE_CHOICES,
value="H4 / H3 / H2",
label="Mode d'affichage",
)
step1_info = gr.Markdown("*Chargez les données pour commencer.*")
with gr.Column(scale=3, elem_id="step1-img"):
step1_image = gr.Image(
label="Scène complète — cliquer pour inspecter un pixel",
type="numpy",
)
step1_click = gr.Markdown("*Cliquez n'importe où sur l'image.*")
# ────────────────────────────────────────────────────────
# Étape 2 — Signatures spectrales
# ────────────────────────────────────────────────────────
with gr.Tab("Étape 2 · Signatures spectrales"):
gr.Markdown(
"Chaque type d'occupation du sol possède un profil de luminosité caractéristique "
"à travers les 7 bandes — sa **signature spectrale**. "
"Le NDVI et le NDWI sont des indices calculés directement à partir des valeurs de bande.",
elem_classes="hint",
)
with gr.Row():
with gr.Column(scale=1):
index_radio = gr.Radio(
choices=["NDVI", "NDWI"],
value="NDVI",
label="Carte d'indice spectral",
)
gr.Markdown(
"**NDVI** = (H_5 − H_4) / (H_5 + H_4)\n\n"
"**NDWI** = (H_3 − H_5) / (H_3 + H_5)",
elem_classes="hint",
)
with gr.Column(scale=3):
step2_sig_chart = gr.Image(
label="Signatures spectrales (étiquettes d'entraînement)",
type="numpy",
)
step2_index_map = gr.Image(label="Carte d'indice", type="numpy")
# ────────────────────────────────────────────────────────
# Étape 3 — Référence spectrale (KNN)
# ────────────────────────────────────────────────────────
with gr.Tab("Étape 3 · Référence spectrale"):
gr.Markdown(
"**Aucune convolution ici.** Chaque pixel est classifié uniquement à partir "
"de ses 7 valeurs de bande, en cherchant les k pixels d'entraînement les plus "
"proches dans l'espace spectral. "
"Cela montre ce qui est atteignable sans aucun contexte spatial.",
elem_classes="hint",
)
with gr.Row():
with gr.Column(scale=1):
k_slider = gr.Slider(1, 5, value=3, step=2,
label="k (nombre de voisins)")
baseline_btn = gr.Button("Lancer la référence KNN", variant="primary")
step3_metrics = gr.Markdown("*Lancez la référence pour voir les résultats.*")
with gr.Column(scale=3, elem_id="step3-pred"):
step3_full_pred = gr.Image(
label=(
"Prédiction sur la scène complète · superposée sur l'image en couleurs naturelles · "
"points colorés = étiquettes de validation (vert=correct, rouge=incorrect)"
),
type="numpy",
)
# ────────────────────────────────────────────────────────
# Étape 4 — Apprentissage profond (UNet)
# ────────────────────────────────────────────────────────
with gr.Tab("Étape 4 · Apprentissage profond"):
gr.Markdown(
"Un **U-Net** observe un patch de pixels à la fois, pas un seul pixel. "
"Son encodeur capture la texture locale ; les connexions de saut préservent "
"le détail spatial. Entraînez un modèle et comparez-le patch par patch avec la référence KNN.",
elem_classes="hint",
)
with gr.Row():
with gr.Column(scale=1):
run_name = gr.Textbox(label="Nom de l'expérience",
placeholder="ex. : lr-1e-3_ch-16")
learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4,
label="Taux d'apprentissage")
batch_size = gr.Slider(2, 32, value=8, step=2,
label="Taille du lot")
epochs = gr.Slider(1, 20, value=5, step=1,
label="Époques")
base_channels = gr.Slider(8, 64, value=16, step=8,
label="Largeur du modèle (canaux de base)")
train_btn = gr.Button("Entraîner le modèle", variant="primary")
gr.Markdown(
f"*Maximum {MAX_EXPERIMENTS} expériences. "
"Rechargez les données pour réinitialiser.*",
elem_classes="hint",
)
step4_summary = gr.Markdown("*Entraînez un modèle pour voir les résultats.*")
with gr.Column(scale=3):
step4_patch_slider = gr.Slider(0, 59, value=0, step=1,
label="Index du patch de validation")
with gr.Row():
step4_gt_img = gr.Image(label="Superposition vérité terrain",
type="numpy", height=280)
step4_bl_img = gr.Image(label="Prédiction référence KNN",
type="numpy", height=280)
step4_un_img = gr.Image(label="Prédiction UNet",
type="numpy", height=280)
# ────────────────────────────────────────────────────────
# Étape 5 — Laboratoire d'expériences
# ────────────────────────────────────────────────────────
with gr.Tab("Étape 5 · Laboratoire d'expériences"):
gr.Markdown(
f"Comparez jusqu'à **{MAX_EXPERIMENTS}** expériences UNet côte à côte. "
"Essayez différents taux d'apprentissage, nombres d'époques ou largeurs de modèle "
"et observez ce qui change.",
elem_classes="hint",
)
with gr.Row():
step5_sel_a = gr.Dropdown(choices=[], value=None, label="Modèle gauche",
interactive=True)
step5_sel_b = gr.Dropdown(choices=[], value=None, label="Modèle droit",
interactive=True)
step5_patch_slider = gr.Slider(0, 59, value=0, step=1,
label="Index du patch de validation")
step5_table = gr.Markdown("*Aucune expérience pour l'instant.*")
gr.Markdown(
"**Questions directrices**\n\n"
"- Doublez les époques — le mIoU continue-t-il de progresser ou se stabilise-t-il ?\n"
"- Divisez le taux d'apprentissage par 2 — l'entraînement devient-il plus stable ?\n"
"- Augmentez les canaux de base de 16 à 32 — le gain vaut-il le temps supplémentaire ?",
elem_classes="hint",
)
with gr.Row():
with gr.Column(scale=1):
gr.Markdown("## Gauche")
s5_a_rgb = gr.Image(label="RVB", type="numpy", height=240)
s5_a_pred = gr.Image(label="Prédiction", type="numpy", height=240)
s5_a_overlay = gr.Image(label="Superposition", type="numpy", height=240)
s5_a_metrics = gr.Markdown("*Aucun modèle sélectionné.*")
s5_a_error = gr.Image(label="Carte de précision", type="numpy", height=240)
with gr.Column(scale=1):
gr.Markdown("## Droite")
s5_b_rgb = gr.Image(label="RVB", type="numpy", height=240)
s5_b_pred = gr.Image(label="Prédiction", type="numpy", height=240)
s5_b_overlay = gr.Image(label="Superposition", type="numpy", height=240)
s5_b_metrics = gr.Markdown("*Aucun modèle sélectionné.*")
s5_b_error = gr.Image(label="Carte de précision", type="numpy", height=240)
# ── Connexion des événements ──────────────────────────────
_load_outputs = [
dataset_state, baseline_state, experiments_state,
# Étape 1
step1_info, step1_image, step1_click,
# Étape 2
step2_sig_chart, step2_index_map,
# Étape 3
step3_metrics, step3_full_pred,
# Étape 4
step4_summary, step4_patch_slider,
step4_gt_img, step4_bl_img, step4_un_img,
# Étape 5
step5_table, step5_sel_a, step5_sel_b,
]
load_btn.click(fn=load_dataset_action, inputs=[patch_size], outputs=_load_outputs)
composite_dd.change(
fn=update_step1_composite,
inputs=[dataset_state, composite_dd],
outputs=[step1_image, step1_click],
)
step1_image.select(
fn=handle_click_step1,
inputs=[dataset_state],
outputs=[step1_click],
)
index_radio.change(
fn=update_step2_index,
inputs=[dataset_state, index_radio],
outputs=[step2_index_map],
)
baseline_btn.click(
fn=run_baseline_action,
inputs=[dataset_state, k_slider],
outputs=[baseline_state, step3_metrics, step3_full_pred],
)
_train_outputs = [
experiments_state,
step4_summary, step4_patch_slider,
step4_gt_img, step4_bl_img, step4_un_img,
step5_table, step5_sel_a, step5_sel_b,
]
train_btn.click(
fn=train_experiment,
inputs=[
dataset_state, baseline_state, experiments_state,
learning_rate, batch_size, epochs, base_channels, run_name,
],
outputs=_train_outputs,
)
step4_patch_slider.change(
fn=update_step4_patch,
inputs=[dataset_state, baseline_state, experiments_state, step4_patch_slider],
outputs=[step4_gt_img, step4_bl_img, step4_un_img],
)
_s5_inputs = [dataset_state, experiments_state, step5_sel_a, step5_sel_b, step5_patch_slider]
_s5_outputs = [
s5_a_rgb, s5_a_pred, s5_a_overlay, s5_a_metrics, s5_a_error,
s5_b_rgb, s5_b_pred, s5_b_overlay, s5_b_metrics, s5_b_error,
]
for trigger in [step5_sel_a, step5_sel_b, step5_patch_slider]:
trigger.change(fn=update_step5_comparison, inputs=_s5_inputs, outputs=_s5_outputs)
for img, sel in [(s5_a_rgb, step5_sel_a), (s5_a_overlay, step5_sel_a),
(s5_b_rgb, step5_sel_b), (s5_b_overlay, step5_sel_b)]:
img.select(
fn=handle_click_step5,
inputs=[dataset_state, experiments_state, sel, step5_patch_slider],
outputs=[s5_a_metrics if sel == step5_sel_a else s5_b_metrics],
)
try:
# Gradio 6+: css moved to launch()
demo.launch(css=custom_css, server_name="0.0.0.0")
except TypeError:
# Gradio 4.x: css stays in Blocks(), launch() doesn't accept it
demo.launch(server_name="0.0.0.0")