from typing import Dict, List, Optional, Tuple import numpy as np import torch import torch.nn as nn import torch.nn.functional as F from torch.utils.data import DataLoader import gradio as gr from config import ( DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE, BAND_NAMES, BAND_DESCRIPTIONS, CLASS_NAMES, IGNORE_INDEX, COMPOSITE_PRESETS, MAX_EXPERIMENTS, ) from data import MultiSpectralDataset, load_data from model import SmallUNet from baseline import run_knn_baseline from visualize import ( render_composite, render_single_band, add_labels_overlay, multispectral_to_rgb, mask_to_color, overlay_mask, correctness_overlay, render_full_prediction_overlay, render_spectral_signatures_chart, render_index_map, _blank_rgb, ) from metrics import compute_metrics, metrics_markdown # ── Inference ──────────────────────────────────────────────── def build_prediction_cache( model: nn.Module, images: np.ndarray, batch_size: int = 8 ) -> Tuple[np.ndarray, np.ndarray]: dummy = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64) ds = MultiSpectralDataset(images, dummy) loader = DataLoader(ds, batch_size=batch_size, shuffle=False) preds, probs = [], [] model.eval() with torch.no_grad(): for xb, _ in loader: xb = xb.to(DEVICE) pb = F.softmax(model(xb), dim=1) preds.append(torch.argmax(pb, dim=1).cpu().numpy()) probs.append(pb.cpu().numpy()) return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0) # ── Shared render helpers ──────────────────────────────────── def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]: if not name: return None return next((e for e in experiments if e["name"] == name), None) def pixel_info_markdown( x: int, y: int, img7: np.ndarray, gt: np.ndarray, pred: Optional[np.ndarray] = None, probs: Optional[np.ndarray] = None, ) -> str: h, w = gt.shape x = int(np.clip(x, 0, w - 1)) y = int(np.clip(y, 0, h - 1)) gt_class = int(gt[y, x]) gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Non étiqueté" lines = [f"**Pixel ({x}, {y})**", f"Vérité terrain : **{gt_name}**"] if pred is not None: pred_class = int(pred[y, x]) lines.append(f"Prédiction : **{CLASS_NAMES[pred_class]}**") if gt_class != IGNORE_INDEX: lines.append(f"Correct : **{'Oui' if pred_class == gt_class else 'Non'}**") if probs is not None: top = np.argsort(probs[:, y, x])[::-1][:3] lines.append("Meilleures probabilités : " + ", ".join( f"{CLASS_NAMES[i]} {probs[i, y, x]*100:.1f}%" for i in top )) lines += ["", "**Valeurs de bande**"] + [ f"{BAND_DESCRIPTIONS[b]}: **{float(img7[b, y, x]):.3f}**" for b in range(img7.shape[0]) ] return "\n\n".join(lines) def experiments_table_markdown(experiments: List[Dict]) -> str: if not experiments: return "Aucune expérience entraînée pour l'instant." lines = [ "| # | Nom | Taux d'app. | Époques | Canaux | Préc. val | mIoU |", "|---|---|---:|---:|---:|---:|---:|", ] for i, e in enumerate(experiments): cfg = e["config"] lines.append( f"| {i+1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} " f"| {cfg['base_channels']} " f"| {e['global_metrics']['overall_acc']*100:.1f}% " f"| {e['global_metrics']['miou']*100:.1f}% |" ) return "\n".join(lines) # ── Step 1 render helpers ──────────────────────────────────── def _render_step1_image(dataset_state: Dict, composite_choice: str) -> np.ndarray: full = dataset_state["full_image"] if composite_choice in COMPOSITE_PRESETS: r, g, b = COMPOSITE_PRESETS[composite_choice] base = render_composite(full, r, g, b) else: band_idx = BAND_DESCRIPTIONS.index(composite_choice) base = render_single_band(full, band_idx) return add_labels_overlay(base, dataset_state["full_train_mask"], dataset_state["full_val_mask"]) # ── Step 4 render helpers ──────────────────────────────────── def _render_step4_row( dataset_state: Dict, baseline_state: Optional[Dict], experiments: List[Dict], patch_idx: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]: """Returns (rgb, gt_overlay, baseline_overlay, unet_overlay).""" val_images = dataset_state["val_images"] val_masks = dataset_state["val_masks"] idx = max(0, min(patch_idx, len(val_images) - 1)) rgb = multispectral_to_rgb(val_images[idx]) gt = val_masks[idx] gt_ov = overlay_mask(rgb, gt) if baseline_state is not None and idx < len(baseline_state["val_preds"]): bl_ov = overlay_mask(rgb, baseline_state["val_preds"][idx]) else: bl_ov = _blank_rgb(*rgb.shape[:2]) if experiments and idx < len(experiments[-1]["val_preds"]): un_ov = overlay_mask(rgb, experiments[-1]["val_preds"][idx]) else: un_ov = _blank_rgb(*rgb.shape[:2]) return rgb, gt_ov, bl_ov, un_ov # ── Step 5 render helpers ──────────────────────────────────── def render_step5_panel( dataset_state: Dict, exp: Optional[Dict], patch_idx: int, ) -> Tuple[np.ndarray, np.ndarray, np.ndarray, str, np.ndarray]: """Returns (rgb, pred_color, overlay, metrics_md, error_map).""" blank = _blank_rgb() if dataset_state is None or exp is None: return blank, blank, blank, "Aucun modèle sélectionné.", blank val_images = dataset_state["val_images"] val_masks = dataset_state["val_masks"] idx = max(0, min(patch_idx, len(val_images) - 1)) rgb = multispectral_to_rgb(val_images[idx]) gt = val_masks[idx] if idx >= len(exp["val_preds"]): return rgb, mask_to_color(gt), overlay_mask(rgb, gt), "Données rechargées — réentraîner.", blank pred = exp["val_preds"][idx].astype(np.int64) probs = exp["val_probs"][idx].astype(np.float32) m = compute_metrics(pred, gt) return ( rgb, mask_to_color(pred), overlay_mask(rgb, pred), metrics_markdown(m, title=f"{exp['name']} · patch {idx}"), correctness_overlay(rgb, pred, gt), ) # ── Gradio action functions ────────────────────────────────── def load_dataset_action(patch_size: int): patch_size = int(patch_size) dataset_state = load_data(patch_size) val_count = len(dataset_state["val_images"]) step1_img = _render_step1_image(dataset_state, "H4 / H3 / H2") sig_chart = render_spectral_signatures_chart(dataset_state["signatures"]) ndvi_map = render_index_map( dataset_state["ndvi"], "NDVI", dataset_state["full_train_mask"], dataset_state["full_val_mask"], ) blank = _blank_rgb() dataset_info = "\n\n".join([ "**Données chargées.**", dataset_state["status"], f"Bandes : {', '.join(BAND_NAMES)} | Classes : {', '.join(CLASS_NAMES)}", "**Carrés** = étiquettes d'entraînement · **Cercles** = étiquettes de validation", ]) slider_upd = gr.update(maximum=max(0, val_count - 1), value=0) return ( dataset_state, None, # baseline_state reset [], # experiments_state reset # Tab 1 dataset_info, step1_img, "Cliquez sur l'image pour inspecter un pixel.", # Tab 2 sig_chart, ndvi_map, # Tab 3 "Lancez la référence KNN après le chargement des données.", blank, # Tab 4 "Entraînez un modèle à l'étape 4.", slider_upd, blank, blank, blank, # Tab 5 "Aucune expérience pour l'instant.", gr.update(choices=[], value=None), gr.update(choices=[], value=None), ) def update_step1_composite(dataset_state, composite_choice: str): if dataset_state is None: return _blank_rgb(), "Chargez les données d'abord." img = _render_step1_image(dataset_state, composite_choice) return img, "Cliquez sur l'image pour inspecter un pixel." def handle_click_step1(evt: gr.SelectData, dataset_state): if dataset_state is None: return "Chargez les données d'abord." x, y = evt.index full = dataset_state["full_image"] fmask = dataset_state["full_val_mask"] H, W = fmask.shape x, y = int(np.clip(x, 0, W-1)), int(np.clip(y, 0, H-1)) cls = int(fmask[y, x]) label = CLASS_NAMES[cls] if cls != IGNORE_INDEX else "Non étiqueté" lines = [ f"**Pixel ({x}, {y})** | Étiquette val : **{label}**", "", "| Bande | Valeur |", "|---|---:|", ] + [f"| {BAND_DESCRIPTIONS[b]} | {float(full[b, y, x]):.4f} |" for b in range(7)] return "\n".join(lines) def update_step2_index(dataset_state, index_choice: str): if dataset_state is None: return _blank_rgb() key = index_choice.lower() arr = dataset_state[key] return render_index_map( arr, index_choice, dataset_state["full_train_mask"], dataset_state["full_val_mask"], ) def run_baseline_action(dataset_state, k: int, progress=gr.Progress()): if dataset_state is None: raise gr.Error("Chargez les données d'abord.") progress(0.1, desc="KNN sur la scène complète...") k = int(k) full_pred, val_preds, metrics, metrics_md = run_knn_baseline( dataset_state["full_image"], dataset_state["full_train_mask"], dataset_state["full_val_mask"], dataset_state["val_images"], k=k, ) progress(0.9, desc="Rendu en cours...") baseline_state = { "k": k, "full_pred": full_pred, "val_preds": val_preds, "metrics": metrics, } full_ov = render_full_prediction_overlay( dataset_state["full_image"], full_pred, dataset_state["full_val_mask"], ) progress(1.0) return baseline_state, metrics_md, full_ov def update_step4_patch(dataset_state, baseline_state, experiments, patch_idx: int): if dataset_state is None: blank = _blank_rgb() return blank, blank, blank # no data loaded yet _, gt_ov, bl_ov, un_ov = _render_step4_row( dataset_state, baseline_state, experiments, int(patch_idx) ) return gt_ov, bl_ov, un_ov def train_experiment( dataset_state: Dict, baseline_state: Optional[Dict], experiments: List[Dict], learning_rate: float, batch_size: int, epochs: int, base_channels: int, run_name: str, progress=gr.Progress(), ): if dataset_state is None or "train_images" not in dataset_state: raise gr.Error("Chargez les données d'abord.") if len(experiments) >= MAX_EXPERIMENTS: raise gr.Error( f"Maximum de {MAX_EXPERIMENTS} expériences atteint. " "Comparez à l'étape 5, puis rechargez les données pour recommencer." ) loader = DataLoader( MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]), batch_size=int(batch_size), shuffle=True, ) model = SmallUNet(NUM_CHANNELS, NUM_CLASSES, int(base_channels)).to(DEVICE) optimizer = torch.optim.Adam(model.parameters(), lr=float(learning_rate)) criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX) n_epochs = int(epochs) history = [] for ep in range(n_epochs): progress(ep / n_epochs, desc=f"Epoch {ep+1}/{n_epochs}") model.train() total, n = 0.0, 0 for xb, yb in loader: xb, yb = xb.to(DEVICE), yb.to(DEVICE) optimizer.zero_grad(set_to_none=True) loss = criterion(model(xb), yb) loss.backward() optimizer.step() total += float(loss.item()) n += 1 history.append(total / max(1, n)) progress(0.95, desc="Inférence de validation...") val_preds, val_probs = build_prediction_cache( model, dataset_state["val_images"], batch_size=max(1, int(batch_size)) ) global_metrics = compute_metrics( val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1) ) progress(1.0) base = (run_name or f"Run {len(experiments)+1}").strip() existing = {e["name"] for e in experiments} name, ctr = base, 2 while name in existing: name = f"{base} ({ctr})" ctr += 1 experiment = { "name": name, "config": { "learning_rate": float(learning_rate), "batch_size": int(batch_size), "epochs": int(epochs), "base_channels": int(base_channels), }, "train_loss_history": history, "global_metrics": global_metrics, "val_preds": val_preds.astype(np.int64), "val_probs": val_probs.astype(np.float32), } experiments = experiments + [experiment] summary = "\n\n".join([ f"**Entraînement terminé — {name}**", f"Appareil : **{DEVICE}** | Époques : **{n_epochs}** | Perte finale : **{history[-1]:.4f}**", f"Précision val : **{global_metrics['overall_acc']*100:.2f}%** (pixels étiquetés seulement)", f"mIoU val : **{global_metrics['miou']*100:.2f}%**", ]) choices = [e["name"] for e in experiments] val_count = len(dataset_state["val_images"]) _, gt_ov, bl_ov, un_ov = _render_step4_row(dataset_state, baseline_state, experiments, 0) return ( experiments, summary, gr.update(maximum=max(0, val_count-1), value=0), # step4 patch slider gt_ov, bl_ov, un_ov, experiments_table_markdown(experiments), # step5 table gr.update(choices=choices, value=None), # step5 sel_a gr.update(choices=choices, value=None), # step5 sel_b ) def update_step5_comparison( dataset_state, experiments, sel_a, sel_b, patch_idx: int ): idx = int(patch_idx) exp_a = _get_exp_by_name(experiments, sel_a) exp_b = _get_exp_by_name(experiments, sel_b) a_outs = render_step5_panel(dataset_state, exp_a, idx) b_outs = render_step5_panel(dataset_state, exp_b, idx) return (*a_outs, *b_outs) def handle_click_step5( evt: gr.SelectData, dataset_state, experiments, model_name, patch_idx: int, ) -> str: try: if dataset_state is None: return "Aucune donnée chargée." idx = max(0, min(int(patch_idx), len(dataset_state["val_images"])-1)) exp = _get_exp_by_name(experiments, model_name) x, y = evt.index img7 = dataset_state["val_images"][idx] gt = dataset_state["val_masks"][idx] pred = exp["val_preds"][idx] if (exp and idx < len(exp["val_preds"])) else None probs = exp["val_probs"][idx] if (exp and idx < len(exp["val_probs"])) else None return pixel_info_markdown(int(x), int(y), img7, gt, pred, probs) except Exception as e: return f"Click error: `{e}`"