Spaces:
Sleeping
Sleeping
| from typing import Dict, List, Optional, Tuple | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| from torch.utils.data import DataLoader | |
| import gradio as gr | |
| from config import ( | |
| DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE, | |
| BAND_NAMES, BAND_DESCRIPTIONS, CLASS_NAMES, IGNORE_INDEX, | |
| COMPOSITE_PRESETS, MAX_EXPERIMENTS, | |
| ) | |
| from data import MultiSpectralDataset, load_data | |
| from model import SmallUNet | |
| from baseline import run_knn_baseline | |
| from visualize import ( | |
| render_composite, render_single_band, | |
| add_labels_overlay, multispectral_to_rgb, | |
| mask_to_color, overlay_mask, correctness_overlay, | |
| render_full_prediction_overlay, | |
| render_spectral_signatures_chart, render_index_map, | |
| _blank_rgb, | |
| ) | |
| from metrics import compute_metrics, metrics_markdown | |
| # ── Inference ──────────────────────────────────────────────── | |
| def build_prediction_cache( | |
| model: nn.Module, images: np.ndarray, batch_size: int = 8 | |
| ) -> Tuple[np.ndarray, np.ndarray]: | |
| dummy = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64) | |
| ds = MultiSpectralDataset(images, dummy) | |
| loader = DataLoader(ds, batch_size=batch_size, shuffle=False) | |
| preds, probs = [], [] | |
| model.eval() | |
| with torch.no_grad(): | |
| for xb, _ in loader: | |
| xb = xb.to(DEVICE) | |
| pb = F.softmax(model(xb), dim=1) | |
| preds.append(torch.argmax(pb, dim=1).cpu().numpy()) | |
| probs.append(pb.cpu().numpy()) | |
| return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0) | |
| # ── Shared render helpers ──────────────────────────────────── | |
| def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]: | |
| if not name: | |
| return None | |
| return next((e for e in experiments if e["name"] == name), None) | |
| def pixel_info_markdown( | |
| x: int, y: int, | |
| img7: np.ndarray, gt: np.ndarray, | |
| pred: Optional[np.ndarray] = None, | |
| probs: Optional[np.ndarray] = None, | |
| ) -> str: | |
| h, w = gt.shape | |
| x = int(np.clip(x, 0, w - 1)) | |
| y = int(np.clip(y, 0, h - 1)) | |
| gt_class = int(gt[y, x]) | |
| gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Non étiqueté" | |
| lines = [f"**Pixel ({x}, {y})**", f"Vérité terrain : **{gt_name}**"] | |
| if pred is not None: | |
| pred_class = int(pred[y, x]) | |
| lines.append(f"Prédiction : **{CLASS_NAMES[pred_class]}**") | |
| if gt_class != IGNORE_INDEX: | |
| lines.append(f"Correct : **{'Oui' if pred_class == gt_class else 'Non'}**") | |
| if probs is not None: | |
| top = np.argsort(probs[:, y, x])[::-1][:3] | |
| lines.append("Meilleures probabilités : " + ", ".join( | |
| f"{CLASS_NAMES[i]} {probs[i, y, x]*100:.1f}%" for i in top | |
| )) | |
| lines += ["", "**Valeurs de bande**"] + [ | |
| f"{BAND_DESCRIPTIONS[b]}: **{float(img7[b, y, x]):.3f}**" | |
| for b in range(img7.shape[0]) | |
| ] | |
| return "\n\n".join(lines) | |
| def experiments_table_markdown(experiments: List[Dict]) -> str: | |
| if not experiments: | |
| return "Aucune expérience entraînée pour l'instant." | |
| lines = [ | |
| "| # | Nom | Taux d'app. | Époques | Canaux | Préc. val | mIoU |", | |
| "|---|---|---:|---:|---:|---:|---:|", | |
| ] | |
| for i, e in enumerate(experiments): | |
| cfg = e["config"] | |
| lines.append( | |
| f"| {i+1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} " | |
| f"| {cfg['base_channels']} " | |
| f"| {e['global_metrics']['overall_acc']*100:.1f}% " | |
| f"| {e['global_metrics']['miou']*100:.1f}% |" | |
| ) | |
| return "\n".join(lines) | |
| # ── Step 1 render helpers ──────────────────────────────────── | |
| def _render_step1_image(dataset_state: Dict, composite_choice: str) -> np.ndarray: | |
| full = dataset_state["full_image"] | |
| if composite_choice in COMPOSITE_PRESETS: | |
| r, g, b = COMPOSITE_PRESETS[composite_choice] | |
| base = render_composite(full, r, g, b) | |
| else: | |
| band_idx = BAND_DESCRIPTIONS.index(composite_choice) | |
| base = render_single_band(full, band_idx) | |
| return add_labels_overlay(base, dataset_state["full_train_mask"], dataset_state["full_val_mask"]) | |
| # ── Step 4 render helpers ──────────────────────────────────── | |
| def _render_step4_row( | |
| dataset_state: Dict, | |
| baseline_state: Optional[Dict], | |
| experiments: List[Dict], | |
| patch_idx: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: | |
| """Returns (rgb, gt_overlay, baseline_overlay, unet_overlay).""" | |
| val_images = dataset_state["val_images"] | |
| val_masks = dataset_state["val_masks"] | |
| idx = max(0, min(patch_idx, len(val_images) - 1)) | |
| rgb = multispectral_to_rgb(val_images[idx]) | |
| gt = val_masks[idx] | |
| gt_ov = overlay_mask(rgb, gt) | |
| if baseline_state is not None and idx < len(baseline_state["val_preds"]): | |
| bl_ov = overlay_mask(rgb, baseline_state["val_preds"][idx]) | |
| else: | |
| bl_ov = _blank_rgb(*rgb.shape[:2]) | |
| if experiments and idx < len(experiments[-1]["val_preds"]): | |
| un_ov = overlay_mask(rgb, experiments[-1]["val_preds"][idx]) | |
| else: | |
| un_ov = _blank_rgb(*rgb.shape[:2]) | |
| return rgb, gt_ov, bl_ov, un_ov | |
| # ── Step 5 render helpers ──────────────────────────────────── | |
| def render_step5_panel( | |
| dataset_state: Dict, | |
| exp: Optional[Dict], | |
| patch_idx: int, | |
| ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, str, np.ndarray]: | |
| """Returns (rgb, pred_color, overlay, metrics_md, error_map).""" | |
| blank = _blank_rgb() | |
| if dataset_state is None or exp is None: | |
| return blank, blank, blank, "Aucun modèle sélectionné.", blank | |
| val_images = dataset_state["val_images"] | |
| val_masks = dataset_state["val_masks"] | |
| idx = max(0, min(patch_idx, len(val_images) - 1)) | |
| rgb = multispectral_to_rgb(val_images[idx]) | |
| gt = val_masks[idx] | |
| if idx >= len(exp["val_preds"]): | |
| return rgb, mask_to_color(gt), overlay_mask(rgb, gt), "Données rechargées — réentraîner.", blank | |
| pred = exp["val_preds"][idx].astype(np.int64) | |
| probs = exp["val_probs"][idx].astype(np.float32) | |
| m = compute_metrics(pred, gt) | |
| return ( | |
| rgb, | |
| mask_to_color(pred), | |
| overlay_mask(rgb, pred), | |
| metrics_markdown(m, title=f"{exp['name']} · patch {idx}"), | |
| correctness_overlay(rgb, pred, gt), | |
| ) | |
| # ── Gradio action functions ────────────────────────────────── | |
| def load_dataset_action(patch_size: int): | |
| patch_size = int(patch_size) | |
| dataset_state = load_data(patch_size) | |
| val_count = len(dataset_state["val_images"]) | |
| step1_img = _render_step1_image(dataset_state, "H4 / H3 / H2") | |
| sig_chart = render_spectral_signatures_chart(dataset_state["signatures"]) | |
| ndvi_map = render_index_map( | |
| dataset_state["ndvi"], "NDVI", | |
| dataset_state["full_train_mask"], dataset_state["full_val_mask"], | |
| ) | |
| blank = _blank_rgb() | |
| dataset_info = "\n\n".join([ | |
| "**Données chargées.**", | |
| dataset_state["status"], | |
| f"Bandes : {', '.join(BAND_NAMES)} | Classes : {', '.join(CLASS_NAMES)}", | |
| "**Carrés** = étiquettes d'entraînement · **Cercles** = étiquettes de validation", | |
| ]) | |
| slider_upd = gr.update(maximum=max(0, val_count - 1), value=0) | |
| return ( | |
| dataset_state, | |
| None, # baseline_state reset | |
| [], # experiments_state reset | |
| # Tab 1 | |
| dataset_info, | |
| step1_img, | |
| "Cliquez sur l'image pour inspecter un pixel.", | |
| # Tab 2 | |
| sig_chart, | |
| ndvi_map, | |
| # Tab 3 | |
| "Lancez la référence KNN après le chargement des données.", | |
| blank, | |
| # Tab 4 | |
| "Entraînez un modèle à l'étape 4.", | |
| slider_upd, | |
| blank, blank, blank, | |
| # Tab 5 | |
| "Aucune expérience pour l'instant.", | |
| gr.update(choices=[], value=None), | |
| gr.update(choices=[], value=None), | |
| ) | |
| def update_step1_composite(dataset_state, composite_choice: str): | |
| if dataset_state is None: | |
| return _blank_rgb(), "Chargez les données d'abord." | |
| img = _render_step1_image(dataset_state, composite_choice) | |
| return img, "Cliquez sur l'image pour inspecter un pixel." | |
| def handle_click_step1(evt: gr.SelectData, dataset_state): | |
| if dataset_state is None: | |
| return "Chargez les données d'abord." | |
| x, y = evt.index | |
| full = dataset_state["full_image"] | |
| fmask = dataset_state["full_val_mask"] | |
| H, W = fmask.shape | |
| x, y = int(np.clip(x, 0, W-1)), int(np.clip(y, 0, H-1)) | |
| cls = int(fmask[y, x]) | |
| label = CLASS_NAMES[cls] if cls != IGNORE_INDEX else "Non étiqueté" | |
| lines = [ | |
| f"**Pixel ({x}, {y})** | Étiquette val : **{label}**", "", | |
| "| Bande | Valeur |", "|---|---:|", | |
| ] + [f"| {BAND_DESCRIPTIONS[b]} | {float(full[b, y, x]):.4f} |" for b in range(7)] | |
| return "\n".join(lines) | |
| def update_step2_index(dataset_state, index_choice: str): | |
| if dataset_state is None: | |
| return _blank_rgb() | |
| key = index_choice.lower() | |
| arr = dataset_state[key] | |
| return render_index_map( | |
| arr, index_choice, | |
| dataset_state["full_train_mask"], dataset_state["full_val_mask"], | |
| ) | |
| def run_baseline_action(dataset_state, k: int, progress=gr.Progress()): | |
| if dataset_state is None: | |
| raise gr.Error("Chargez les données d'abord.") | |
| progress(0.1, desc="KNN sur la scène complète...") | |
| k = int(k) | |
| full_pred, val_preds, metrics, metrics_md = run_knn_baseline( | |
| dataset_state["full_image"], | |
| dataset_state["full_train_mask"], | |
| dataset_state["full_val_mask"], | |
| dataset_state["val_images"], | |
| k=k, | |
| ) | |
| progress(0.9, desc="Rendu en cours...") | |
| baseline_state = { | |
| "k": k, | |
| "full_pred": full_pred, | |
| "val_preds": val_preds, | |
| "metrics": metrics, | |
| } | |
| full_ov = render_full_prediction_overlay( | |
| dataset_state["full_image"], full_pred, dataset_state["full_val_mask"], | |
| ) | |
| progress(1.0) | |
| return baseline_state, metrics_md, full_ov | |
| def update_step4_patch(dataset_state, baseline_state, experiments, patch_idx: int): | |
| if dataset_state is None: | |
| blank = _blank_rgb() | |
| return blank, blank, blank # no data loaded yet | |
| _, gt_ov, bl_ov, un_ov = _render_step4_row( | |
| dataset_state, baseline_state, experiments, int(patch_idx) | |
| ) | |
| return gt_ov, bl_ov, un_ov | |
| def train_experiment( | |
| dataset_state: Dict, | |
| baseline_state: Optional[Dict], | |
| experiments: List[Dict], | |
| learning_rate: float, | |
| batch_size: int, | |
| epochs: int, | |
| base_channels: int, | |
| run_name: str, | |
| progress=gr.Progress(), | |
| ): | |
| if dataset_state is None or "train_images" not in dataset_state: | |
| raise gr.Error("Chargez les données d'abord.") | |
| if len(experiments) >= MAX_EXPERIMENTS: | |
| raise gr.Error( | |
| f"Maximum de {MAX_EXPERIMENTS} expériences atteint. " | |
| "Comparez à l'étape 5, puis rechargez les données pour recommencer." | |
| ) | |
| loader = DataLoader( | |
| MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]), | |
| batch_size=int(batch_size), shuffle=True, | |
| ) | |
| model = SmallUNet(NUM_CHANNELS, NUM_CLASSES, int(base_channels)).to(DEVICE) | |
| optimizer = torch.optim.Adam(model.parameters(), lr=float(learning_rate)) | |
| criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) | |
| n_epochs = int(epochs) | |
| history = [] | |
| for ep in range(n_epochs): | |
| progress(ep / n_epochs, desc=f"Epoch {ep+1}/{n_epochs}") | |
| model.train() | |
| total, n = 0.0, 0 | |
| for xb, yb in loader: | |
| xb, yb = xb.to(DEVICE), yb.to(DEVICE) | |
| optimizer.zero_grad(set_to_none=True) | |
| loss = criterion(model(xb), yb) | |
| loss.backward() | |
| optimizer.step() | |
| total += float(loss.item()) | |
| n += 1 | |
| history.append(total / max(1, n)) | |
| progress(0.95, desc="Inférence de validation...") | |
| val_preds, val_probs = build_prediction_cache( | |
| model, dataset_state["val_images"], batch_size=max(1, int(batch_size)) | |
| ) | |
| global_metrics = compute_metrics( | |
| val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1) | |
| ) | |
| progress(1.0) | |
| base = (run_name or f"Run {len(experiments)+1}").strip() | |
| existing = {e["name"] for e in experiments} | |
| name, ctr = base, 2 | |
| while name in existing: | |
| name = f"{base} ({ctr})" | |
| ctr += 1 | |
| experiment = { | |
| "name": name, | |
| "config": { | |
| "learning_rate": float(learning_rate), | |
| "batch_size": int(batch_size), | |
| "epochs": int(epochs), | |
| "base_channels": int(base_channels), | |
| }, | |
| "train_loss_history": history, | |
| "global_metrics": global_metrics, | |
| "val_preds": val_preds.astype(np.int64), | |
| "val_probs": val_probs.astype(np.float32), | |
| } | |
| experiments = experiments + [experiment] | |
| summary = "\n\n".join([ | |
| f"**Entraînement terminé — {name}**", | |
| f"Appareil : **{DEVICE}** | Époques : **{n_epochs}** | Perte finale : **{history[-1]:.4f}**", | |
| f"Précision val : **{global_metrics['overall_acc']*100:.2f}%** (pixels étiquetés seulement)", | |
| f"mIoU val : **{global_metrics['miou']*100:.2f}%**", | |
| ]) | |
| choices = [e["name"] for e in experiments] | |
| val_count = len(dataset_state["val_images"]) | |
| _, gt_ov, bl_ov, un_ov = _render_step4_row(dataset_state, baseline_state, experiments, 0) | |
| return ( | |
| experiments, | |
| summary, | |
| gr.update(maximum=max(0, val_count-1), value=0), # step4 patch slider | |
| gt_ov, bl_ov, un_ov, | |
| experiments_table_markdown(experiments), # step5 table | |
| gr.update(choices=choices, value=None), # step5 sel_a | |
| gr.update(choices=choices, value=None), # step5 sel_b | |
| ) | |
| def update_step5_comparison( | |
| dataset_state, experiments, sel_a, sel_b, patch_idx: int | |
| ): | |
| idx = int(patch_idx) | |
| exp_a = _get_exp_by_name(experiments, sel_a) | |
| exp_b = _get_exp_by_name(experiments, sel_b) | |
| a_outs = render_step5_panel(dataset_state, exp_a, idx) | |
| b_outs = render_step5_panel(dataset_state, exp_b, idx) | |
| return (*a_outs, *b_outs) | |
| def handle_click_step5( | |
| evt: gr.SelectData, | |
| dataset_state, experiments, model_name, patch_idx: int, | |
| ) -> str: | |
| try: | |
| if dataset_state is None: | |
| return "Aucune donnée chargée." | |
| idx = max(0, min(int(patch_idx), len(dataset_state["val_images"])-1)) | |
| exp = _get_exp_by_name(experiments, model_name) | |
| x, y = evt.index | |
| img7 = dataset_state["val_images"][idx] | |
| gt = dataset_state["val_masks"][idx] | |
| pred = exp["val_preds"][idx] if (exp and idx < len(exp["val_preds"])) else None | |
| probs = exp["val_probs"][idx] if (exp and idx < len(exp["val_probs"])) else None | |
| return pixel_info_markdown(int(x), int(y), img7, gt, pred, probs) | |
| except Exception as e: | |
| return f"Click error: `{e}`" | |