segspace_app / train.py
functionNormally
Remove speculative band labels; revert class names to English
f4b0cd5
from typing import Dict, List, Optional, Tuple
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
from torch.utils.data import DataLoader
import gradio as gr
from config import (
DEVICE, NUM_CHANNELS, NUM_CLASSES, DEFAULT_PATCH_SIZE,
BAND_NAMES, BAND_DESCRIPTIONS, CLASS_NAMES, IGNORE_INDEX,
COMPOSITE_PRESETS, MAX_EXPERIMENTS,
)
from data import MultiSpectralDataset, load_data
from model import SmallUNet
from baseline import run_knn_baseline
from visualize import (
render_composite, render_single_band,
add_labels_overlay, multispectral_to_rgb,
mask_to_color, overlay_mask, correctness_overlay,
render_full_prediction_overlay,
render_spectral_signatures_chart, render_index_map,
_blank_rgb,
)
from metrics import compute_metrics, metrics_markdown
# ── Inference ────────────────────────────────────────────────
def build_prediction_cache(
model: nn.Module, images: np.ndarray, batch_size: int = 8
) -> Tuple[np.ndarray, np.ndarray]:
dummy = np.zeros((len(images), images.shape[-2], images.shape[-1]), dtype=np.int64)
ds = MultiSpectralDataset(images, dummy)
loader = DataLoader(ds, batch_size=batch_size, shuffle=False)
preds, probs = [], []
model.eval()
with torch.no_grad():
for xb, _ in loader:
xb = xb.to(DEVICE)
pb = F.softmax(model(xb), dim=1)
preds.append(torch.argmax(pb, dim=1).cpu().numpy())
probs.append(pb.cpu().numpy())
return np.concatenate(preds, axis=0), np.concatenate(probs, axis=0)
# ── Shared render helpers ────────────────────────────────────
def _get_exp_by_name(experiments: List[Dict], name: Optional[str]) -> Optional[Dict]:
if not name:
return None
return next((e for e in experiments if e["name"] == name), None)
def pixel_info_markdown(
x: int, y: int,
img7: np.ndarray, gt: np.ndarray,
pred: Optional[np.ndarray] = None,
probs: Optional[np.ndarray] = None,
) -> str:
h, w = gt.shape
x = int(np.clip(x, 0, w - 1))
y = int(np.clip(y, 0, h - 1))
gt_class = int(gt[y, x])
gt_name = CLASS_NAMES[gt_class] if gt_class != IGNORE_INDEX else "Non étiqueté"
lines = [f"**Pixel ({x}, {y})**", f"Vérité terrain : **{gt_name}**"]
if pred is not None:
pred_class = int(pred[y, x])
lines.append(f"Prédiction : **{CLASS_NAMES[pred_class]}**")
if gt_class != IGNORE_INDEX:
lines.append(f"Correct : **{'Oui' if pred_class == gt_class else 'Non'}**")
if probs is not None:
top = np.argsort(probs[:, y, x])[::-1][:3]
lines.append("Meilleures probabilités : " + ", ".join(
f"{CLASS_NAMES[i]} {probs[i, y, x]*100:.1f}%" for i in top
))
lines += ["", "**Valeurs de bande**"] + [
f"{BAND_DESCRIPTIONS[b]}: **{float(img7[b, y, x]):.3f}**"
for b in range(img7.shape[0])
]
return "\n\n".join(lines)
def experiments_table_markdown(experiments: List[Dict]) -> str:
if not experiments:
return "Aucune expérience entraînée pour l'instant."
lines = [
"| # | Nom | Taux d'app. | Époques | Canaux | Préc. val | mIoU |",
"|---|---|---:|---:|---:|---:|---:|",
]
for i, e in enumerate(experiments):
cfg = e["config"]
lines.append(
f"| {i+1} | {e['name']} | {cfg['learning_rate']:.4f} | {cfg['epochs']} "
f"| {cfg['base_channels']} "
f"| {e['global_metrics']['overall_acc']*100:.1f}% "
f"| {e['global_metrics']['miou']*100:.1f}% |"
)
return "\n".join(lines)
# ── Step 1 render helpers ────────────────────────────────────
def _render_step1_image(dataset_state: Dict, composite_choice: str) -> np.ndarray:
full = dataset_state["full_image"]
if composite_choice in COMPOSITE_PRESETS:
r, g, b = COMPOSITE_PRESETS[composite_choice]
base = render_composite(full, r, g, b)
else:
band_idx = BAND_DESCRIPTIONS.index(composite_choice)
base = render_single_band(full, band_idx)
return add_labels_overlay(base, dataset_state["full_train_mask"], dataset_state["full_val_mask"])
# ── Step 4 render helpers ────────────────────────────────────
def _render_step4_row(
dataset_state: Dict,
baseline_state: Optional[Dict],
experiments: List[Dict],
patch_idx: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, np.ndarray]:
"""Returns (rgb, gt_overlay, baseline_overlay, unet_overlay)."""
val_images = dataset_state["val_images"]
val_masks = dataset_state["val_masks"]
idx = max(0, min(patch_idx, len(val_images) - 1))
rgb = multispectral_to_rgb(val_images[idx])
gt = val_masks[idx]
gt_ov = overlay_mask(rgb, gt)
if baseline_state is not None and idx < len(baseline_state["val_preds"]):
bl_ov = overlay_mask(rgb, baseline_state["val_preds"][idx])
else:
bl_ov = _blank_rgb(*rgb.shape[:2])
if experiments and idx < len(experiments[-1]["val_preds"]):
un_ov = overlay_mask(rgb, experiments[-1]["val_preds"][idx])
else:
un_ov = _blank_rgb(*rgb.shape[:2])
return rgb, gt_ov, bl_ov, un_ov
# ── Step 5 render helpers ────────────────────────────────────
def render_step5_panel(
dataset_state: Dict,
exp: Optional[Dict],
patch_idx: int,
) -> Tuple[np.ndarray, np.ndarray, np.ndarray, str, np.ndarray]:
"""Returns (rgb, pred_color, overlay, metrics_md, error_map)."""
blank = _blank_rgb()
if dataset_state is None or exp is None:
return blank, blank, blank, "Aucun modèle sélectionné.", blank
val_images = dataset_state["val_images"]
val_masks = dataset_state["val_masks"]
idx = max(0, min(patch_idx, len(val_images) - 1))
rgb = multispectral_to_rgb(val_images[idx])
gt = val_masks[idx]
if idx >= len(exp["val_preds"]):
return rgb, mask_to_color(gt), overlay_mask(rgb, gt), "Données rechargées — réentraîner.", blank
pred = exp["val_preds"][idx].astype(np.int64)
probs = exp["val_probs"][idx].astype(np.float32)
m = compute_metrics(pred, gt)
return (
rgb,
mask_to_color(pred),
overlay_mask(rgb, pred),
metrics_markdown(m, title=f"{exp['name']} · patch {idx}"),
correctness_overlay(rgb, pred, gt),
)
# ── Gradio action functions ──────────────────────────────────
def load_dataset_action(patch_size: int):
patch_size = int(patch_size)
dataset_state = load_data(patch_size)
val_count = len(dataset_state["val_images"])
step1_img = _render_step1_image(dataset_state, "H4 / H3 / H2")
sig_chart = render_spectral_signatures_chart(dataset_state["signatures"])
ndvi_map = render_index_map(
dataset_state["ndvi"], "NDVI",
dataset_state["full_train_mask"], dataset_state["full_val_mask"],
)
blank = _blank_rgb()
dataset_info = "\n\n".join([
"**Données chargées.**",
dataset_state["status"],
f"Bandes : {', '.join(BAND_NAMES)} | Classes : {', '.join(CLASS_NAMES)}",
"**Carrés** = étiquettes d'entraînement · **Cercles** = étiquettes de validation",
])
slider_upd = gr.update(maximum=max(0, val_count - 1), value=0)
return (
dataset_state,
None, # baseline_state reset
[], # experiments_state reset
# Tab 1
dataset_info,
step1_img,
"Cliquez sur l'image pour inspecter un pixel.",
# Tab 2
sig_chart,
ndvi_map,
# Tab 3
"Lancez la référence KNN après le chargement des données.",
blank,
# Tab 4
"Entraînez un modèle à l'étape 4.",
slider_upd,
blank, blank, blank,
# Tab 5
"Aucune expérience pour l'instant.",
gr.update(choices=[], value=None),
gr.update(choices=[], value=None),
)
def update_step1_composite(dataset_state, composite_choice: str):
if dataset_state is None:
return _blank_rgb(), "Chargez les données d'abord."
img = _render_step1_image(dataset_state, composite_choice)
return img, "Cliquez sur l'image pour inspecter un pixel."
def handle_click_step1(evt: gr.SelectData, dataset_state):
if dataset_state is None:
return "Chargez les données d'abord."
x, y = evt.index
full = dataset_state["full_image"]
fmask = dataset_state["full_val_mask"]
H, W = fmask.shape
x, y = int(np.clip(x, 0, W-1)), int(np.clip(y, 0, H-1))
cls = int(fmask[y, x])
label = CLASS_NAMES[cls] if cls != IGNORE_INDEX else "Non étiqueté"
lines = [
f"**Pixel ({x}, {y})** | Étiquette val : **{label}**", "",
"| Bande | Valeur |", "|---|---:|",
] + [f"| {BAND_DESCRIPTIONS[b]} | {float(full[b, y, x]):.4f} |" for b in range(7)]
return "\n".join(lines)
def update_step2_index(dataset_state, index_choice: str):
if dataset_state is None:
return _blank_rgb()
key = index_choice.lower()
arr = dataset_state[key]
return render_index_map(
arr, index_choice,
dataset_state["full_train_mask"], dataset_state["full_val_mask"],
)
def run_baseline_action(dataset_state, k: int, progress=gr.Progress()):
if dataset_state is None:
raise gr.Error("Chargez les données d'abord.")
progress(0.1, desc="KNN sur la scène complète...")
k = int(k)
full_pred, val_preds, metrics, metrics_md = run_knn_baseline(
dataset_state["full_image"],
dataset_state["full_train_mask"],
dataset_state["full_val_mask"],
dataset_state["val_images"],
k=k,
)
progress(0.9, desc="Rendu en cours...")
baseline_state = {
"k": k,
"full_pred": full_pred,
"val_preds": val_preds,
"metrics": metrics,
}
full_ov = render_full_prediction_overlay(
dataset_state["full_image"], full_pred, dataset_state["full_val_mask"],
)
progress(1.0)
return baseline_state, metrics_md, full_ov
def update_step4_patch(dataset_state, baseline_state, experiments, patch_idx: int):
if dataset_state is None:
blank = _blank_rgb()
return blank, blank, blank # no data loaded yet
_, gt_ov, bl_ov, un_ov = _render_step4_row(
dataset_state, baseline_state, experiments, int(patch_idx)
)
return gt_ov, bl_ov, un_ov
def train_experiment(
dataset_state: Dict,
baseline_state: Optional[Dict],
experiments: List[Dict],
learning_rate: float,
batch_size: int,
epochs: int,
base_channels: int,
run_name: str,
progress=gr.Progress(),
):
if dataset_state is None or "train_images" not in dataset_state:
raise gr.Error("Chargez les données d'abord.")
if len(experiments) >= MAX_EXPERIMENTS:
raise gr.Error(
f"Maximum de {MAX_EXPERIMENTS} expériences atteint. "
"Comparez à l'étape 5, puis rechargez les données pour recommencer."
)
loader = DataLoader(
MultiSpectralDataset(dataset_state["train_images"], dataset_state["train_masks"]),
batch_size=int(batch_size), shuffle=True,
)
model = SmallUNet(NUM_CHANNELS, NUM_CLASSES, int(base_channels)).to(DEVICE)
optimizer = torch.optim.Adam(model.parameters(), lr=float(learning_rate))
criterion = nn.CrossEntropyLoss(ignore_index=IGNORE_INDEX)
n_epochs = int(epochs)
history = []
for ep in range(n_epochs):
progress(ep / n_epochs, desc=f"Epoch {ep+1}/{n_epochs}")
model.train()
total, n = 0.0, 0
for xb, yb in loader:
xb, yb = xb.to(DEVICE), yb.to(DEVICE)
optimizer.zero_grad(set_to_none=True)
loss = criterion(model(xb), yb)
loss.backward()
optimizer.step()
total += float(loss.item())
n += 1
history.append(total / max(1, n))
progress(0.95, desc="Inférence de validation...")
val_preds, val_probs = build_prediction_cache(
model, dataset_state["val_images"], batch_size=max(1, int(batch_size))
)
global_metrics = compute_metrics(
val_preds.reshape(-1), dataset_state["val_masks"].reshape(-1)
)
progress(1.0)
base = (run_name or f"Run {len(experiments)+1}").strip()
existing = {e["name"] for e in experiments}
name, ctr = base, 2
while name in existing:
name = f"{base} ({ctr})"
ctr += 1
experiment = {
"name": name,
"config": {
"learning_rate": float(learning_rate),
"batch_size": int(batch_size),
"epochs": int(epochs),
"base_channels": int(base_channels),
},
"train_loss_history": history,
"global_metrics": global_metrics,
"val_preds": val_preds.astype(np.int64),
"val_probs": val_probs.astype(np.float32),
}
experiments = experiments + [experiment]
summary = "\n\n".join([
f"**Entraînement terminé — {name}**",
f"Appareil : **{DEVICE}** | Époques : **{n_epochs}** | Perte finale : **{history[-1]:.4f}**",
f"Précision val : **{global_metrics['overall_acc']*100:.2f}%** (pixels étiquetés seulement)",
f"mIoU val : **{global_metrics['miou']*100:.2f}%**",
])
choices = [e["name"] for e in experiments]
val_count = len(dataset_state["val_images"])
_, gt_ov, bl_ov, un_ov = _render_step4_row(dataset_state, baseline_state, experiments, 0)
return (
experiments,
summary,
gr.update(maximum=max(0, val_count-1), value=0), # step4 patch slider
gt_ov, bl_ov, un_ov,
experiments_table_markdown(experiments), # step5 table
gr.update(choices=choices, value=None), # step5 sel_a
gr.update(choices=choices, value=None), # step5 sel_b
)
def update_step5_comparison(
dataset_state, experiments, sel_a, sel_b, patch_idx: int
):
idx = int(patch_idx)
exp_a = _get_exp_by_name(experiments, sel_a)
exp_b = _get_exp_by_name(experiments, sel_b)
a_outs = render_step5_panel(dataset_state, exp_a, idx)
b_outs = render_step5_panel(dataset_state, exp_b, idx)
return (*a_outs, *b_outs)
def handle_click_step5(
evt: gr.SelectData,
dataset_state, experiments, model_name, patch_idx: int,
) -> str:
try:
if dataset_state is None:
return "Aucune donnée chargée."
idx = max(0, min(int(patch_idx), len(dataset_state["val_images"])-1))
exp = _get_exp_by_name(experiments, model_name)
x, y = evt.index
img7 = dataset_state["val_images"][idx]
gt = dataset_state["val_masks"][idx]
pred = exp["val_preds"][idx] if (exp and idx < len(exp["val_preds"])) else None
probs = exp["val_probs"][idx] if (exp and idx < len(exp["val_probs"])) else None
return pixel_info_markdown(int(x), int(y), img7, gt, pred, probs)
except Exception as e:
return f"Click error: `{e}`"