Spaces:
Sleeping
Sleeping
| import gradio as gr | |
| from config import ( | |
| APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE, | |
| COMPOSITE_PRESETS, BAND_DESCRIPTIONS, MAX_EXPERIMENTS, | |
| ) | |
| from train import ( | |
| load_dataset_action, | |
| update_step1_composite, | |
| handle_click_step1, | |
| update_step2_index, | |
| run_baseline_action, | |
| update_step4_patch, | |
| train_experiment, | |
| update_step5_comparison, | |
| handle_click_step5, | |
| ) | |
| set_seed(SEED) | |
| _COMPOSITE_CHOICES = list(COMPOSITE_PRESETS.keys()) + BAND_DESCRIPTIONS | |
| custom_css = """ | |
| #step1-img img, #step3-pred img { image-rendering: pixelated; } | |
| .hint { font-size: 0.88rem; color: #666; } | |
| """ | |
| with gr.Blocks(title=APP_TITLE, css=custom_css) as demo: | |
| gr.Markdown(f"# {APP_TITLE}") | |
| gr.Markdown( | |
| "Un parcours en cinq étapes, des pixels bruts du satellite jusqu'à la segmentation " | |
| "par apprentissage profond. Suivez les onglets dans l'ordre — chaque étape s'appuie sur la précédente.", | |
| elem_classes="hint", | |
| ) | |
| dataset_state = gr.State(None) | |
| baseline_state = gr.State(None) | |
| experiments_state = gr.State([]) | |
| # ──────────────────────────────────────────────────────── | |
| # Étape 1 — Découvrir les données | |
| # ──────────────────────────────────────────────────────── | |
| with gr.Tab("Étape 1 · Découvrir les données"): | |
| gr.Markdown( | |
| "**Commencez ici.** Chargez le jeu de données, puis explorez les 7 bandes spectrales. " | |
| "Les carrés sur l'image sont les étiquettes d'entraînement ; les cercles sont les étiquettes de validation.", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, | |
| label="Taille du patch (pour l'entraînement)") | |
| load_btn = gr.Button("Charger les données", variant="primary") | |
| composite_dd = gr.Dropdown( | |
| choices=_COMPOSITE_CHOICES, | |
| value="H4 / H3 / H2", | |
| label="Mode d'affichage", | |
| ) | |
| step1_info = gr.Markdown("*Chargez les données pour commencer.*") | |
| with gr.Column(scale=3, elem_id="step1-img"): | |
| step1_image = gr.Image( | |
| label="Scène complète — cliquer pour inspecter un pixel", | |
| type="numpy", | |
| ) | |
| step1_click = gr.Markdown("*Cliquez n'importe où sur l'image.*") | |
| # ──────────────────────────────────────────────────────── | |
| # Étape 2 — Signatures spectrales | |
| # ──────────────────────────────────────────────────────── | |
| with gr.Tab("Étape 2 · Signatures spectrales"): | |
| gr.Markdown( | |
| "Chaque type d'occupation du sol possède un profil de luminosité caractéristique " | |
| "à travers les 7 bandes — sa **signature spectrale**. " | |
| "Le NDVI et le NDWI sont des indices calculés directement à partir des valeurs de bande.", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| index_radio = gr.Radio( | |
| choices=["NDVI", "NDWI"], | |
| value="NDVI", | |
| label="Carte d'indice spectral", | |
| ) | |
| gr.Markdown( | |
| "**NDVI** = (H_5 − H_4) / (H_5 + H_4)\n\n" | |
| "**NDWI** = (H_3 − H_5) / (H_3 + H_5)", | |
| elem_classes="hint", | |
| ) | |
| with gr.Column(scale=3): | |
| step2_sig_chart = gr.Image( | |
| label="Signatures spectrales (étiquettes d'entraînement)", | |
| type="numpy", | |
| ) | |
| step2_index_map = gr.Image(label="Carte d'indice", type="numpy") | |
| # ──────────────────────────────────────────────────────── | |
| # Étape 3 — Référence spectrale (KNN) | |
| # ──────────────────────────────────────────────────────── | |
| with gr.Tab("Étape 3 · Référence spectrale"): | |
| gr.Markdown( | |
| "**Aucune convolution ici.** Chaque pixel est classifié uniquement à partir " | |
| "de ses 7 valeurs de bande, en cherchant les k pixels d'entraînement les plus " | |
| "proches dans l'espace spectral. " | |
| "Cela montre ce qui est atteignable sans aucun contexte spatial.", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| k_slider = gr.Slider(1, 5, value=3, step=2, | |
| label="k (nombre de voisins)") | |
| baseline_btn = gr.Button("Lancer la référence KNN", variant="primary") | |
| step3_metrics = gr.Markdown("*Lancez la référence pour voir les résultats.*") | |
| with gr.Column(scale=3, elem_id="step3-pred"): | |
| step3_full_pred = gr.Image( | |
| label=( | |
| "Prédiction sur la scène complète · superposée sur l'image en couleurs naturelles · " | |
| "points colorés = étiquettes de validation (vert=correct, rouge=incorrect)" | |
| ), | |
| type="numpy", | |
| ) | |
| # ──────────────────────────────────────────────────────── | |
| # Étape 4 — Apprentissage profond (UNet) | |
| # ──────────────────────────────────────────────────────── | |
| with gr.Tab("Étape 4 · Apprentissage profond"): | |
| gr.Markdown( | |
| "Un **U-Net** observe un patch de pixels à la fois, pas un seul pixel. " | |
| "Son encodeur capture la texture locale ; les connexions de saut préservent " | |
| "le détail spatial. Entraînez un modèle et comparez-le patch par patch avec la référence KNN.", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| run_name = gr.Textbox(label="Nom de l'expérience", | |
| placeholder="ex. : lr-1e-3_ch-16") | |
| learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, | |
| label="Taux d'apprentissage") | |
| batch_size = gr.Slider(2, 32, value=8, step=2, | |
| label="Taille du lot") | |
| epochs = gr.Slider(1, 20, value=5, step=1, | |
| label="Époques") | |
| base_channels = gr.Slider(8, 64, value=16, step=8, | |
| label="Largeur du modèle (canaux de base)") | |
| train_btn = gr.Button("Entraîner le modèle", variant="primary") | |
| gr.Markdown( | |
| f"*Maximum {MAX_EXPERIMENTS} expériences. " | |
| "Rechargez les données pour réinitialiser.*", | |
| elem_classes="hint", | |
| ) | |
| step4_summary = gr.Markdown("*Entraînez un modèle pour voir les résultats.*") | |
| with gr.Column(scale=3): | |
| step4_patch_slider = gr.Slider(0, 59, value=0, step=1, | |
| label="Index du patch de validation") | |
| with gr.Row(): | |
| step4_gt_img = gr.Image(label="Superposition vérité terrain", | |
| type="numpy", height=280) | |
| step4_bl_img = gr.Image(label="Prédiction référence KNN", | |
| type="numpy", height=280) | |
| step4_un_img = gr.Image(label="Prédiction UNet", | |
| type="numpy", height=280) | |
| # ──────────────────────────────────────────────────────── | |
| # Étape 5 — Laboratoire d'expériences | |
| # ──────────────────────────────────────────────────────── | |
| with gr.Tab("Étape 5 · Laboratoire d'expériences"): | |
| gr.Markdown( | |
| f"Comparez jusqu'à **{MAX_EXPERIMENTS}** expériences UNet côte à côte. " | |
| "Essayez différents taux d'apprentissage, nombres d'époques ou largeurs de modèle " | |
| "et observez ce qui change.", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| step5_sel_a = gr.Dropdown(choices=[], value=None, label="Modèle gauche", | |
| interactive=True) | |
| step5_sel_b = gr.Dropdown(choices=[], value=None, label="Modèle droit", | |
| interactive=True) | |
| step5_patch_slider = gr.Slider(0, 59, value=0, step=1, | |
| label="Index du patch de validation") | |
| step5_table = gr.Markdown("*Aucune expérience pour l'instant.*") | |
| gr.Markdown( | |
| "**Questions directrices**\n\n" | |
| "- Doublez les époques — le mIoU continue-t-il de progresser ou se stabilise-t-il ?\n" | |
| "- Divisez le taux d'apprentissage par 2 — l'entraînement devient-il plus stable ?\n" | |
| "- Augmentez les canaux de base de 16 à 32 — le gain vaut-il le temps supplémentaire ?", | |
| elem_classes="hint", | |
| ) | |
| with gr.Row(): | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Gauche") | |
| s5_a_rgb = gr.Image(label="RVB", type="numpy", height=240) | |
| s5_a_pred = gr.Image(label="Prédiction", type="numpy", height=240) | |
| s5_a_overlay = gr.Image(label="Superposition", type="numpy", height=240) | |
| s5_a_metrics = gr.Markdown("*Aucun modèle sélectionné.*") | |
| s5_a_error = gr.Image(label="Carte de précision", type="numpy", height=240) | |
| with gr.Column(scale=1): | |
| gr.Markdown("## Droite") | |
| s5_b_rgb = gr.Image(label="RVB", type="numpy", height=240) | |
| s5_b_pred = gr.Image(label="Prédiction", type="numpy", height=240) | |
| s5_b_overlay = gr.Image(label="Superposition", type="numpy", height=240) | |
| s5_b_metrics = gr.Markdown("*Aucun modèle sélectionné.*") | |
| s5_b_error = gr.Image(label="Carte de précision", type="numpy", height=240) | |
| # ── Connexion des événements ────────────────────────────── | |
| _load_outputs = [ | |
| dataset_state, baseline_state, experiments_state, | |
| # Étape 1 | |
| step1_info, step1_image, step1_click, | |
| # Étape 2 | |
| step2_sig_chart, step2_index_map, | |
| # Étape 3 | |
| step3_metrics, step3_full_pred, | |
| # Étape 4 | |
| step4_summary, step4_patch_slider, | |
| step4_gt_img, step4_bl_img, step4_un_img, | |
| # Étape 5 | |
| step5_table, step5_sel_a, step5_sel_b, | |
| ] | |
| load_btn.click(fn=load_dataset_action, inputs=[patch_size], outputs=_load_outputs) | |
| composite_dd.change( | |
| fn=update_step1_composite, | |
| inputs=[dataset_state, composite_dd], | |
| outputs=[step1_image, step1_click], | |
| ) | |
| step1_image.select( | |
| fn=handle_click_step1, | |
| inputs=[dataset_state], | |
| outputs=[step1_click], | |
| ) | |
| index_radio.change( | |
| fn=update_step2_index, | |
| inputs=[dataset_state, index_radio], | |
| outputs=[step2_index_map], | |
| ) | |
| baseline_btn.click( | |
| fn=run_baseline_action, | |
| inputs=[dataset_state, k_slider], | |
| outputs=[baseline_state, step3_metrics, step3_full_pred], | |
| ) | |
| _train_outputs = [ | |
| experiments_state, | |
| step4_summary, step4_patch_slider, | |
| step4_gt_img, step4_bl_img, step4_un_img, | |
| step5_table, step5_sel_a, step5_sel_b, | |
| ] | |
| train_btn.click( | |
| fn=train_experiment, | |
| inputs=[ | |
| dataset_state, baseline_state, experiments_state, | |
| learning_rate, batch_size, epochs, base_channels, run_name, | |
| ], | |
| outputs=_train_outputs, | |
| ) | |
| step4_patch_slider.change( | |
| fn=update_step4_patch, | |
| inputs=[dataset_state, baseline_state, experiments_state, step4_patch_slider], | |
| outputs=[step4_gt_img, step4_bl_img, step4_un_img], | |
| ) | |
| _s5_inputs = [dataset_state, experiments_state, step5_sel_a, step5_sel_b, step5_patch_slider] | |
| _s5_outputs = [ | |
| s5_a_rgb, s5_a_pred, s5_a_overlay, s5_a_metrics, s5_a_error, | |
| s5_b_rgb, s5_b_pred, s5_b_overlay, s5_b_metrics, s5_b_error, | |
| ] | |
| for trigger in [step5_sel_a, step5_sel_b, step5_patch_slider]: | |
| trigger.change(fn=update_step5_comparison, inputs=_s5_inputs, outputs=_s5_outputs) | |
| for img, sel in [(s5_a_rgb, step5_sel_a), (s5_a_overlay, step5_sel_a), | |
| (s5_b_rgb, step5_sel_b), (s5_b_overlay, step5_sel_b)]: | |
| img.select( | |
| fn=handle_click_step5, | |
| inputs=[dataset_state, experiments_state, sel, step5_patch_slider], | |
| outputs=[s5_a_metrics if sel == step5_sel_a else s5_b_metrics], | |
| ) | |
| try: | |
| # Gradio 6+: css moved to launch() | |
| demo.launch(css=custom_css, server_name="0.0.0.0") | |
| except TypeError: | |
| # Gradio 4.x: css stays in Blocks(), launch() doesn't accept it | |
| demo.launch(server_name="0.0.0.0") | |