import gradio as gr from config import ( APP_TITLE, set_seed, SEED, DEFAULT_PATCH_SIZE, COMPOSITE_PRESETS, BAND_DESCRIPTIONS, MAX_EXPERIMENTS, ) from train import ( load_dataset_action, update_step1_composite, handle_click_step1, update_step2_index, run_baseline_action, update_step4_patch, train_experiment, update_step5_comparison, handle_click_step5, ) set_seed(SEED) _COMPOSITE_CHOICES = list(COMPOSITE_PRESETS.keys()) + BAND_DESCRIPTIONS custom_css = """ #step1-img img, #step3-pred img { image-rendering: pixelated; } .hint { font-size: 0.88rem; color: #666; } """ with gr.Blocks(title=APP_TITLE, css=custom_css) as demo: gr.Markdown(f"# {APP_TITLE}") gr.Markdown( "Un parcours en cinq étapes, des pixels bruts du satellite jusqu'à la segmentation " "par apprentissage profond. Suivez les onglets dans l'ordre — chaque étape s'appuie sur la précédente.", elem_classes="hint", ) dataset_state = gr.State(None) baseline_state = gr.State(None) experiments_state = gr.State([]) # ──────────────────────────────────────────────────────── # Étape 1 — Découvrir les données # ──────────────────────────────────────────────────────── with gr.Tab("Étape 1 · Découvrir les données"): gr.Markdown( "**Commencez ici.** Chargez le jeu de données, puis explorez les 7 bandes spectrales. " "Les carrés sur l'image sont les étiquettes d'entraînement ; les cercles sont les étiquettes de validation.", elem_classes="hint", ) with gr.Row(): with gr.Column(scale=1): patch_size = gr.Slider(64, 512, value=DEFAULT_PATCH_SIZE, step=32, label="Taille du patch (pour l'entraînement)") load_btn = gr.Button("Charger les données", variant="primary") composite_dd = gr.Dropdown( choices=_COMPOSITE_CHOICES, value="H4 / H3 / H2", label="Mode d'affichage", ) step1_info = gr.Markdown("*Chargez les données pour commencer.*") with gr.Column(scale=3, elem_id="step1-img"): step1_image = gr.Image( label="Scène complète — cliquer pour inspecter un pixel", type="numpy", ) step1_click = gr.Markdown("*Cliquez n'importe où sur l'image.*") # ──────────────────────────────────────────────────────── # Étape 2 — Signatures spectrales # ──────────────────────────────────────────────────────── with gr.Tab("Étape 2 · Signatures spectrales"): gr.Markdown( "Chaque type d'occupation du sol possède un profil de luminosité caractéristique " "à travers les 7 bandes — sa **signature spectrale**. " "Le NDVI et le NDWI sont des indices calculés directement à partir des valeurs de bande.", elem_classes="hint", ) with gr.Row(): with gr.Column(scale=1): index_radio = gr.Radio( choices=["NDVI", "NDWI"], value="NDVI", label="Carte d'indice spectral", ) gr.Markdown( "**NDVI** = (H_5 − H_4) / (H_5 + H_4)\n\n" "**NDWI** = (H_3 − H_5) / (H_3 + H_5)", elem_classes="hint", ) with gr.Column(scale=3): step2_sig_chart = gr.Image( label="Signatures spectrales (étiquettes d'entraînement)", type="numpy", ) step2_index_map = gr.Image(label="Carte d'indice", type="numpy") # ──────────────────────────────────────────────────────── # Étape 3 — Référence spectrale (KNN) # ──────────────────────────────────────────────────────── with gr.Tab("Étape 3 · Référence spectrale"): gr.Markdown( "**Aucune convolution ici.** Chaque pixel est classifié uniquement à partir " "de ses 7 valeurs de bande, en cherchant les k pixels d'entraînement les plus " "proches dans l'espace spectral. " "Cela montre ce qui est atteignable sans aucun contexte spatial.", elem_classes="hint", ) with gr.Row(): with gr.Column(scale=1): k_slider = gr.Slider(1, 5, value=3, step=2, label="k (nombre de voisins)") baseline_btn = gr.Button("Lancer la référence KNN", variant="primary") step3_metrics = gr.Markdown("*Lancez la référence pour voir les résultats.*") with gr.Column(scale=3, elem_id="step3-pred"): step3_full_pred = gr.Image( label=( "Prédiction sur la scène complète · superposée sur l'image en couleurs naturelles · " "points colorés = étiquettes de validation (vert=correct, rouge=incorrect)" ), type="numpy", ) # ──────────────────────────────────────────────────────── # Étape 4 — Apprentissage profond (UNet) # ──────────────────────────────────────────────────────── with gr.Tab("Étape 4 · Apprentissage profond"): gr.Markdown( "Un **U-Net** observe un patch de pixels à la fois, pas un seul pixel. " "Son encodeur capture la texture locale ; les connexions de saut préservent " "le détail spatial. Entraînez un modèle et comparez-le patch par patch avec la référence KNN.", elem_classes="hint", ) with gr.Row(): with gr.Column(scale=1): run_name = gr.Textbox(label="Nom de l'expérience", placeholder="ex. : lr-1e-3_ch-16") learning_rate = gr.Slider(1e-4, 5e-3, value=1e-3, step=1e-4, label="Taux d'apprentissage") batch_size = gr.Slider(2, 32, value=8, step=2, label="Taille du lot") epochs = gr.Slider(1, 20, value=5, step=1, label="Époques") base_channels = gr.Slider(8, 64, value=16, step=8, label="Largeur du modèle (canaux de base)") train_btn = gr.Button("Entraîner le modèle", variant="primary") gr.Markdown( f"*Maximum {MAX_EXPERIMENTS} expériences. " "Rechargez les données pour réinitialiser.*", elem_classes="hint", ) step4_summary = gr.Markdown("*Entraînez un modèle pour voir les résultats.*") with gr.Column(scale=3): step4_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Index du patch de validation") with gr.Row(): step4_gt_img = gr.Image(label="Superposition vérité terrain", type="numpy", height=280) step4_bl_img = gr.Image(label="Prédiction référence KNN", type="numpy", height=280) step4_un_img = gr.Image(label="Prédiction UNet", type="numpy", height=280) # ──────────────────────────────────────────────────────── # Étape 5 — Laboratoire d'expériences # ──────────────────────────────────────────────────────── with gr.Tab("Étape 5 · Laboratoire d'expériences"): gr.Markdown( f"Comparez jusqu'à **{MAX_EXPERIMENTS}** expériences UNet côte à côte. " "Essayez différents taux d'apprentissage, nombres d'époques ou largeurs de modèle " "et observez ce qui change.", elem_classes="hint", ) with gr.Row(): step5_sel_a = gr.Dropdown(choices=[], value=None, label="Modèle gauche", interactive=True) step5_sel_b = gr.Dropdown(choices=[], value=None, label="Modèle droit", interactive=True) step5_patch_slider = gr.Slider(0, 59, value=0, step=1, label="Index du patch de validation") step5_table = gr.Markdown("*Aucune expérience pour l'instant.*") gr.Markdown( "**Questions directrices**\n\n" "- Doublez les époques — le mIoU continue-t-il de progresser ou se stabilise-t-il ?\n" "- Divisez le taux d'apprentissage par 2 — l'entraînement devient-il plus stable ?\n" "- Augmentez les canaux de base de 16 à 32 — le gain vaut-il le temps supplémentaire ?", elem_classes="hint", ) with gr.Row(): with gr.Column(scale=1): gr.Markdown("## Gauche") s5_a_rgb = gr.Image(label="RVB", type="numpy", height=240) s5_a_pred = gr.Image(label="Prédiction", type="numpy", height=240) s5_a_overlay = gr.Image(label="Superposition", type="numpy", height=240) s5_a_metrics = gr.Markdown("*Aucun modèle sélectionné.*") s5_a_error = gr.Image(label="Carte de précision", type="numpy", height=240) with gr.Column(scale=1): gr.Markdown("## Droite") s5_b_rgb = gr.Image(label="RVB", type="numpy", height=240) s5_b_pred = gr.Image(label="Prédiction", type="numpy", height=240) s5_b_overlay = gr.Image(label="Superposition", type="numpy", height=240) s5_b_metrics = gr.Markdown("*Aucun modèle sélectionné.*") s5_b_error = gr.Image(label="Carte de précision", type="numpy", height=240) # ── Connexion des événements ────────────────────────────── _load_outputs = [ dataset_state, baseline_state, experiments_state, # Étape 1 step1_info, step1_image, step1_click, # Étape 2 step2_sig_chart, step2_index_map, # Étape 3 step3_metrics, step3_full_pred, # Étape 4 step4_summary, step4_patch_slider, step4_gt_img, step4_bl_img, step4_un_img, # Étape 5 step5_table, step5_sel_a, step5_sel_b, ] load_btn.click(fn=load_dataset_action, inputs=[patch_size], outputs=_load_outputs) composite_dd.change( fn=update_step1_composite, inputs=[dataset_state, composite_dd], outputs=[step1_image, step1_click], ) step1_image.select( fn=handle_click_step1, inputs=[dataset_state], outputs=[step1_click], ) index_radio.change( fn=update_step2_index, inputs=[dataset_state, index_radio], outputs=[step2_index_map], ) baseline_btn.click( fn=run_baseline_action, inputs=[dataset_state, k_slider], outputs=[baseline_state, step3_metrics, step3_full_pred], ) _train_outputs = [ experiments_state, step4_summary, step4_patch_slider, step4_gt_img, step4_bl_img, step4_un_img, step5_table, step5_sel_a, step5_sel_b, ] train_btn.click( fn=train_experiment, inputs=[ dataset_state, baseline_state, experiments_state, learning_rate, batch_size, epochs, base_channels, run_name, ], outputs=_train_outputs, ) step4_patch_slider.change( fn=update_step4_patch, inputs=[dataset_state, baseline_state, experiments_state, step4_patch_slider], outputs=[step4_gt_img, step4_bl_img, step4_un_img], ) _s5_inputs = [dataset_state, experiments_state, step5_sel_a, step5_sel_b, step5_patch_slider] _s5_outputs = [ s5_a_rgb, s5_a_pred, s5_a_overlay, s5_a_metrics, s5_a_error, s5_b_rgb, s5_b_pred, s5_b_overlay, s5_b_metrics, s5_b_error, ] for trigger in [step5_sel_a, step5_sel_b, step5_patch_slider]: trigger.change(fn=update_step5_comparison, inputs=_s5_inputs, outputs=_s5_outputs) for img, sel in [(s5_a_rgb, step5_sel_a), (s5_a_overlay, step5_sel_a), (s5_b_rgb, step5_sel_b), (s5_b_overlay, step5_sel_b)]: img.select( fn=handle_click_step5, inputs=[dataset_state, experiments_state, sel, step5_patch_slider], outputs=[s5_a_metrics if sel == step5_sel_a else s5_b_metrics], ) try: # Gradio 6+: css moved to launch() demo.launch(css=custom_css, server_name="0.0.0.0") except TypeError: # Gradio 4.x: css stays in Blocks(), launch() doesn't accept it demo.launch(server_name="0.0.0.0")