import gradio as gr import numpy as np from skimage.segmentation import slic from skimage.transform import resize from scipy.cluster.hierarchy import linkage, fcluster import lpips import torch from skimage.data import astronaut from PIL import Image import warnings warnings.filterwarnings("ignore") # LPIPS loss_fn = lpips.LPIPS(net='alex', spatial=False, verbose=False) if torch.cuda.is_available(): loss_fn.cuda() MAX_PIXELS = 300_000 QUANT_STEP = 4 / 255.0 state = {} def get_default_image(): return Image.fromarray(astronaut()) # === DÉTECTION AUTOMATIQUE DU POINT DE RUPTURE === def find_perceptual_knee(lpips_curve, ns): """Trouver le 'coude' de la courbe LPIPS → point optimal""" if len(ns) < 10: return ns[-1] // 2 # Dérivée seconde simplifiée (où l'accélération de la perte est maximale) diffs = np.diff(lpips_curve) if len(diffs) < 2: return ns[-1] accel = np.diff(diffs) if len(accel) == 0: return ns[-1] knee_idx = np.argmax(accel) + 1 knee_n = ns[knee_idx] return max(15, knee_n) # minimum raisonnable def preprocess(image, n_segments=800, compactness=10): global state state.clear() if image is None: image = get_default_image() img = np.asarray(image, dtype=np.float32) / 255.0 h, w = img.shape[:2] if h * w > MAX_PIXELS: scale = (MAX_PIXELS / (h * w)) ** 0.5 img = resize(img, (int(h * scale), int(w * scale)), anti_aliasing=True) img = np.round(img / QUANT_STEP) * QUANT_STEP img = np.clip(img, 0, 1) segments = slic(img, n_segments=n_segments, compactness=compactness, start_label=0) n_sp = int(segments.max()) + 1 flat_img = img.reshape(-1, 3) flat_seg = segments.ravel() counts = np.bincount(flat_seg, minlength=n_sp) means = np.zeros((n_sp, 3)) for c in range(3): means[:, c] = np.bincount(flat_seg, weights=flat_img[:, c], minlength=n_sp) / np.maximum(counts, 1) valid = counts > 0 means = means[valid] valid_map = np.zeros(segments.max() + 1, dtype=int) valid_map[valid] = np.arange(len(means)) segments = valid_map[segments] Z = linkage(means, method='ward') # Pré-calcul de la courbe LPIPS complète pour détecter le point de rupture ns_test = np.linspace(10, len(means), 15, dtype=int) ns_test = np.unique(ns_test) lpips_curve = [] orig_tensor = lpips.im2tensor(img) if torch.cuda.is_available(): orig_tensor = orig_tensor.cuda() for n in ns_test: labels = fcluster(Z, n, criterion='maxclust') - 1 cluster_means = np.zeros((n, 3)) cluster_weights = np.zeros(n) for i, mean in enumerate(means): cl = labels[i] cluster_means[cl] += mean * counts[valid][i] cluster_weights[cl] += counts[valid][i] cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1) out = cluster_means[labels][segments] t2 = lpips.im2tensor(out) if torch.cuda.is_available(): t2 = t2.cuda() with torch.no_grad(): lpips_curve.append(loss_fn(orig_tensor, t2).item()) knee = find_perceptual_knee(lpips_curve, ns_test) state.update({ "original": img, "img": img, "segments": segments, "means": means, "counts": counts[valid], "Z": Z, "max_clusters": len(means), "knee": knee, "ns_test": ns_test, "lpips_curve": lpips_curve }) return len(means) def get_segmented(n_clusters): if "Z" not in state: return None, "**Cliquez sur Lancer**" n = max(2, min(int(n_clusters), state["max_clusters"])) knee = state["knee"] # Reconstruction image labels = fcluster(state["Z"], n, criterion='maxclust') - 1 cluster_means = np.zeros((n, 3)) cluster_weights = np.zeros(n) for i, mean in enumerate(state["means"]): cl = labels[i] cluster_means[cl] += mean * state["counts"][i] cluster_weights[cl] += state["counts"][i] cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1) out_img = cluster_means[labels][state["segments"]] out_img = np.clip(out_img, 0, 1) # Message intelligent et adaptatif if n >= knee * 1.5: msg = f"Excellente fidélité ({n} régions ≥ {knee} optimal)" color = "darkgreen" elif n >= knee: msg = f"Très bonne qualité ({n} régions ≈ point optimal {knee})" color = "green" elif n >= knee * 0.6: msg = f"Acceptable mais dégradé ({n} < {knee} → perte visible)" color = "orange" elif n >= knee * 0.3: msg = f"Fortement stylisé ({n} ≪ {knee} → plus réaliste)" color = "red" else: msg = f"Extrêmement simplifié / posterisé ({n} régions)" color = "darkred" lpips_small = f"LPIPS ≈ {loss_fn(lpips.im2tensor(state['original']), lpips.im2tensor(out_img)).item():.4f}" full_msg = f"**{n} régions** → **{msg}**
{lpips_small}" return (out_img * 255).astype(np.uint8), full_msg # === Interface === with gr.Blocks() as demo: gr.Markdown("# Segmentation hiérarchique + point de rupture perceptif automatique") gr.Markdown("Glissez le curseur → l’app détecte **pour cette image précise** à partir de quand ça dégrade trop") with gr.Row(): inp = gr.Image(type="pil", label="Image", value=get_default_image(), height=450) out = gr.Image(label="Résultat", height=450) with gr.Row(): n_seg = gr.Slider(300, 1500, 800, step=100, label="Superpixels initiaux") comp = gr.Slider(1, 30, 10, step=1, label="Compacité") btn = gr.Button("Lancer l’analyse", variant="primary") slider = gr.Slider(2, 1000, 100, step=1, label="Nombre de régions finales", interactive=False) feedback = gr.Markdown("**Cliquez sur Lancer → glissez le curseur**") def launch(img, ns, c): n_sp = preprocess(img, ns, c) res, txt = get_segmented(n_sp) knee = state.get("knee", 100) return res, gr.update(maximum=n_sp, value=min(n_sp, 300), interactive=True), f"**Analyse terminée !** Point de rupture détecté ≈ **{knee} régions**
{txt}" btn.click(launch, [inp, n_seg, comp], [out, slider, feedback]) slider.change(get_segmented, slider, [out, feedback]) demo.launch()