Segment / app.py
Ynvers's picture
ok
1e4c8f8
import numpy as np
import gradio as gr
import cv2
import colorsys
def generate_distinct_colors(k):
colors = []
for i in range(k):
hue = i / k
saturation = 0.8
value = 0.8
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
colors.append(np.array(rgb) * 255)
return np.array(colors, dtype=np.uint8)
def k_means_segmentation(image, k, max_iters=100, tol=1e-4):
pixels = image.reshape(-1, 3).astype(np.float32)
np.random.seed(42)
centroids = pixels[np.random.choice(pixels.shape[0], k, replace=False)]
for iteration in range(max_iters):
distances = np.linalg.norm(pixels[:, np.newaxis] - centroids, axis=2)
labels = np.argmin(distances, axis=1)
new_centroids = np.array([
pixels[labels == i].mean(axis=0) if np.sum(labels == i) > 0
else centroids[i]
for i in range(k)
])
if np.linalg.norm(new_centroids - centroids) < tol:
print(f"K-Means convergé après {iteration + 1} itérations.")
break
centroids = new_centroids
distinct_colors = generate_distinct_colors(k)
segmented_pixels = distinct_colors[labels].astype(np.uint8)
segmented_image = segmented_pixels.reshape(image.shape)
return segmented_image
def mean_shift_segmentation(image, bandwidth=30, max_iters=20, tol=1e-3, max_image_size=200):
"""
Version optimisée de la segmentation Mean Shift
Args:
image: Image d'entrée BGR
bandwidth: Rayon de recherche
max_iters: Nombre maximum d'itérations
tol: Seuil de convergence
max_image_size: Taille maximale de l'image (le plus grand côté)
"""
# Redimensionnement de l'image si nécessaire
h, w = image.shape[:2]
scale = max_image_size / max(h, w)
if scale < 1:
new_size = (int(w * scale), int(h * scale))
image = cv2.resize(image, new_size)
# Conversion en LAB
lab_image = cv2.cvtColor(image, cv2.COLOR_BGR2LAB)
# Préparation des données
pixels = lab_image.reshape(-1, 3).astype(np.float32)
# Sous-échantillonnage pour les centres initiaux (1 pixel sur 10)
step = 10
centers = pixels[::step].copy()
# Boucle principale de Mean Shift
for iteration in range(max_iters):
# Calcul vectorisé des distances entre tous les pixels et les centres
distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
# Pour chaque centre, trouver les pixels dans son rayon et calculer la nouvelle position
new_centers = []
for i in range(len(centers)):
in_bandwidth = distances[:, i] < bandwidth
if np.sum(in_bandwidth) > 0:
new_centers.append(np.mean(pixels[in_bandwidth], axis=0))
else:
new_centers.append(centers[i])
new_centers = np.array(new_centers)
# Vérifier la convergence
center_shifts = np.sqrt(np.sum((centers - new_centers) ** 2, axis=1))
centers = new_centers
if np.all(center_shifts < tol):
print(f"Mean Shift convergé après {iteration + 1} itérations.")
break
# Attribution des labels
distances = np.sqrt(np.sum((pixels[:, np.newaxis] - centers) ** 2, axis=2))
labels = np.argmin(distances, axis=1)
# Génération des couleurs distinctes
n_clusters = len(centers)
distinct_colors = np.zeros((n_clusters, 3), dtype=np.uint8)
for i in range(n_clusters):
hue = i / n_clusters
saturation = 0.8
value = 0.8
rgb = colorsys.hsv_to_rgb(hue, saturation, value)
distinct_colors[i] = np.array(rgb) * 255
# Création de l'image segmentée
segmented_pixels = distinct_colors[labels]
segmented_image = segmented_pixels.reshape(image.shape)
# Redimensionnement au format original si nécessaire
if scale < 1:
segmented_image = cv2.resize(segmented_image, (w, h))
return segmented_image
def segment_image(image, k=None):
"""
Fonction principale qui gère la segmentation en fonction des paramètres
"""
if image is None:
return None
# Conversion de l'image au format BGR si nécessaire
if len(image.shape) == 2: # Image en niveaux de gris
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 4: # Image avec canal alpha
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
try:
if k is not None and k > 0:
return k_means_segmentation(image, int(k))
else:
return mean_shift_segmentation(image)
except Exception as e:
print(f"Erreur lors de la segmentation: {str(e)}")
return None
def segment_image(image, k=None):
"""
Fonction principale qui gère la segmentation en fonction des paramètres
"""
if image is None:
return None
# Conversion de l'image au format BGR si nécessaire
if len(image.shape) == 2: # Image en niveaux de gris
image = cv2.cvtColor(image, cv2.COLOR_GRAY2BGR)
elif image.shape[2] == 4: # Image avec canal alpha
image = cv2.cvtColor(image, cv2.COLOR_RGBA2BGR)
try:
if k is not None and k > 0:
return k_means_segmentation(image, int(k))
else:
return mean_shift_segmentation(image)
except Exception as e:
print(f"Erreur lors de la segmentation: {str(e)}")
return None
with gr.Blocks(title="Art par Segmentation d'Images", theme=gr.themes.Soft()) as app:
gr.Markdown("""
# 🎨 Studio de Segmentation Artistique
Transformez vos photos en œuvres d'art segmentées !
### Mode d'emploi :
1. Téléchargez une image
2. Choisissez votre style :
- **K-means** : Entrez un nombre K pour définir le nombre exact de couleurs
- **Mean Shift** : Laissez K vide pour une segmentation automatique
3. Cliquez sur "Transformer" et admirez le résultat !
""")
with gr.Row():
with gr.Column(scale=1):
input_image = gr.Image(
label="Image Originale",
type="numpy"
)
k_input = gr.Number(
label="Nombre de segments (K)",
minimum=0,
step=1
)
segment_btn = gr.Button(
"🎨 Transformer",
variant="primary",
scale=0
)
with gr.Column(scale=1):
output_image = gr.Image(
label="Résultat",
type="numpy"
)
# Exemples d'utilisation
gr.Examples(
examples=[
["MEN-Denim-id_00000089-17_4_full.png", 5],
["ben.jpg", None, ""],
],
inputs=[input_image, k_input],
outputs=output_image,
fn=segment_image,
cache_examples=True
)
# Configuration des événements
segment_btn.click(
fn=segment_image,
inputs=[input_image, k_input],
outputs=output_image
)
# Message d'aide supplémentaire
gr.Markdown("""
### 💡 Conseils :
- Pour K-means : essayez des valeurs entre 3 et 10 pour des résultats intéressants
- Pour Mean Shift : idéal pour les images complexes avec beaucoup de détails
- Les images de taille moyenne fonctionnent le mieux
""")
# Lancement de l'application
app.launch(share=True)