| | import gradio as gr |
| | import numpy as np |
| | from skimage.segmentation import slic |
| | from skimage.transform import resize |
| | from scipy.cluster.hierarchy import linkage, fcluster |
| | import lpips |
| | import torch |
| | from skimage.data import astronaut |
| | from PIL import Image |
| | import warnings |
| | warnings.filterwarnings("ignore") |
| |
|
| | |
| | loss_fn = lpips.LPIPS(net='alex', spatial=False, verbose=False) |
| | if torch.cuda.is_available(): |
| | loss_fn.cuda() |
| |
|
| | MAX_PIXELS = 300_000 |
| | QUANT_STEP = 4 / 255.0 |
| | state = {} |
| |
|
| | def get_default_image(): |
| | return Image.fromarray(astronaut()) |
| |
|
| | |
| | def find_perceptual_knee(lpips_curve, ns): |
| | """Trouver le 'coude' de la courbe LPIPS → point optimal""" |
| | if len(ns) < 10: |
| | return ns[-1] // 2 |
| | |
| | |
| | diffs = np.diff(lpips_curve) |
| | if len(diffs) < 2: |
| | return ns[-1] |
| | accel = np.diff(diffs) |
| | if len(accel) == 0: |
| | return ns[-1] |
| | |
| | knee_idx = np.argmax(accel) + 1 |
| | knee_n = ns[knee_idx] |
| | return max(15, knee_n) |
| |
|
| | def preprocess(image, n_segments=800, compactness=10): |
| | global state |
| | state.clear() |
| | |
| | if image is None: |
| | image = get_default_image() |
| | |
| | img = np.asarray(image, dtype=np.float32) / 255.0 |
| | h, w = img.shape[:2] |
| | if h * w > MAX_PIXELS: |
| | scale = (MAX_PIXELS / (h * w)) ** 0.5 |
| | img = resize(img, (int(h * scale), int(w * scale)), anti_aliasing=True) |
| | |
| | img = np.round(img / QUANT_STEP) * QUANT_STEP |
| | img = np.clip(img, 0, 1) |
| | |
| | segments = slic(img, n_segments=n_segments, compactness=compactness, start_label=0) |
| | n_sp = int(segments.max()) + 1 |
| | |
| | flat_img = img.reshape(-1, 3) |
| | flat_seg = segments.ravel() |
| | counts = np.bincount(flat_seg, minlength=n_sp) |
| | means = np.zeros((n_sp, 3)) |
| | for c in range(3): |
| | means[:, c] = np.bincount(flat_seg, weights=flat_img[:, c], minlength=n_sp) / np.maximum(counts, 1) |
| | |
| | valid = counts > 0 |
| | means = means[valid] |
| | valid_map = np.zeros(segments.max() + 1, dtype=int) |
| | valid_map[valid] = np.arange(len(means)) |
| | segments = valid_map[segments] |
| | |
| | Z = linkage(means, method='ward') |
| | |
| | |
| | ns_test = np.linspace(10, len(means), 15, dtype=int) |
| | ns_test = np.unique(ns_test) |
| | lpips_curve = [] |
| | |
| | orig_tensor = lpips.im2tensor(img) |
| | if torch.cuda.is_available(): |
| | orig_tensor = orig_tensor.cuda() |
| | |
| | for n in ns_test: |
| | labels = fcluster(Z, n, criterion='maxclust') - 1 |
| | cluster_means = np.zeros((n, 3)) |
| | cluster_weights = np.zeros(n) |
| | for i, mean in enumerate(means): |
| | cl = labels[i] |
| | cluster_means[cl] += mean * counts[valid][i] |
| | cluster_weights[cl] += counts[valid][i] |
| | cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1) |
| | out = cluster_means[labels][segments] |
| | t2 = lpips.im2tensor(out) |
| | if torch.cuda.is_available(): |
| | t2 = t2.cuda() |
| | with torch.no_grad(): |
| | lpips_curve.append(loss_fn(orig_tensor, t2).item()) |
| | |
| | knee = find_perceptual_knee(lpips_curve, ns_test) |
| | |
| | state.update({ |
| | "original": img, |
| | "img": img, |
| | "segments": segments, |
| | "means": means, |
| | "counts": counts[valid], |
| | "Z": Z, |
| | "max_clusters": len(means), |
| | "knee": knee, |
| | "ns_test": ns_test, |
| | "lpips_curve": lpips_curve |
| | }) |
| | return len(means) |
| |
|
| | def get_segmented(n_clusters): |
| | if "Z" not in state: |
| | return None, "**Cliquez sur Lancer**" |
| |
|
| | n = max(2, min(int(n_clusters), state["max_clusters"])) |
| | knee = state["knee"] |
| | |
| | |
| | labels = fcluster(state["Z"], n, criterion='maxclust') - 1 |
| | cluster_means = np.zeros((n, 3)) |
| | cluster_weights = np.zeros(n) |
| | for i, mean in enumerate(state["means"]): |
| | cl = labels[i] |
| | cluster_means[cl] += mean * state["counts"][i] |
| | cluster_weights[cl] += state["counts"][i] |
| | cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1) |
| | out_img = cluster_means[labels][state["segments"]] |
| | out_img = np.clip(out_img, 0, 1) |
| | |
| | |
| | if n >= knee * 1.5: |
| | msg = f"Excellente fidélité ({n} régions ≥ {knee} optimal)" |
| | color = "darkgreen" |
| | elif n >= knee: |
| | msg = f"Très bonne qualité ({n} régions ≈ point optimal {knee})" |
| | color = "green" |
| | elif n >= knee * 0.6: |
| | msg = f"Acceptable mais dégradé ({n} < {knee} → perte visible)" |
| | color = "orange" |
| | elif n >= knee * 0.3: |
| | msg = f"Fortement stylisé ({n} ≪ {knee} → plus réaliste)" |
| | color = "red" |
| | else: |
| | msg = f"Extrêmement simplifié / posterisé ({n} régions)" |
| | color = "darkred" |
| | |
| | lpips_small = f"<small>LPIPS ≈ {loss_fn(lpips.im2tensor(state['original']), lpips.im2tensor(out_img)).item():.4f}</small>" |
| | |
| | full_msg = f"**{n} régions** → **{msg}**<br>{lpips_small}" |
| | |
| | return (out_img * 255).astype(np.uint8), full_msg |
| |
|
| | |
| | with gr.Blocks() as demo: |
| | gr.Markdown("# Segmentation hiérarchique + point de rupture perceptif automatique") |
| | gr.Markdown("Glissez le curseur → l’app détecte **pour cette image précise** à partir de quand ça dégrade trop") |
| | |
| | with gr.Row(): |
| | inp = gr.Image(type="pil", label="Image", value=get_default_image(), height=450) |
| | out = gr.Image(label="Résultat", height=450) |
| | |
| | with gr.Row(): |
| | n_seg = gr.Slider(300, 1500, 800, step=100, label="Superpixels initiaux") |
| | comp = gr.Slider(1, 30, 10, step=1, label="Compacité") |
| | btn = gr.Button("Lancer l’analyse", variant="primary") |
| | |
| | slider = gr.Slider(2, 1000, 100, step=1, label="Nombre de régions finales", interactive=False) |
| | feedback = gr.Markdown("**Cliquez sur Lancer → glissez le curseur**") |
| |
|
| | def launch(img, ns, c): |
| | n_sp = preprocess(img, ns, c) |
| | res, txt = get_segmented(n_sp) |
| | knee = state.get("knee", 100) |
| | return res, gr.update(maximum=n_sp, value=min(n_sp, 300), interactive=True), f"**Analyse terminée !** Point de rupture détecté ≈ **{knee} régions**<br>{txt}" |
| |
|
| | btn.click(launch, [inp, n_seg, comp], [out, slider, feedback]) |
| | slider.change(get_segmented, slider, [out, feedback]) |
| |
|
| | demo.launch() |