Eric2mangel's picture
Update app.py
101b838 verified
import gradio as gr
import numpy as np
from skimage.segmentation import slic
from skimage.transform import resize
from scipy.cluster.hierarchy import linkage, fcluster
import lpips
import torch
from skimage.data import astronaut
from PIL import Image
import warnings
warnings.filterwarnings("ignore")
# LPIPS
loss_fn = lpips.LPIPS(net='alex', spatial=False, verbose=False)
if torch.cuda.is_available():
loss_fn.cuda()
MAX_PIXELS = 300_000
QUANT_STEP = 4 / 255.0
state = {}
def get_default_image():
return Image.fromarray(astronaut())
# === DÉTECTION AUTOMATIQUE DU POINT DE RUPTURE ===
def find_perceptual_knee(lpips_curve, ns):
"""Trouver le 'coude' de la courbe LPIPS → point optimal"""
if len(ns) < 10:
return ns[-1] // 2
# Dérivée seconde simplifiée (où l'accélération de la perte est maximale)
diffs = np.diff(lpips_curve)
if len(diffs) < 2:
return ns[-1]
accel = np.diff(diffs)
if len(accel) == 0:
return ns[-1]
knee_idx = np.argmax(accel) + 1
knee_n = ns[knee_idx]
return max(15, knee_n) # minimum raisonnable
def preprocess(image, n_segments=800, compactness=10):
global state
state.clear()
if image is None:
image = get_default_image()
img = np.asarray(image, dtype=np.float32) / 255.0
h, w = img.shape[:2]
if h * w > MAX_PIXELS:
scale = (MAX_PIXELS / (h * w)) ** 0.5
img = resize(img, (int(h * scale), int(w * scale)), anti_aliasing=True)
img = np.round(img / QUANT_STEP) * QUANT_STEP
img = np.clip(img, 0, 1)
segments = slic(img, n_segments=n_segments, compactness=compactness, start_label=0)
n_sp = int(segments.max()) + 1
flat_img = img.reshape(-1, 3)
flat_seg = segments.ravel()
counts = np.bincount(flat_seg, minlength=n_sp)
means = np.zeros((n_sp, 3))
for c in range(3):
means[:, c] = np.bincount(flat_seg, weights=flat_img[:, c], minlength=n_sp) / np.maximum(counts, 1)
valid = counts > 0
means = means[valid]
valid_map = np.zeros(segments.max() + 1, dtype=int)
valid_map[valid] = np.arange(len(means))
segments = valid_map[segments]
Z = linkage(means, method='ward')
# Pré-calcul de la courbe LPIPS complète pour détecter le point de rupture
ns_test = np.linspace(10, len(means), 15, dtype=int)
ns_test = np.unique(ns_test)
lpips_curve = []
orig_tensor = lpips.im2tensor(img)
if torch.cuda.is_available():
orig_tensor = orig_tensor.cuda()
for n in ns_test:
labels = fcluster(Z, n, criterion='maxclust') - 1
cluster_means = np.zeros((n, 3))
cluster_weights = np.zeros(n)
for i, mean in enumerate(means):
cl = labels[i]
cluster_means[cl] += mean * counts[valid][i]
cluster_weights[cl] += counts[valid][i]
cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1)
out = cluster_means[labels][segments]
t2 = lpips.im2tensor(out)
if torch.cuda.is_available():
t2 = t2.cuda()
with torch.no_grad():
lpips_curve.append(loss_fn(orig_tensor, t2).item())
knee = find_perceptual_knee(lpips_curve, ns_test)
state.update({
"original": img,
"img": img,
"segments": segments,
"means": means,
"counts": counts[valid],
"Z": Z,
"max_clusters": len(means),
"knee": knee,
"ns_test": ns_test,
"lpips_curve": lpips_curve
})
return len(means)
def get_segmented(n_clusters):
if "Z" not in state:
return None, "**Cliquez sur Lancer**"
n = max(2, min(int(n_clusters), state["max_clusters"]))
knee = state["knee"]
# Reconstruction image
labels = fcluster(state["Z"], n, criterion='maxclust') - 1
cluster_means = np.zeros((n, 3))
cluster_weights = np.zeros(n)
for i, mean in enumerate(state["means"]):
cl = labels[i]
cluster_means[cl] += mean * state["counts"][i]
cluster_weights[cl] += state["counts"][i]
cluster_means = cluster_means / np.maximum(cluster_weights[:, None], 1)
out_img = cluster_means[labels][state["segments"]]
out_img = np.clip(out_img, 0, 1)
# Message intelligent et adaptatif
if n >= knee * 1.5:
msg = f"Excellente fidélité ({n} régions ≥ {knee} optimal)"
color = "darkgreen"
elif n >= knee:
msg = f"Très bonne qualité ({n} régions ≈ point optimal {knee})"
color = "green"
elif n >= knee * 0.6:
msg = f"Acceptable mais dégradé ({n} < {knee} → perte visible)"
color = "orange"
elif n >= knee * 0.3:
msg = f"Fortement stylisé ({n}{knee} → plus réaliste)"
color = "red"
else:
msg = f"Extrêmement simplifié / posterisé ({n} régions)"
color = "darkred"
lpips_small = f"<small>LPIPS ≈ {loss_fn(lpips.im2tensor(state['original']), lpips.im2tensor(out_img)).item():.4f}</small>"
full_msg = f"**{n} régions** → **{msg}**<br>{lpips_small}"
return (out_img * 255).astype(np.uint8), full_msg
# === Interface ===
with gr.Blocks() as demo:
gr.Markdown("# Segmentation hiérarchique + point de rupture perceptif automatique")
gr.Markdown("Glissez le curseur → l’app détecte **pour cette image précise** à partir de quand ça dégrade trop")
with gr.Row():
inp = gr.Image(type="pil", label="Image", value=get_default_image(), height=450)
out = gr.Image(label="Résultat", height=450)
with gr.Row():
n_seg = gr.Slider(300, 1500, 800, step=100, label="Superpixels initiaux")
comp = gr.Slider(1, 30, 10, step=1, label="Compacité")
btn = gr.Button("Lancer l’analyse", variant="primary")
slider = gr.Slider(2, 1000, 100, step=1, label="Nombre de régions finales", interactive=False)
feedback = gr.Markdown("**Cliquez sur Lancer → glissez le curseur**")
def launch(img, ns, c):
n_sp = preprocess(img, ns, c)
res, txt = get_segmented(n_sp)
knee = state.get("knee", 100)
return res, gr.update(maximum=n_sp, value=min(n_sp, 300), interactive=True), f"**Analyse terminée !** Point de rupture détecté ≈ **{knee} régions**<br>{txt}"
btn.click(launch, [inp, n_seg, comp], [out, slider, feedback])
slider.change(get_segmented, slider, [out, feedback])
demo.launch()